-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRISE.py
More file actions
140 lines (123 loc) · 4.49 KB
/
RISE.py
File metadata and controls
140 lines (123 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import numpy as np
from skimage.transform import resize
import matplotlib.pyplot as plt
import os
def generate_masks(
M, s, p1, img_shape, oversize=0.1, crops_per_mask=8,
upsample_mode='bilinear', random_state=None
):
"""
RISE mask generation
Parameters
----------
M : int
Number of base low-res masks to generate.
s : int
Low-resolution mask size (s x s).
p1 : float
Probability of a low-res pixel being 1 (unmasked).
img_shape : tuple
Target image shape (H, W, C).
oversize : float
Fractional enlargement before random cropping.
crops_per_mask : int
Number of random crops to sample per upscaled mask.
upsample_mode : str
'bilinear' or 'nearest'.
random_state : int or None
Random seed.
Returns
-------
masks : np.ndarray
Array of shape (M * crops_per_mask, H, W) in [0,1].
"""
rng = np.random.RandomState(random_state)
H, W = img_shape[:2]
H_large = int(np.ceil(H * (1 + oversize)))
W_large = int(np.ceil(W * (1 + oversize)))
# Step 1: Generate low-res binary masks
low_res = rng.binomial(1, p1, size=(M, s, s)).astype(np.float32)
# Step 2: Upscale each to slightly larger than image
upscaled = np.zeros((M, H_large, W_large), dtype=np.float32)
for i in range(M):
upscaled[i] = resize(
low_res[i],
(H_large, W_large),
order=1 if upsample_mode == 'bilinear' else 0,
mode='reflect',
anti_aliasing=False
)
# Step 3: Random cropping reuse
masks = np.zeros((M * crops_per_mask, H, W), dtype=np.float32)
idx = 0
for i in range(M):
for _ in range(crops_per_mask):
dx = rng.randint(0, W_large - W + 1)
dy = rng.randint(0, H_large - H + 1)
masks[idx] = upscaled[i, dy:dy + H, dx:dx + W]
idx += 1
return masks
def RISE(model, img, masks, target_class):
"""
Compute RISE saliency map for a given image and model.
Parameters
----------
model : tf.keras.Model
Trained model with `predict()` method.
img : np.ndarray
Input image of shape (H, W, 3), preprocessed as model expects.
masks : np.ndarray
Array of shape (N, H, W) with float values in [0,1].
target_class : int
Class index to explain.
Returns
-------
saliency_map : np.ndarray
Normalized saliency map of shape (H, W).
"""
N, H, W = masks.shape
masked_imgs = img[None, ...] * masks[..., None] # broadcast to (N, H, W, 3)
preds = []
batch_size = 16
for i in range(0, N, batch_size):
batch = masked_imgs[i:i + batch_size]
preds.append(model(batch))
preds = np.concatenate(preds, axis=0)
# Extract target class scores
scores = preds[:, target_class]
# Weighted sum over all masks
saliency = np.tensordot(scores, masks, axes=(0, 0))
saliency = saliency / (N * np.mean(masks))
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
return saliency
def save_example_masks(masks, save_dir, num_samples=6):
"""
Save example masks to visualize RISE mask generation.
"""
os.makedirs(save_dir, exist_ok=True)
sample_idxs = np.linspace(0, len(masks)-1, num_samples, dtype=int)
fig, axes = plt.subplots(1, num_samples, figsize=(3*num_samples, 3))
for i, idx in enumerate(sample_idxs):
axes[i].imshow(masks[idx], cmap='gray')
axes[i].axis('off')
axes[i].set_title(f"Mask {idx}")
plt.tight_layout()
save_path = os.path.join(save_dir, "sample_masks.png")
plt.savefig(save_path, bbox_inches='tight', dpi=300)
plt.close()
def save_masked_examples(img, masks, save_dir, num_samples=4):
"""
Save examples of masked images.
"""
os.makedirs(save_dir, exist_ok=True)
sample_idxs = np.linspace(0, len(masks)-1, num_samples, dtype=int)
fig, axes = plt.subplots(1, num_samples, figsize=(4*num_samples, 4))
for i, idx in enumerate(sample_idxs):
masked_img = img * masks[idx][..., None]
axes[i].imshow(masked_img / 2 + 0.5)
axes[i].axis('off')
axes[i].set_title(f"Masked {idx}")
plt.tight_layout()
save_path = os.path.join(save_dir, "masked_examples.png")
plt.savefig(save_path, bbox_inches='tight', dpi=300)
plt.close()