Skip to content

Commit a8173f3

Browse files
authored
fix get sub model (#733) (#746)
1 parent 0642d9a commit a8173f3

File tree

6 files changed

+111
-66
lines changed

6 files changed

+111
-66
lines changed

paddleslim/nas/ofa/get_sub_model.py

+71-26
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@ def get_prune_params_config(graph, origin_model_config):
3737
### TODO(ceci3):
3838
### 1. fix config when this op is concat by graph.pre_ops(op)
3939
### 2. add kernel_size in config
40-
### 3. add channel in config
4140
for inp in op.all_inputs():
4241
n_ops = graph.next_ops(op)
4342
if inp._var.name in origin_model_config.keys():
44-
if 'expand_ratio' in origin_model_config[inp._var.name].keys():
45-
tmp = origin_model_config[inp._var.name]['expand_ratio']
43+
if 'expand_ratio' in origin_model_config[
44+
inp._var.name] or 'channel' in origin_model_config[
45+
inp._var.name]:
46+
key = 'channel' if 'channel' in origin_model_config[
47+
inp._var.name] else 'expand_ratio'
48+
tmp = origin_model_config[inp._var.name][key]
4649
if len(inp._var.shape) > 1:
4750
if inp._var.name in param_config.keys():
4851
param_config[inp._var.name].append(tmp)
@@ -59,9 +62,13 @@ def get_prune_params_config(graph, origin_model_config):
5962
if next_inp._var.persistable == True:
6063
if next_inp._var.name in origin_model_config.keys():
6164
if 'expand_ratio' in origin_model_config[
62-
next_inp._var.name].keys():
65+
next_inp._var.
66+
name] or 'channel' in origin_model_config[
67+
next_inp._var.name]:
68+
key = 'channel' if 'channel' in origin_model_config[
69+
next_inp._var.name] else 'expand_ratio'
6370
tmp = origin_model_config[next_inp._var.name][
64-
'expand_ratio']
71+
key]
6572
pre = tmp if precedor is None else precedor
6673
if len(next_inp._var.shape) > 1:
6774
param_config[next_inp._var.name] = [pre]
@@ -78,9 +85,19 @@ def get_prune_params_config(graph, origin_model_config):
7885
return param_config
7986

8087

88+
def get_actual_shape(transform, channel):
89+
if transform == None:
90+
channel = int(channel)
91+
else:
92+
if isinstance(transform, float):
93+
channel = int(channel * transform)
94+
else:
95+
channel = int(transform)
96+
return channel
97+
98+
8199
def prune_params(model, param_config, super_model_sd=None):
82100
""" Prune parameters according to the config.
83-
84101
Parameters:
85102
model(paddle.nn.Layer): instance of model.
86103
param_config(dict): prune config of each weight.
@@ -104,25 +121,18 @@ def prune_params(model, param_config, super_model_sd=None):
104121
in_exp = param_config[param.name][0]
105122
out_exp = param_config[param.name][1]
106123
if sublayer.__class__.__name__.lower() in CONV_TYPES:
107-
in_chn = int(value.shape[1]) if in_exp == None else int(
108-
value.shape[1] * in_exp)
109-
out_chn = int(value.shape[
110-
0]) if out_exp == None else int(value.shape[0] *
111-
out_exp)
124+
in_chn = get_actual_shape(in_exp, value.shape[1])
125+
out_chn = get_actual_shape(out_exp, value.shape[0])
112126
prune_value = super_value[:out_chn, :in_chn, ...] \
113127
if super_model_sd != None else value[:out_chn, :in_chn, ...]
114128
else:
115-
in_chn = int(value.shape[0]) if in_exp == None else int(
116-
value.shape[0] * in_exp)
117-
out_chn = int(value.shape[
118-
1]) if out_exp == None else int(value.shape[1] *
119-
out_exp)
129+
in_chn = get_actual_shape(in_exp, value.shape[0])
130+
out_chn = get_actual_shape(out_exp, value.shape[1])
120131
prune_value = super_value[:in_chn, :out_chn, ...] \
121132
if super_model_sd != None else value[:in_chn, :out_chn, ...]
122133
else:
123-
out_chn = int(value.shape[0]) if param_config[param.name][
124-
0] == None else int(value.shape[0] *
125-
param_config[param.name][0])
134+
out_chn = get_actual_shape(param_config[param.name][0],
135+
value.shape[0])
126136
prune_value = super_value[:out_chn, ...] \
127137
if super_model_sd != None else value[:out_chn, ...]
128138

@@ -140,23 +150,24 @@ def prune_params(model, param_config, super_model_sd=None):
140150
if param.trainable:
141151
param.clear_gradient()
142152

143-
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
153+
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
144154
if super_model_sd != None and len(super_model_sd) != 0:
145155
for k, v in super_model_sd.items():
146156
setattr(model, k, v)
147157

148158

149159
def _is_depthwise(op):
150-
"""Check if this op is depthwise conv.
160+
"""Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
151161
The shape of input and the shape of output in depthwise conv must be same in superlayer,
152162
so depthwise op cannot be consider as weight op
153163
"""
154-
if op.type() == 'depthwise_conv':
155-
return True
156-
elif 'conv' in op.type():
164+
#if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
165+
# return True
166+
if 'conv' in op.type():
157167
for inp in op.all_inputs():
158-
if not inp._var.persistable and op.attr('groups') == inp._var.shape[
159-
1]:
168+
if inp._var.persistable and (
169+
op.attr('groups') == inp._var.shape[0] and
170+
op.attr('groups') * inp._var.shape[1] == inp._var.shape[0]):
160171
return True
161172
return False
162173

