@@ -37,12 +37,15 @@ def get_prune_params_config(graph, origin_model_config):
37
37
### TODO(ceci3):
38
38
### 1. fix config when this op is concat by graph.pre_ops(op)
39
39
### 2. add kernel_size in config
40
- ### 3. add channel in config
41
40
for inp in op .all_inputs ():
42
41
n_ops = graph .next_ops (op )
43
42
if inp ._var .name in origin_model_config .keys ():
44
- if 'expand_ratio' in origin_model_config [inp ._var .name ].keys ():
45
- tmp = origin_model_config [inp ._var .name ]['expand_ratio' ]
43
+ if 'expand_ratio' in origin_model_config [
44
+ inp ._var .name ] or 'channel' in origin_model_config [
45
+ inp ._var .name ]:
46
+ key = 'channel' if 'channel' in origin_model_config [
47
+ inp ._var .name ] else 'expand_ratio'
48
+ tmp = origin_model_config [inp ._var .name ][key ]
46
49
if len (inp ._var .shape ) > 1 :
47
50
if inp ._var .name in param_config .keys ():
48
51
param_config [inp ._var .name ].append (tmp )
@@ -59,9 +62,13 @@ def get_prune_params_config(graph, origin_model_config):
59
62
if next_inp ._var .persistable == True :
60
63
if next_inp ._var .name in origin_model_config .keys ():
61
64
if 'expand_ratio' in origin_model_config [
62
- next_inp ._var .name ].keys ():
65
+ next_inp ._var .
66
+ name ] or 'channel' in origin_model_config [
67
+ next_inp ._var .name ]:
68
+ key = 'channel' if 'channel' in origin_model_config [
69
+ next_inp ._var .name ] else 'expand_ratio'
63
70
tmp = origin_model_config [next_inp ._var .name ][
64
- 'expand_ratio' ]
71
+ key ]
65
72
pre = tmp if precedor is None else precedor
66
73
if len (next_inp ._var .shape ) > 1 :
67
74
param_config [next_inp ._var .name ] = [pre ]
@@ -78,9 +85,19 @@ def get_prune_params_config(graph, origin_model_config):
78
85
return param_config
79
86
80
87
88
+ def get_actual_shape (transform , channel ):
89
+ if transform == None :
90
+ channel = int (channel )
91
+ else :
92
+ if isinstance (transform , float ):
93
+ channel = int (channel * transform )
94
+ else :
95
+ channel = int (transform )
96
+ return channel
97
+
98
+
81
99
def prune_params (model , param_config , super_model_sd = None ):
82
100
""" Prune parameters according to the config.
83
-
84
101
Parameters:
85
102
model(paddle.nn.Layer): instance of model.
86
103
param_config(dict): prune config of each weight.
@@ -104,25 +121,18 @@ def prune_params(model, param_config, super_model_sd=None):
104
121
in_exp = param_config [param .name ][0 ]
105
122
out_exp = param_config [param .name ][1 ]
106
123
if sublayer .__class__ .__name__ .lower () in CONV_TYPES :
107
- in_chn = int (value .shape [1 ]) if in_exp == None else int (
108
- value .shape [1 ] * in_exp )
109
- out_chn = int (value .shape [
110
- 0 ]) if out_exp == None else int (value .shape [0 ] *
111
- out_exp )
124
+ in_chn = get_actual_shape (in_exp , value .shape [1 ])
125
+ out_chn = get_actual_shape (out_exp , value .shape [0 ])
112
126
prune_value = super_value [:out_chn , :in_chn , ...] \
113
127
if super_model_sd != None else value [:out_chn , :in_chn , ...]
114
128
else :
115
- in_chn = int (value .shape [0 ]) if in_exp == None else int (
116
- value .shape [0 ] * in_exp )
117
- out_chn = int (value .shape [
118
- 1 ]) if out_exp == None else int (value .shape [1 ] *
119
- out_exp )
129
+ in_chn = get_actual_shape (in_exp , value .shape [0 ])
130
+ out_chn = get_actual_shape (out_exp , value .shape [1 ])
120
131
prune_value = super_value [:in_chn , :out_chn , ...] \
121
132
if super_model_sd != None else value [:in_chn , :out_chn , ...]
122
133
else :
123
- out_chn = int (value .shape [0 ]) if param_config [param .name ][
124
- 0 ] == None else int (value .shape [0 ] *
125
- param_config [param .name ][0 ])
134
+ out_chn = get_actual_shape (param_config [param .name ][0 ],
135
+ value .shape [0 ])
126
136
prune_value = super_value [:out_chn , ...] \
127
137
if super_model_sd != None else value [:out_chn , ...]
128
138
@@ -140,23 +150,24 @@ def prune_params(model, param_config, super_model_sd=None):
140
150
if param .trainable :
141
151
param .clear_gradient ()
142
152
143
- ### initialize param which not in sublayers, such as create persistable inputs by create_parameters
153
+ ### initialize param which not in sublayers, such as create persistable inputs by create_parameters
144
154
if super_model_sd != None and len (super_model_sd ) != 0 :
145
155
for k , v in super_model_sd .items ():
146
156
setattr (model , k , v )
147
157
148
158
149
159
def _is_depthwise (op ):
150
- """Check if this op is depthwise conv.
160
+ """Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
151
161
The shape of input and the shape of output in depthwise conv must be same in superlayer,
152
162
so depthwise op cannot be consider as weight op
153
163
"""
154
- if op .type () == 'depthwise_conv' :
155
- return True
156
- elif 'conv' in op .type ():
164
+ # if op.type() == 'depthwise_conv2d': ### depthwise_conv2d in paddle is Cout % Cin =0
165
+ # return True
166
+ if 'conv' in op .type ():
157
167
for inp in op .all_inputs ():
158
- if not inp ._var .persistable and op .attr ('groups' ) == inp ._var .shape [
159
- 1 ]:
168
+ if inp ._var .persistable and (
169
+ op .attr ('groups' ) == inp ._var .shape [0 ] and
170
+ op .attr ('groups' ) * inp ._var .shape [1 ] == inp ._var .shape [0 ]):
160
171
return True
161
172
return False
162
173
@@ -179,6 +190,7 @@ def _find_weight_ops(op, graph, weights):
179
190
weights .append (inp ._var .name )
180
191
return weights
181
192
return _find_weight_ops (pre_op , graph , weights )
193
+ return weights
182
194
183
195
184
196
def _find_pre_elementwise_add (op , graph ):
@@ -236,3 +248,36 @@ def check_search_space(graph):
236
248
depthwise_conv = sorted (depthwise_conv )
237
249
238
250
return (final_search_space , depthwise_conv )
251
+
252
+
253
+ def broadcast_search_space (same_search_space , param2key , origin_config ):
254
+ """
255
+ Inplace broadcast the origin_config according to the same search space. Such as: same_search_space = [['conv1_weight', 'conv3_weight']], param2key = {'conv1_weight': 'conv1.conv', 'conv3_weight': 'conv3.weight'}, origin_config= {'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}}, the result after this function is origin_config={'conv1.weight': {'channel': 10}, 'conv2.weight': {'channel': 20}, 'conv3.weight': {'channel': 10}}
256
+
257
+ Args:
258
+ same_search_space(list<list>): broadcast according this list, each list in same_search_space means the channel must be consistent.
259
+ param2key(dict): the name of layers corresponds to the name of parameter.
260
+ origin_config(dict): the search space which can be searched.
261
+ """
262
+ for per_ss in same_search_space :
263
+ for ss in per_ss [1 :]:
264
+ key = param2key [ss ]
265
+ pre_key = param2key [per_ss [0 ]]
266
+ if key in origin_config :
267
+ if 'expand_ratio' in origin_config [pre_key ]:
268
+ origin_config [key ].update ({
269
+ 'expand_ratio' : origin_config [pre_key ]['expand_ratio' ]
270
+ })
271
+ elif 'channel' in origin_config [pre_key ]:
272
+ origin_config [key ].update ({
273
+ 'channel' : origin_config [pre_key ]['channel' ]
274
+ })
275
+ else :
276
+ if 'expand_ratio' in origin_config [pre_key ]:
277
+ origin_config [key ] = {
278
+ 'expand_ratio' : origin_config [pre_key ]['expand_ratio' ]
279
+ }
280
+ elif 'channel' in origin_config [pre_key ]:
281
+ origin_config [key ] = {
282
+ 'channel' : origin_config [pre_key ]['channel' ]
283
+ }
0 commit comments