1313# limitations under the License.
1414# ==============================================================================
1515
16- import numpy as np
1716from enum import Enum , auto
1817
18+ import numpy as np
19+
1920from lucid .modelzoo .aligned_activations import (
2021 push_activations ,
2122 NUMBER_OF_AVAILABLE_SAMPLES ,
23+ layer_inverse_covariance ,
2224)
2325from lucid .recipes .activation_atlas .layout import aligned_umap
2426from lucid .recipes .activation_atlas .render import render_icons
@@ -37,7 +39,7 @@ def activation_atlas(
3739
3840 activations = layer .activations [:number_activations , ...]
3941 layout , = aligned_umap (activations , verbose = verbose )
40- directions , coordinates , _ = _bin_laid_out_activations (
42+ directions , coordinates , _ = bin_laid_out_activations (
4143 layout , activations , grid_size
4244 )
4345 icons = []
@@ -46,7 +48,7 @@ def activation_atlas(
4648 directions_batch , model , layer = layer .name , size = icon_size , num_attempts = 1
4749 )
4850 icons += icon_batch
49- canvas = _make_canvas (icons , coordinates , grid_size )
51+ canvas = make_canvas (icons , coordinates , grid_size )
5052
5153 return canvas
5254
@@ -57,36 +59,43 @@ def aligned_activation_atlas(
5759 model2 ,
5860 layer2 ,
5961 grid_size = 10 ,
60- icon_size = 96 ,
62+ icon_size = 80 ,
63+ num_steps = 1024 ,
64+ whiten_layers = False ,
6165 number_activations = NUMBER_OF_AVAILABLE_SAMPLES ,
6266 verbose = False ,
6367):
68+ """Renders two aligned Activation Atlases of the given models' layers.
69+
70+ Returns a generator of the two atlasses, and a nested generator for intermediate
71+ atlasses while they're being rendered.
72+ """
6473 combined_activations = _combine_activations (
6574 layer1 , layer2 , number_activations = number_activations
6675 )
6776 layouts = aligned_umap (combined_activations , verbose = verbose )
6877
69- atlasses = []
7078 for model , layer , layout in zip ((model1 , model2 ), (layer1 , layer2 ), layouts ):
71- directions , coordinates , densities = _bin_laid_out_activations (
72- layout , layer .activations [:number_activations , ...], grid_size
79+ directions , coordinates , densities = bin_laid_out_activations (
80+ layout , layer .activations [:number_activations , ...], grid_size , threshold = 10
7381 )
74- icons = []
75- for directions_batch in batch (directions , batch_size = 64 ):
76- icon_batch , losses = render_icons (
77- directions_batch ,
78- model ,
79- alpha = False ,
80- layer = layer .name ,
81- size = icon_size ,
82- num_attempts = 1 ,
83- n_steps = 1024 ,
84- )
85- icons += icon_batch
86- canvas = _make_canvas (icons , coordinates , grid_size )
87- atlasses .append (canvas )
88-
89- return atlasses
82+
83+ def _progressive_canvas_iterator ():
84+ icons = []
85+ for directions_batch in batch (directions , batch_size = 32 , as_list = True ):
86+ icon_batch , losses = render_icons (
87+ directions_batch ,
88+ model ,
89+ alpha = False ,
90+ layer = layer .name ,
91+ size = icon_size ,
92+ n_steps = num_steps ,
93+ S = layer_inverse_covariance (layer ) if whiten_layers else None ,
94+ )
95+ icons += icon_batch
96+ yield make_canvas (icons , coordinates , grid_size )
97+
98+ yield _progressive_canvas_iterator ()
9099
91100
92101# Helpers
@@ -100,6 +109,8 @@ class ActivationTranslation(Enum):
100109def _combine_activations (
101110 layer1 ,
102111 layer2 ,
112+ activations1 = None ,
113+ activations2 = None ,
103114 mode = ActivationTranslation .BIDIRECTIONAL ,
104115 number_activations = NUMBER_OF_AVAILABLE_SAMPLES ,
105116):
@@ -114,8 +125,8 @@ def _combine_activations(
114125 into the space of layer 1, concatenate them along their channels, and returns a
115126 tuple of the concatenated activations for each layer.
116127 """
117- activations1 = layer1 .activations [:number_activations , ...]
118- activations2 = layer2 .activations [:number_activations , ...]
128+ activations1 = activations1 or layer1 .activations [:number_activations , ...]
129+ activations2 = activations2 or layer2 .activations [:number_activations , ...]
119130
120131 if mode is ActivationTranslation .ONE_TO_TWO :
121132
@@ -133,10 +144,10 @@ def _combine_activations(
133144 return activations_model1 , activations_model2
134145
135146
136- def _bin_laid_out_activations (layout , activations , grid_size , threshold = 5 ):
147+ def bin_laid_out_activations (layout , activations , grid_size , threshold = 5 ):
137148 """Given a layout and activations, overlays a grid on the layout and returns
138149 averaged activations for each grid cell. If a cell contains less than `threshold`
139- activations it will not be used , so the number of returned directions is variable."""
150+ activations it will be discarded , so the number of returned data is variable."""
140151
141152 assert layout .shape [0 ] == activations .shape [0 ]
142153
@@ -151,28 +162,30 @@ def _bin_laid_out_activations(layout, activations, grid_size, threshold=5):
151162
152163 # iterate over all grid cell coordinates to compute their average directions
153164 grid_coordinates = np .indices ((grid_size , grid_size )).transpose ().reshape (- 1 , 2 )
154- for xy in grid_coordinates :
155- mask = np .equal (xy , indices ).all (axis = 1 )
165+ for xy_coordinates in grid_coordinates :
166+ mask = np .equal (xy_coordinates , indices ).all (axis = 1 )
156167 count = np .count_nonzero (mask )
157168 if count > threshold :
158169 counts .append (count )
159- coordinates .append (xy )
170+ coordinates .append (xy_coordinates )
160171 mean = np .average (activations [mask ], axis = 0 )
161172 means .append (mean )
162173
163174 assert len (means ) == len (coordinates ) == len (counts )
175+ if len (coordinates ) == 0 :
176+ raise RuntimeError ("Binning activations led to 0 cells containing activations!" )
164177
165- return np . array ( means ), np . array ( coordinates ), np . array ( counts )
178+ return means , coordinates , counts
166179
167180
168- def _make_canvas (icon_batch , coordinates , grid_size ):
181+ def make_canvas (icon_batch , coordinates , grid_size ):
169182 """Given a list of images and their coordinates, places them on a white canvas."""
170183
171184 grid_shape = (grid_size , grid_size )
172185 icon_shape = icon_batch [0 ].shape
173186 canvas = np .ones ((* grid_shape , * icon_shape ))
174187
175- for (x , y ), icon in zip (coordinates , icon_batch ):
188+ for icon , (x , y ) in zip (icon_batch , coordinates ):
176189 canvas [x , y ] = icon
177190
178191 return np .hstack (np .hstack (canvas ))
0 commit comments