@@ -179,6 +190,7 @@ def _find_weight_ops(op, graph, weights):
179190
weights.append(inp._var.name)
180191
return weights
181192
return _find_weight_ops(pre_op, graph, weights)
193+
return weights
182194

183195

184196
def _find_pre_elementwise_add(op, graph):
@@ -236,3 +248,36 @@ def check_search_space(graph):
236248
depthwise_conv = sorted(depthwise_conv)
237249

238250
return (final_search_space, depthwise_conv)
251+
252+
253+
def broadcast_search_space(same_search_space, param2key, origin_config):
254+
"""
255+
Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}
256+
257+
Args:
258+
same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
259+
param2key(dict): the name of layers corresponds to the name of parameter.
260+
origin_config(dict): the search space which can be searched.
261+
"""
262+
for per_ss in same_search_space:
263+
for ss in per_ss[1:]:
264+
key = param2key[ss]
265+
pre_key = param2key[per_ss[0]]
266+
if key in origin_config:
267+
if 'expand_ratio' in origin_config[pre_key]:
268+
origin_config[key].update({
269+
'expand_ratio': origin_config[pre_key]['expand_ratio']
270+
})
271+
elif 'channel' in origin_config[pre_key]:
272+
origin_config[key].update({
273+
'channel': origin_config[pre_key]['channel']
274+
})
275+
else:
276+
if 'expand_ratio' in origin_config[pre_key]:
277+
origin_config[key] = {
278+
'expand_ratio': origin_config[pre_key]['expand_ratio']
279+
}
280+
elif 'channel' in origin_config[pre_key]:
281+
origin_config[key] = {
282+
'channel': origin_config[pre_key]['channel']
283+
}

paddleslim/nas/ofa/ofa.py

+20-36
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .utils.utils import search_idx
3232
from ...common import get_logger
3333
from ...core import GraphWrapper, dygraph2program
34-
from .get_sub_model import get_prune_params_config, prune_params, check_search_space
34+
from .get_sub_model import get_prune_params_config, prune_params, check_search_space, broadcast_search_space
3535

3636
_logger = get_logger(__name__, level=logging.INFO)
3737

@@ -156,7 +156,6 @@ class OFA(OFABase):
156156
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
157157
sp_model = Convert(sp_net_config).convert(model)
158158
ofa_model = OFA(sp_model)
159-
160159
"""
161160

162161
def __init__(self,
@@ -461,6 +460,23 @@ def search(self, eval_func, condition):
461460

462461
def _export_sub_model_config(self, origin_model, config, input_shapes,
463462
input_dtypes):
463+
param2name = {}
464+
for name, sublayer in origin_model.named_sublayers():
465+
for param in sublayer.parameters(include_sublayers=False):
466+
if name.split('.')[-1] == 'fn':
467+
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
468+
param2name[param.name] = name[:-3]
469+
else:
470+
param2name[param.name] = name
471+
472+
program = dygraph2program(
473+
origin_model, inputs=input_shapes, dtypes=input_dtypes)
474+
graph = GraphWrapper(program)
475+
476+
same_config, _ = check_search_space(graph)
477+
if same_config != None:
478+
broadcast_search_space(same_config, param2name, config)
479+
464480
origin_model_config = {}
465481
for name, sublayer in origin_model.named_sublayers():
466482
if isinstance(sublayer, BaseBlock):
@@ -469,9 +485,6 @@ def _export_sub_model_config(self, origin_model, config, input_shapes,
469485
if name in config.keys():
470486
origin_model_config[param.name] = config[name]
471487

472-
program = dygraph2program(
473-
origin_model, inputs=input_shapes, dtypes=input_dtypes)
474-
graph = GraphWrapper(program)
475488
param_prune_config = get_prune_params_config(graph, origin_model_config)
476489
return param_prune_config
477490

@@ -493,7 +506,6 @@ def export(self,
493506
.. code-block:: python
494507
from paddle.vision.models import mobilenet_v1
495508
origin_model = mobilenet_v1()
496-
497509
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
498510
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
499511
"""
@@ -505,7 +517,6 @@ def export(self,
505517
origin_model = self.model
506518
origin_model = origin_model._layers if isinstance(
507519
origin_model, DataParallel) else origin_model
508-
509520
param_config = self._export_sub_model_config(origin_model, config,
510521
input_shapes, input_dtypes)
511522
prune_params(origin_model, param_config, super_sd)
@@ -602,7 +613,6 @@ def _clear_search_space(self, *inputs, **kwargs):
602613
per_ss.append(key)
603614
else:
604615
_logger.info("{} not in ss".format(key))
605-
606616
if len(per_ss) != 0:
607617
tmp_same_ss.append(per_ss)
608618

@@ -626,33 +636,6 @@ def _clear_search_space(self, *inputs, **kwargs):
626636
):
627637
self._clear_width(name)
628638

