Skip to content

Commit 6416870

Browse files
committed
Move computations to tensorflow and vectorize to first degree
1 parent f98d1e1 commit 6416870

File tree

3 files changed

+80
-87
lines changed

3 files changed

+80
-87
lines changed

agent.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from net import Net
66

77

8-
def do_plot(guided_backprops, indices, image, bounding_box, files_dict):
9-
top_grads = [guided_backprops[i] for i in indices]
10-
folder = list(files_dict.keys())[0]
11-
name = files_dict[folder][0]
12-
get_full_plot(top_grads, image, bounding_box, folder, name)
8+
def do_batch_plot(guided_backprops, indices, images, bounding_boxes, files_dict):
9+
folders_and_names = []
10+
for key, values in files_dict.items():
11+
for val in values:
12+
folders_and_names.append([key, val])
13+
for gbps, index, image, bb, fn in zip(guided_backprops, indices, images, bounding_boxes, folders_and_names):
14+
top_grads = [gbps[i] for i in index]
15+
get_full_plot(top_grads, image, bb, *fn)
1316

1417

1518
class Agent(object):
@@ -23,17 +26,17 @@ def __init__(self, config):
2326

2427
def get_bounding_box(self, image=None, files_dict=None):
2528
if image is None and files_dict is None:
26-
image, files_dict = sample_data(1)
29+
image, files_dict = sample_data(self.config.batch_size)
2730
preprocessed_image = self._preprocess(image.copy())
28-
kmax_neuron_indices, top_class = self._get_kmax_neurons(preprocessed_image)
31+
kmax_neuron_indices, top_classes_indices = self.net.get_top_kmax_neurons(preprocessed_image)
2932
guided_backprops = self._get_guided_backprops(preprocessed_image, kmax_neuron_indices)
30-
masks = self._get_image_masks(guided_backprops)
31-
topk_neurons_relative_indices = self._get_topk_neurons(preprocessed_image, masks, top_class)
32-
bounding_box, _ = self._get_bounding_box(masks, topk_neurons_relative_indices)
33+
masks = self._get_images_masks(guided_backprops)
34+
top_k_neurons_relative_indices = self._get_top_k_neurons(preprocessed_image, masks, top_classes_indices)
35+
bounding_boxes = self._get_all_bounding_boxes(masks, top_k_neurons_relative_indices)
3336
if self.config.do_plotting:
34-
do_plot(guided_backprops, topk_neurons_relative_indices,
35-
image, bounding_box, files_dict)
36-
return bounding_box
37+
do_batch_plot(guided_backprops, top_k_neurons_relative_indices,
38+
image, bounding_boxes, files_dict)
39+
return bounding_boxes
3740

3841
def make_tsne_pic_for_directory(self, folder='personal'):
3942
fc_features = None
@@ -48,52 +51,49 @@ def make_tsne_pic_for_directory(self, folder='personal'):
4851

4952
tsne = T_SNE()
5053
tsne_embedding = tsne.generate_tsne(fc_features)
51-
tsne.save_grid(images, tsne_embedding, folder+'.jpg')
54+
tsne.save_grid(images, tsne_embedding, folder + '.jpg')
5255

5356
def _preprocess(self, image):
5457
return self.net.vgg16.preprocess_input(image)
5558

56-
def _get_kmax_neurons(self, image):
57-
class_scores = self.net.get_class_scores(image)
58-
n = len(self.config.top_n_classes_weights)
59-
top_class = np.argmax(class_scores)
60-
top_classes = np.argsort(class_scores)[-n:]
61-
top_scores = class_scores[top_classes]
62-
activations = self.net.get_activations(image)
63-
64-
normalizer = 1.0 / (np.sum(self.config.top_n_classes_weights) * np.sum(top_scores))
65-
66-
impact_gradient = 0
67-
for i, class_index in enumerate(top_classes):
68-
impact_gradient_per_class = self.net.get_impact_gradient_per_class(activations=activations,
69-
class_index=class_index)
70-
impact_gradient += impact_gradient_per_class * top_scores[i] * \
71-
self.config.top_n_classes_weights[i] * normalizer
72-
73-
rank_score = activations * impact_gradient
74-
kmax_neurons = np.argsort(rank_score.ravel())[-self.config.kmax:]
75-
return kmax_neurons, top_class
76-
77-
def _get_guided_backprops(self, image, neuron_indices):
78-
guided_backprops = [self.net.get_guided_backprop(image, neuron_index) for neuron_index in neuron_indices]
59+
def _get_guided_backprops(self, images, neuron_indices):
60+
# lazy programming, this part should as well be vectorized
61+
guided_backprops = []
62+
for image, neuron_index in zip(images, neuron_indices):
63+
gbp = [self.net.get_guided_backprop(np.expand_dims(image, axis=0), ni)
64+
for ni in neuron_index]
65+
guided_backprops.append(gbp)
7966
return guided_backprops
8067

