5
5
from mmengine import Config
6
6
7
7
from mmdeploy .codebase import import_codebase
8
+ from mmdeploy .core .rewriters .rewriter_manager import RewriterContext
8
9
from mmdeploy .utils import Backend , Codebase
9
10
from mmdeploy .utils .test import WrapModel , check_backend , get_rewrite_outputs
10
11
@@ -62,7 +63,6 @@ def test_baseclassifier_forward():
62
63
from mmcls .models .classifiers import ImageClassifier
63
64
64
65
from mmdeploy .codebase .mmcls import models # noqa
65
- from mmdeploy .core .rewriters import patch_model
66
66
67
67
class DummyClassifier (ImageClassifier ):
68
68
@@ -75,8 +75,8 @@ def extract_feat(self, batch_inputs: torch.Tensor):
75
75
def head (self , x ):
76
76
return x
77
77
78
- def forward (self , batch_inputs , data_samples , mode ):
79
- return batch_inputs + 1
78
+ def predict (self , x , data_samples ):
79
+ return x
80
80
81
81
backbone_cfg = dict (
82
82
type = 'ResNet' ,
@@ -86,11 +86,12 @@ def forward(self, batch_inputs, data_samples, mode):
86
86
style = 'pytorch' )
87
87
model = DummyClassifier (backbone_cfg ).eval ()
88
88
89
- model_output = model (input , None , None )
90
- model = patch_model (model , {}, bachend = 'onnxruntime' , data_samples = None )
91
- backend_output = model (input )
89
+ model_output = model (input , None , mode = 'predict' )
90
+
91
+ with RewriterContext ({}):
92
+ backend_output = model (input )
92
93
93
- assert model_output == input + 1
94
+ assert model_output == input
94
95
assert backend_output == input
95
96
96
97
@@ -186,7 +187,7 @@ def test_vision_transformer_backbone__forward(backend_type: Backend):
186
187
model_output .reshape (- 1 ),
187
188
rewrite_output .reshape (- 1 ),
188
189
rtol = 1e-03 ,
189
- atol = 1e-03 )
190
+ atol = 1e-02 )
190
191
191
192
192
193
@pytest .mark .parametrize (
0 commit comments