629-
def _broadcast_ss(self):
630-
""" broadcast search space after random sample."""
631-
for per_ss in self._same_ss:
632-
for ss in per_ss[1:]:
633-
key = self._param2key[ss]
634-
pre_key = self._param2key[per_ss[0]]
635-
if key in self.current_config:
636-
if 'expand_ratio' in self.current_config[pre_key]:
637-
self.current_config[key].update({
638-
'expand_ratio':
639-
self.current_config[pre_key]['expand_ratio']
640-
})
641-
elif 'channel' in self.current_config[pre_key]:
642-
self.current_config[key].update({
643-
'channel': self.current_config[pre_key]['channel']
644-
})
645-
else:
646-
if 'expand_ratio' in self.current_config[pre_key]:
647-
self.current_config[key] = {
648-
'expand_ratio':
649-
self.current_config[pre_key]['expand_ratio']
650-
}
651-
elif 'channel' in self.current_config[pre_key]:
652-
self.current_config[key] = {
653-
'channel': self.current_config[pre_key]['channel']
654-
}
655-
656639
def forward(self, *inputs, **kwargs):
657640
# ===================== teacher process =====================
658641
teacher_output = None
@@ -692,7 +675,8 @@ def forward(self, *inputs, **kwargs):
692675
kwargs['depth'] = self.current_config['depth']
693676

694677
if self._broadcast:
695-
self._broadcast_ss()
678+
broadcast_search_space(self._same_ss, self._param2key,
679+
self.current_config)
696680

697681
student_output = self.model.forward(*inputs, **kwargs)
698682

paddleslim/nas/one_shot/one_shot_nas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
3434
list<int>: The best tokens searched.
3535
"""
3636
super_net = None
37-
for layer in model.sublayers(include_sublayers=False):
37+
for layer in model.sublayers(include_self=True):
3838
print("layer: {}".format(layer))
3939
if isinstance(layer, OneShotSuperNet):
4040
super_net = layer

paddleslim/teachers/bert/model/transformer_encoder.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@ def __init__(self, process_cmd, d_model, dropout_rate, name):
3737

3838
for cmd in self.process_cmd:
3939
if cmd == "a": # add residual connection
40-
self.functors.append(
41-
lambda x, y: x + y if y is not None else x)
40+
self.functors.append(lambda x, y: x + y if y is not None else x)
4241
self.exec_order += "a"
4342
elif cmd == "n": # add layer normalization
4443
self.functors.append(
4544
self.add_sublayer(
4645
"layer_norm_%d" % len(
47-
self.sublayers(include_sublayers=False)),
46+
self.sublayers(include_self=True)),
4847
LayerNorm(
4948
normalized_shape=d_model,
5049
param_attr=fluid.ParamAttr(

tests/test_ofa.py

+16
Original file line numberDiff line numberDiff line change
@@ -449,5 +449,21 @@ def test_export_model(self):
449449
assert len(self.ofa_model.ofa_layers) == 38
450450

451451

452+
class TestExportCase1(unittest.TestCase):
453+
def setUp(self):
454+
model = ModelLinear1()
455+
data_np = np.random.random((3, 64)).astype(np.int64)
456+
self.data = paddle.to_tensor(data_np)
457+
self.ofa_model = OFA(model)
458+
self.ofa_model.set_epoch(0)
459+
outs, _ = self.ofa_model(self.data)
460+
self.config = self.ofa_model.current_config
461+
462+
def test_export_model(self):
463+
self.ofa_model.export(
464+
self.config, input_shapes=[[3, 64]], input_dtypes=['int64'])
465+
assert len(self.ofa_model.ofa_layers) == 4
466+
467+
452468
if __name__ == '__main__':
453469
unittest.main()

tests/test_ofa_v2.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(self, x):
7777
y = x + y
7878
z = self.branch2(y)
7979
z = z + y
80+
z = self.out(z)
8081
return z
8182

8283

0 commit comments

Comments
 (0)