81-
def _get_image_masks(self, guided_backprops):
82-
projected_gbps = [np.max(gb, axis=-1).squeeze() for gb in guided_backprops]
83-
raw_masks = [pgbp > np.percentile(pgbp, self.config.cut_off_percentile) for pgbp in projected_gbps]
84-
# erosion and dilation
85-
masks = [binary_dilation(binary_erosion(raw_mask)).astype(projected_gbps[0].dtype) for raw_mask in raw_masks]
68+
def _get_images_masks(self, guided_backprops):
69+
masks = []
70+
for gbps in guided_backprops:
71+
projected_gbps = [np.max(gb, axis=-1).squeeze() for gb in gbps]
72+
raw_masks = [pgbp > np.percentile(pgbp, self.config.cut_off_percentile) for pgbp in projected_gbps]
73+
# erosion and dilation
74+
masks_per_image = [binary_dilation(binary_erosion(raw_mask)).astype(projected_gbps[0].dtype) for raw_mask in
75+
raw_masks]
76+
masks.append(masks_per_image)
8677
return masks
8778

88-
def _get_topk_neurons(self, image, masks, top_class):
89-
reshaped_image = image.reshape(np.roll(self.config.vgg16.input_size, 1))
90-
images = np.stack([np.reshape(reshaped_image * mask, self.config.vgg16.input_size) for mask in masks])
91-
losses = self.net.get_batch_loss(images, top_class)
92-
return list(np.argsort(losses)[:self.config.k])
93-
# return kmax_neuron_indices[np.argsort(losses)[:self.config.k]]
79+
def _get_top_k_neurons(self, images, masks, top_class):
80+
top_k_neurons_relative_indices = []
81+
for i, image in enumerate(images):
82+
reshaped_image = image.reshape(np.roll(self.config.vgg16.input_size, 1))
83+
masked_images = np.stack(
84+
[np.reshape(reshaped_image * mask, self.config.vgg16.input_size) for mask in masks[i]])
85+
losses = self.net.get_batch_loss(masked_images, top_class[i])
86+
top_k_neurons_relative_indices.append(list(np.argsort(losses)[:self.config.k]))
87+
return top_k_neurons_relative_indices
88+
89+
def _get_all_bounding_boxes(self, all_masks, all_mask_indices):
90+
bounding_boxes = []
91+
for mask, mask_indices in zip(all_masks, all_mask_indices):
92+
bounding_boxes.append(self._get_bounding_box(mask, mask_indices))
93+
return bounding_boxes
9494

9595
def _get_bounding_box(self, masks, mask_indices):
96-
# sorry, worst piece of code mankind has ever done
96+
# sorry, super lazy
9797
final_masks = np.array(masks)[mask_indices]
9898
the_mask = final_masks[0] * False
9999
for mask in final_masks:
@@ -108,4 +108,4 @@ def _get_bounding_box(self, masks, mask_indices):
108108
x_min = min(x_min, j)
109109
y_max = max(y_max, i)
110110
x_max = max(x_max, j)
111-
return [[x_min, y_min], [x_max, y_max]], the_mask
111+
return [[x_min, y_min], [x_max, y_max]]

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
def main():
77
agent = Agent(Config)
88
# agent.make_tsne_pic_for_directory()
9-
for i in range(100):
10-
bounding_box = agent.get_bounding_box()
9+
for i in range(3):
10+
agent.get_bounding_box()
1111

1212

1313
if __name__ == '__main__':

net.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _build_net(self):
2929
self._vgg16()
3030
self._fetch_shapes()
3131
self._build_impact_gradient()
32+
self._build_top_kmax_neuron_selection()
3233
self._build_guided_backprop()
3334
self._build_softmax_loss()
3435

@@ -39,33 +40,52 @@ def _vgg16(self):
3940
self.fc_features = self.model.get_layer('fc2').output
4041

4142
def _fetch_shapes(self):
42-
# lets not run on batch mode for now (gpu shortage and complexity)
4343
self.features_shape = self.feature_tensor.get_shape().as_list()
4444
self.classes_shape = self.model.output.get_shape().as_list()
4545
self.features_shape[0] = self.classes_shape[0] = 1
4646

