Skip to content

Commit 6903233

Browse files
authored
[Enhancement] support all matting models inferencer (open-mmlab#1836)
1 parent 50e2b9b commit 6903233

File tree

3 files changed

+8
-14
lines changed

3 files changed

+8
-14
lines changed

demo/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ python mmagic_inference_demo.py \
124124

125125
```shell
126126
python mmagic_inference_demo.py \
127-
--model-name global_local \
127+
--model-name gca \
128128
--img ../resources/input/matting/GT05.jpg \
129-
--mask ../resources/input/matting/GT05_trimap.jpg \
129+
--trimap ../resources/input/matting/GT05_trimap.jpg \
130130
--result-out-dir ../resources/output/matting/demo_matting_gca_res.png
131131
```
132132

mmagic/apis/inferencers/matting_inferencer.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch
88
from mmengine import mkdir_or_exist
99
from mmengine.dataset import Compose
10-
from mmengine.dataset.utils import default_collate as collate
11-
from torch.nn.parallel import scatter
1210

1311
from mmagic.structures import DataSample
1412
from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType
@@ -53,21 +51,15 @@ def preprocess(self, img: InputsType, trimap: InputsType) -> Dict:
5351
_data = test_pipeline(data)
5452
trimap = _data['data_samples'].trimap.data
5553
preprocess_res = dict()
56-
preprocess_res['inputs'] = torch.cat([_data['inputs'], trimap],
57-
dim=0).float()
58-
preprocess_res = collate([preprocess_res])
59-
preprocess_res['data_samples'] = DataSample.stack(
60-
[_data['data_samples']])
61-
preprocess_res['mode'] = 'predict'
62-
if 'cuda' in str(self.device):
63-
preprocess_res = scatter(preprocess_res, [self.device])[0]
64-
54+
preprocess_res['inputs'] = [_data['inputs']]
55+
preprocess_res['data_samples'] = [_data['data_samples']]
6556
return preprocess_res
6657

6758
def forward(self, inputs: InputsType) -> PredType:
6859
"""Forward the inputs to the model."""
60+
inputs = self.model.data_preprocessor(inputs)
6961
with torch.no_grad():
70-
return self.model(**inputs)
62+
return self.model(mode='predict', **inputs)
7163

7264
def visualize(self,
7365
preds: PredType,

mmagic/apis/mmagic_inferencer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class MMagicInferencer:
6363
'styleganv3',
6464

6565
# matting models
66+
'dim',
67+
'indexnet',
6668
'gca',
6769

6870
# inpainting models

0 commit comments

Comments
 (0)