1
- """Helper module for calculating the live activation counts."""
2
-
1
+ """Helper module for calculating and saving learned structures."""
3
2
from __future__ import absolute_import
4
3
from __future__ import division
5
4
# [internal] enable type annotations
12
11
import tensorflow as tf
13
12
from typing import Text , Sequence , Dict , Optional , IO , Iterable , Callable
14
13
15
- _SUPPORTED_OPS = ['Conv2D' ]
14
+ _SUPPORTED_OPS = ['Conv2D' , 'Conv2DBackpropInput' ]
16
15
_ALIVE_FILENAME = 'alive'
17
16
18
17
@@ -25,8 +24,8 @@ def compute_alive_counts(
25
24
"""Computes alive counts.
26
25
27
26
Args:
28
- alive_vectors: A mapping from op_name to a vector where each element says
29
- whether the corresponding output activation is alive.
27
+ alive_vectors: A mapping from op_name to a vector where each element
28
+ indicates whether the corresponding output activation is alive.
30
29
31
30
Returns:
32
31
Mapping from op_name to the number of its alive output activations.
@@ -41,8 +40,21 @@ class StructureExporter(object):
41
40
"""Reports statistics about the current state of regularization.
42
41
43
42
Obtains live activation counts for supported ops: a map from each op name
44
- to its count of alive activations (filters). Optionally, thresholds the counts
45
- so that very low counts are reported as 0. Currently, only supports Conv2D.
43
+ to its count of alive activations (filters).
44
+
45
+ Usage:
46
+ 1. Build model.
47
+ `logits = build_model(parmas)`
48
+ 2. Create network regularizer.
49
+ `network_regularizer = flop_regularizer.GammaFlopsRegularizer([logits.op])
50
+ 3. Create StructureExporter:
51
+ `exporter = StructureExporter(net_reg.op_regularizer_manager)`
52
+ 4. Gather tensors to eval:
53
+ `tensor_to_eval_dict = exporter.tensors`
54
+ 5. Within a `tf.Session()` eval and populate tensors:
55
+ `exporter.populate_tensor_values(tensor_to_eval_dict.eval())`
56
+ 6. Export structure:
57
+ `exporter.save_alive_counts(tf.gfile.Open(...))`
46
58
"""
47
59
48
60
def __init__ (self ,
@@ -57,38 +69,51 @@ def __init__(self,
57
69
with the same prefix (up to and including the first '/'), and if so,
58
70
skip that prefix in exported data.
59
71
"""
60
- self . _op_regularizer_manager = op_regularizer_manager
61
- self ._alive_tensors = {} # type: Dict[Text, tf.Tensor]
72
+ # TODO(p1): Consider deleting unused `remove_common_prefix` b/133261798.
73
+ self ._tensors = {} # type: Dict[Text, tf.Tensor]
62
74
self ._alive_vectors = None # type: Optional[Dict[Text, Sequence[bool]]]
75
+ rename_fn = get_remove_common_prefix_fn (
76
+ self ._tensors ) if remove_common_prefix else lambda x : x
63
77
64
- for op in self . _op_regularizer_manager .ops :
78
+ for op in op_regularizer_manager .ops :
65
79
if op .type not in _SUPPORTED_OPS :
66
80
continue
67
- opreg = self ._op_regularizer_manager .get_regularizer (op )
68
- if opreg :
69
- # TODO(p1): use bool here (no cast), and then convert later?
70
- self ._alive_tensors [op .name ] = tf .cast (opreg .alive_vector , tf .int32 )
71
- else :
81
+
82
+ opreg = op_regularizer_manager .get_regularizer (op )
83
+ if not opreg :
72
84
tf .logging .warning ('No regularizer found for: %s' , op .name )
85
+ continue
73
86
74
- if remove_common_prefix :
75
- rename_op = get_remove_common_prefix_op (self ._alive_tensors )
76
- self ._alive_tensors = {
77
- rename_op (k ): v for k , v in self ._alive_tensors .items ()
78
- }
87
+ self ._tensors [rename_fn (op .name )] = tf .cast (opreg .alive_vector , tf .int32 )
79
88
80
89
@property
81
90
def tensors (self ):
82
- """The list of tensors required to compute statistics.
91
+ """A dictionary between op names and alive vectors.
92
+
93
+ Alive vectors are `tf.Tensor`s of type tf.int32.
83
94
84
95
Returns:
85
96
Dict: op name -> alive vector tensor
86
97
"""
87
- return self ._alive_tensors
98
+ # TODO(p1): Rename tensors to something better. tensors is a dict!
99
+ return self ._tensors
88
100
89
101
def populate_tensor_values (self , values : Dict [Text , Sequence [bool ]]) -> None :
90
- # TODO(p1): make this a hierarchy with 'alive_vectors' key at the top
91
- assert sorted (values ) == sorted (self .tensors )
102
+ """Records alive values for ops regularized by op_regularizer_manager.
103
+
104
+ The given mapping must match op names from `self.tensor`.
105
+
106
+ Args:
107
+ values: A dict mapping op names to a boolean alive status.
108
+
109
+ Raises:
110
+ ValueError: If keys of input do not match keys of `self.tensor`.
111
+ """
112
+ # TODO(p1): Rename values to something better. values is a dict!
113
+ if sorted (values ) != sorted (self .tensors ):
114
+ raise ValueError (
115
+ '`values` and `self.tensors` must have the same keys but are %s and %s'
116
+ % (sorted (values ), sorted (self .tensors )))
92
117
self ._alive_vectors = values
93
118
94
119
def get_alive_counts (self ) -> Dict [Text , int ]:
@@ -105,7 +130,6 @@ def get_alive_counts(self) -> Dict[Text, int]:
105
130
106
131
if self ._alive_vectors is None :
107
132
raise RuntimeError ('Tensor values not populated.' )
108
- # TODO(p1): consider warning if same values are used twice?
109
133
return compute_alive_counts (self ._alive_vectors )
110
134
111
135
def save_alive_counts (self , f : IO [bytes ]) -> None :
@@ -116,22 +140,24 @@ def save_alive_counts(self, f: IO[bytes]) -> None:
116
140
"""
117
141
f .write (format_structure (self .get_alive_counts ()))
118
142
119
- def create_file_and_save_alive_counts (self , train_dir : Text ,
120
- global_step : tf . Tensor ) -> None :
121
- """Creates a file and saves live counts to it .
143
+ def create_file_and_save_alive_counts (self , base_dir : Text ,
144
+ global_step : int ) -> None :
145
+ """Creates and updates files with alive counts .
122
146
123
- Creates the directory {train_dir}/learned_structure/ and saves the current
124
- alive counts to {path}/{_ALIVE_FILENAME}_{global_step} and overwrites
125
- {path}/{_ALIVE_FILENAME}.
147
+ Creates the directory `{base_dir}/learned_structure/` and saves the current
148
+ alive counts to:
149
+ `{base_dir}/learned_structure/{_ALIVE_FILENAME}_{global_step}`
150
+ and overwrites:
151
+ `{base_dir}/learned_structure/{_ALIVE_FILENAME}`.
126
152
127
153
Args:
128
- train_dir : where to export the alive counts.
154
+ base_dir : where to export the alive counts.
129
155
global_step: current value of global step, used as a suffix in filename.
130
156
"""
131
157
current_filename = '%s_%s' % (_ALIVE_FILENAME , global_step )
132
- directory = os .path .join (train_dir , 'learned_structure' )
158
+ directory = os .path .join (base_dir , 'learned_structure' )
133
159
try :
134
- tf .gfile .MkDir (directory )
160
+ tf .gfile .MakeDirs (directory )
135
161
except tf .errors .OpError :
136
162
# Probably already exists. If not, we'll see the error in the next line.
137
163
pass
@@ -143,8 +169,8 @@ def create_file_and_save_alive_counts(self, train_dir: Text,
143
169
144
170
# TODO(p1): maybe check that we still end up with unique names after prefix
145
171
# removal, and do nothing if that's not the case?
146
- def get_remove_common_prefix_op (
147
- iterable : Iterable [ Text ] ) -> Callable [[Text ], Text ]:
172
+ def get_remove_common_prefix_fn ( iterable : Iterable [ Text ]
173
+ ) -> Callable [[Text ], Text ]:
148
174
"""Obtains a function that removes common prefix.
149
175
150
176
Determines if all items in iterable start with the same substring (up to and
0 commit comments