4747
def _build_impact_gradient(self):
48-
self.category_index_ph = tf.placeholder(tf.int32, shape=[])
49-
fake_upstream_grad = tf.one_hot(self.category_index_ph, self.classes_shape[-1], axis=-1)
48+
self.category_indices = tf.argmax(self.model.output, axis=-1, name='max_scoring_categories')
49+
fake_upstream_grad = tf.one_hot(self.category_indices, self.classes_shape[-1], axis=-1)
5050
self.impact_grad = tf.gradients(self.model.output, self.feature_tensor,
51-
grad_ys=[fake_upstream_grad],
52-
name='impact_gradients')
51+
grad_ys=[fake_upstream_grad], name='impact_gradients')[0]
52+
53+
def _build_top_kmax_neuron_selection(self):
54+
# use DAM heuristic for selection
55+
neurons_effect = self.impact_grad * self.feature_tensor
56+
neurons_effect_flat_batch = tf.reshape(neurons_effect, (-1, np.prod(self.features_shape)))
57+
self.batch_top_kmax_neuron_indices = tf.nn.top_k(neurons_effect_flat_batch, k=self.config.kmax)[1]
5358

5459
def _build_guided_backprop(self):
5560
self.neuron_index = tf.placeholder(tf.int32, shape=[])
5661
fake_upstream_grad = tf.one_hot(self.neuron_index, np.prod(self.features_shape), axis=-1)
5762
fake_upstream_grad = tf.reshape(fake_upstream_grad, shape=self.features_shape)
5863
self.guided_backprop = tf.gradients(self.feature_tensor, self.model.input,
59-
grad_ys=[fake_upstream_grad],
60-
name='guided_backprop')
64+
grad_ys=[fake_upstream_grad], name='guided_backprop')
6165

6266
def _build_softmax_loss(self):
6367
self.top_class_index_ph = tf.placeholder(tf.int32, shape=[])
64-
self.top_class_batch_one_hot = tf.one_hot(tf.ones([self.config.kmax,], dtype=tf.int32) * self.top_class_index_ph,
68+
self.top_class_batch_one_hot = tf.one_hot(tf.ones([self.config.kmax, ], dtype=tf.int32) * self.top_class_index_ph,
6569
self.classes_shape[-1], axis=-1)
6670
self.softmax_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.top_class_batch_one_hot,
6771
logits=self.model.output)
6872

73+
def get_top_kmax_neurons(self, images):
74+
top_kmax_neurons_indices, max_scoring_indices = self.sess.run([self.batch_top_kmax_neuron_indices,
75+
self.category_indices],
76+
feed_dict={self.model.input: images,
77+
self.not_guided_flag: 1.0})
78+
return top_kmax_neurons_indices, max_scoring_indices
79+
80+
def get_guided_backprop(self, image, neuron_index):
81+
numerical_guided_backprop = self.sess.run(self.guided_backprop,
82+
feed_dict={
83+
self.model.input: image,
84+
self.neuron_index: neuron_index,
85+
self.not_guided_flag: 0.0,
86+
})
87+
return numerical_guided_backprop[0]
88+
6989
def get_batch_loss(self, images, top_class):
7090
assert images.shape[0] is self.config.kmax
7191
batch_loss = self.sess.run(self.softmax_loss,
@@ -75,35 +95,8 @@ def get_batch_loss(self, images, top_class):
7595
})
7696
return batch_loss
7797

78-
def get_class_scores(self, image):
79-
prediction_scores = np.squeeze(self.model.predict(image))
80-
return prediction_scores
81-
82-
def get_activations(self, image):
83-
activations = self.sess.run(self.feature_tensor,
84-
feed_dict={self.model.input: image})
85-
return activations
86-
8798
def get_fc_features(self, images):
8899
fc_features = self.sess.run(self.fc_features,
89100
feed_dict={self.model.input: images})
90101
return fc_features
91102

92-
def get_impact_gradient_per_class(self, activations, class_index):
93-
numerical_impact_grad = self.sess.run(self.impact_grad,
94-
feed_dict={
95-
self.feature_tensor: activations,
96-
self.category_index_ph: class_index,
97-
self.not_guided_flag: 1.0,
98-
})
99-
return numerical_impact_grad[0]
100-
101-
def get_guided_backprop(self, image, neuron_index):
102-
numerical_guided_backprop = self.sess.run(self.guided_backprop,
103-
feed_dict={
104-
self.model.input: image,
105-
self.neuron_index: neuron_index,
106-
self.not_guided_flag: 0.0,
107-
})
108-
return numerical_guided_backprop[0]
109-

0 commit comments

Comments
 (0)