Skip to content

Commit 77d6ac4

Browse files
eladebanmn-robot
authored andcommitted
Support conv2d_transpose for structure exporting.
PiperOrigin-RevId: 249389711
1 parent b9cc6d8 commit 77d6ac4

File tree

2 files changed

+183
-158
lines changed

2 files changed

+183
-158
lines changed

morph_net/tools/structure_exporter.py

+62-36
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""Helper module for calculating the live activation counts."""
2-
1+
"""Helper module for calculating and saving learned structures."""
32
from __future__ import absolute_import
43
from __future__ import division
54
# [internal] enable type annotations
@@ -12,7 +11,7 @@
1211
import tensorflow as tf
1312
from typing import Text, Sequence, Dict, Optional, IO, Iterable, Callable
1413

15-
_SUPPORTED_OPS = ['Conv2D']
14+
_SUPPORTED_OPS = ['Conv2D', 'Conv2DBackpropInput']
1615
_ALIVE_FILENAME = 'alive'
1716

1817

@@ -25,8 +24,8 @@ def compute_alive_counts(
2524
"""Computes alive counts.
2625
2726
Args:
28-
alive_vectors: A mapping from op_name to a vector where each element says
29-
whether the corresponding output activation is alive.
27+
alive_vectors: A mapping from op_name to a vector where each element
28+
indicates whether the corresponding output activation is alive.
3029
3130
Returns:
3231
Mapping from op_name to the number of its alive output activations.
@@ -41,8 +40,21 @@ class StructureExporter(object):
4140
"""Reports statistics about the current state of regularization.
4241
4342
Obtains live activation counts for supported ops: a map from each op name
44-
to its count of alive activations (filters). Optionally, thresholds the counts
45-
so that very low counts are reported as 0. Currently, only supports Conv2D.
43+
to its count of alive activations (filters).
44+
45+
Usage:
46+
1. Build model.
47+
`logits = build_model(parmas)`
48+
2. Create network regularizer.
49+
`network_regularizer = flop_regularizer.GammaFlopsRegularizer([logits.op])
50+
3. Create StructureExporter:
51+
`exporter = StructureExporter(net_reg.op_regularizer_manager)`
52+
4. Gather tensors to eval:
53+
`tensor_to_eval_dict = exporter.tensors`
54+
5. Within a `tf.Session()` eval and populate tensors:
55+
`exporter.populate_tensor_values(tensor_to_eval_dict.eval())`
56+
6. Export structure:
57+
`exporter.save_alive_counts(tf.gfile.Open(...))`
4658
"""
4759

