@@ -26,6 +26,7 @@ def get_args():
26
26
parser .add_argument ('--use_medianflow' , action = 'store_true' )
27
27
parser .add_argument ('--use_tld' , action = 'store_true' )
28
28
parser .add_argument ('--use_nano' , action = 'store_true' )
29
+ parser .add_argument ('--use_vit' , action = 'store_true' )
29
30
30
31
args = parser .parse_args ()
31
32
@@ -63,6 +64,10 @@ def initialize_tracker_list(window_name, image, tracker_algorithm_list):
63
64
# params.backbone = "model/nanotrackv3/nanotrack_backbone_sim.onnx"
64
65
# params.neckhead = "model/nanotrackv3/nanotrack_head_sim.onnx"
65
66
tracker = cv .TrackerNano_create (params )
67
+ if tracker_algorithm == 'Vit' :
68
+ params = cv .TrackerVit_Params ()
69
+ params .net = "model/vit/object_tracking_vittrack_2023sep.onnx"
70
+ tracker = cv .TrackerVit_create (params )
66
71
if tracker_algorithm == 'CSRT' :
67
72
tracker = cv .TrackerCSRT_create ()
68
73
if tracker_algorithm == 'KCF' :
@@ -125,6 +130,7 @@ def main():
125
130
use_medianflow = args .use_medianflow
126
131
use_tld = args .use_tld
127
132
use_nano = args .use_nano
133
+ use_vit = args .use_vit
128
134
129
135
# 使用アルゴリズム #########################################################
130
136
tracker_algorithm_list = []
@@ -148,7 +154,9 @@ def main():
148
154
tracker_algorithm_list .append ('TLD' )
149
155
if use_nano :
150
156
tracker_algorithm_list .append ('Nano' )
151
-
157
+ if use_vit :
158
+ tracker_algorithm_list .append ('Vit' )
159
+
152
160
if len (tracker_algorithm_list ) == 0 :
153
161
tracker_algorithm_list .append ('DaSiamRPN' )
154
162
print (tracker_algorithm_list )
0 commit comments