Skip to content

Commit 27a8566

Browse files
author
grimoire
committed
fix mmcls ut
1 parent ce036d5 commit 27a8566

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

Diff for: tests/test_codebase/test_mmcls/test_classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_visualize(backend_model):
9898
results = backend_model.test_step([input_dict])
9999
with TemporaryDirectory() as dir:
100100
filename = dir + '/tmp.jpg'
101-
task_processor.visualize(img, results[0], filename, '')
101+
task_processor.visualize(img, results[0], filename, 'window')
102102
assert os.path.exists(filename)
103103

104104

Diff for: tests/test_codebase/test_mmcls/test_mmcls_models.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mmengine import Config
66

77
from mmdeploy.codebase import import_codebase
8+
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
89
from mmdeploy.utils import Backend, Codebase
910
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
1011

@@ -62,7 +63,6 @@ def test_baseclassifier_forward():
6263
from mmcls.models.classifiers import ImageClassifier
6364

6465
from mmdeploy.codebase.mmcls import models # noqa
65-
from mmdeploy.core.rewriters import patch_model
6666

6767
class DummyClassifier(ImageClassifier):
6868

@@ -75,8 +75,8 @@ def extract_feat(self, batch_inputs: torch.Tensor):
7575
def head(self, x):
7676
return x
7777

78-
def forward(self, batch_inputs, data_samples, mode):
79-
return batch_inputs + 1
78+
def predict(self, x, data_samples):
79+
return x
8080

8181
backbone_cfg = dict(
8282
type='ResNet',
@@ -86,11 +86,12 @@ def forward(self, batch_inputs, data_samples, mode):
8686
style='pytorch')
8787
model = DummyClassifier(backbone_cfg).eval()
8888

89-
model_output = model(input, None, None)
90-
model = patch_model(model, {}, bachend='onnxruntime', data_samples=None)
91-
backend_output = model(input)
89+
model_output = model(input, None, mode='predict')
90+
91+
with RewriterContext({}):
92+
backend_output = model(input)
9293

93-
assert model_output == input + 1
94+
assert model_output == input
9495
assert backend_output == input
9596

9697

@@ -186,7 +187,7 @@ def test_vision_transformer_backbone__forward(backend_type: Backend):
186187
model_output.reshape(-1),
187188
rewrite_output.reshape(-1),
188189
rtol=1e-03,
189-
atol=1e-03)
190+
atol=1e-02)
190191

191192

192193
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)