4860
def __init__(self,
@@ -57,38 +69,51 @@ def __init__(self,
5769
with the same prefix (up to and including the first '/'), and if so,
5870
skip that prefix in exported data.
5971
"""
60-
self._op_regularizer_manager = op_regularizer_manager
61-
self._alive_tensors = {} # type: Dict[Text, tf.Tensor]
72+
# TODO(p1): Consider deleting unused `remove_common_prefix` b/133261798.
73+
self._tensors = {} # type: Dict[Text, tf.Tensor]
6274
self._alive_vectors = None # type: Optional[Dict[Text, Sequence[bool]]]
75+
rename_fn = get_remove_common_prefix_fn(
76+
self._tensors) if remove_common_prefix else lambda x: x
6377

64-
for op in self._op_regularizer_manager.ops:
78+
for op in op_regularizer_manager.ops:
6579
if op.type not in _SUPPORTED_OPS:
6680
continue
67-
opreg = self._op_regularizer_manager.get_regularizer(op)
68-
if opreg:
69-
# TODO(p1): use bool here (no cast), and then convert later?
70-
self._alive_tensors[op.name] = tf.cast(opreg.alive_vector, tf.int32)
71-
else:
81+
82+
opreg = op_regularizer_manager.get_regularizer(op)
83+
if not opreg:
7284
tf.logging.warning('No regularizer found for: %s', op.name)
85+
continue
7386

74-
if remove_common_prefix:
75-
rename_op = get_remove_common_prefix_op(self._alive_tensors)
76-
self._alive_tensors = {
77-
rename_op(k): v for k, v in self._alive_tensors.items()
78-
}
87+
self._tensors[rename_fn(op.name)] = tf.cast(opreg.alive_vector, tf.int32)
7988

8089
@property
8190
def tensors(self):
82-
"""The list of tensors required to compute statistics.
91+
"""A dictionary between op names and alive vectors.
92+
93+
Alive vectors are `tf.Tensor`s of type tf.int32.
8394
8495
Returns:
8596
Dict: op name -> alive vector tensor
8697
"""
87-
return self._alive_tensors
98+
# TODO(p1): Rename tensors to something better. tensors is a dict!
99+
return self._tensors
88100

89101
def populate_tensor_values(self, values: Dict[Text, Sequence[bool]]) -> None:
90-
# TODO(p1): make this a hierarchy with 'alive_vectors' key at the top
91-
assert sorted(values) == sorted(self.tensors)
102+
"""Records alive values for ops regularized by op_regularizer_manager.
103+
104+
The given mapping must match op names from `self.tensor`.
105+
106+
Args:
107+
values: A dict mapping op names to a boolean alive status.
108+
109+
Raises:
110+
ValueError: If keys of input do not match keys of `self.tensor`.
111+
"""
112+
# TODO(p1): Rename values to something better. values is a dict!
113+
if sorted(values) != sorted(self.tensors):
114+
raise ValueError(
115+
'`values` and `self.tensors` must have the same keys but are %s and %s'
116+
% (sorted(values), sorted(self.tensors)))
92117
self._alive_vectors = values
93118

94119
def get_alive_counts(self) -> Dict[Text, int]:
@@ -105,7 +130,6 @@ def get_alive_counts(self) -> Dict[Text, int]:
105130

106131
if self._alive_vectors is None:
107132
raise RuntimeError('Tensor values not populated.')
108-
# TODO(p1): consider warning if same values are used twice?
109133
return compute_alive_counts(self._alive_vectors)
110134

111135
def save_alive_counts(self, f: IO[bytes]) -> None:
@@ -116,22 +140,24 @@ def save_alive_counts(self, f: IO[bytes]) -> None:
116140
"""
117141
f.write(format_structure(self.get_alive_counts()))
118142

119-
def create_file_and_save_alive_counts(self, train_dir: Text,
120-
global_step: tf.Tensor) -> None:
121-
"""Creates a file and saves live counts to it.
143+
def create_file_and_save_alive_counts(self, base_dir: Text,
144+
global_step: int) -> None:
145+
"""Creates and updates files with alive counts.
122146
123-
Creates the directory {train_dir}/learned_structure/ and saves the current
124-
alive counts to {path}/{_ALIVE_FILENAME}_{global_step} and overwrites
125-
{path}/{_ALIVE_FILENAME}.
147+
Creates the directory `{base_dir}/learned_structure/` and saves the current
148+
alive counts to:
149+
`{base_dir}/learned_structure/{_ALIVE_FILENAME}_{global_step}`
150+
and overwrites:
151+
`{base_dir}/learned_structure/{_ALIVE_FILENAME}`.
126152
127153
Args:
128-
train_dir: where to export the alive counts.
154+
base_dir: where to export the alive counts.
129155
global_step: current value of global step, used as a suffix in filename.
130156
"""
131157
current_filename = '%s_%s' % (_ALIVE_FILENAME, global_step)
132-
directory = os.path.join(train_dir, 'learned_structure')
158+
directory = os.path.join(base_dir, 'learned_structure')
133159
try:
134-
tf.gfile.MkDir(directory)
160+
tf.gfile.MakeDirs(directory)
135161
except tf.errors.OpError:
136162
# Probably already exists. If not, we'll see the error in the next line.
137163
pass
@@ -143,8 +169,8 @@ def create_file_and_save_alive_counts(self, train_dir: Text,
143169

144170
# TODO(p1): maybe check that we still end up with unique names after prefix
145171
# removal, and do nothing if that's not the case?
146-
def get_remove_common_prefix_op(
147-
iterable: Iterable[Text]) -> Callable[[Text], Text]:
172+
def get_remove_common_prefix_fn(iterable: Iterable[Text]
173+
) -> Callable[[Text], Text]:
148174
"""Obtains a function that removes common prefix.
149175
150176
Determines if all items in iterable start with the same substring (up to and

0 commit comments

Comments
 (0)