Skip to content

Commit 5e33ca5

Browse files
committed
Adds Vit tracker option
1 parent 0658d11 commit 5e33ca5

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ Python版OpenCVのTracking APIのサンプルです。<br>
77
* opencv-contrib-python 4.8.0.74 or later
88

99
# Algorithm
10-
2023/07/25時点でOpenCVには以下10アルゴリズムが実装されています
10+
2024/01/18時点でOpenCVには以下11アルゴリズムが実装されています
1111
* DaSiamRPN
1212
* NanoTrack
13+
* Vit
1314
* MIL
1415
* GOTURN
1516
* CSRT
@@ -68,6 +69,9 @@ DaSiamRPNトラッカーの使用有無<br>
6869
* --use_nano<br>
6970
NanoTrackの使用有無<br>
7071
デフォルト:指定なし
72+
* --use_vit<br>
73+
Vitの使用有無<br>
74+
デフォルト:指定なし
7175
* --use_csrt<br>
7276
CSRTトラッカーの使用有無<br>
7377
デフォルト:指定なし
Binary file not shown.

performance_comparison_sample.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_args():
2626
parser.add_argument('--use_medianflow', action='store_true')
2727
parser.add_argument('--use_tld', action='store_true')
2828
parser.add_argument('--use_nano', action='store_true')
29+
parser.add_argument('--use_vit', action='store_true')
2930

3031
args = parser.parse_args()
3132

@@ -63,6 +64,10 @@ def initialize_tracker_list(window_name, image, tracker_algorithm_list):
6364
# params.backbone = "model/nanotrackv3/nanotrack_backbone_sim.onnx"
6465
# params.neckhead = "model/nanotrackv3/nanotrack_head_sim.onnx"
6566
tracker = cv.TrackerNano_create(params)
67+
if tracker_algorithm == 'Vit':
68+
params = cv.TrackerVit_Params()
69+
params.net = "model/vit/object_tracking_vittrack_2023sep.onnx"
70+
tracker = cv.TrackerVit_create(params)
6671
if tracker_algorithm == 'CSRT':
6772
tracker = cv.TrackerCSRT_create()
6873
if tracker_algorithm == 'KCF':
@@ -125,6 +130,7 @@ def main():
125130
use_medianflow = args.use_medianflow
126131
use_tld = args.use_tld
127132
use_nano = args.use_nano
133+
use_vit = args.use_vit
128134

129135
# 使用アルゴリズム #########################################################
130136
tracker_algorithm_list = []
@@ -148,7 +154,9 @@ def main():
148154
tracker_algorithm_list.append('TLD')
149155
if use_nano:
150156
tracker_algorithm_list.append('Nano')
151-
157+
if use_vit:
158+
tracker_algorithm_list.append('Vit')
159+
152160
if len(tracker_algorithm_list) == 0:
153161
tracker_algorithm_list.append('DaSiamRPN')
154162
print(tracker_algorithm_list)

0 commit comments

Comments
 (0)