Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 153 additions & 5 deletions experiments/causal_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,148 @@ def patch_rep(x, layer):
return probs


def generate_test_matrix(truth_matrix, noise_matrix, prior_cutoff=None):
if prior_cutoff is None:
next_cutoff = truth_matrix.max()
else:
prior_cutoff_matrix = truth_matrix[truth_matrix < prior_cutoff]
next_cutoff = prior_cutoff_matrix.max()
new_truth_mask = truth_matrix.ge(next_cutoff) # ge = greater than or equal to
new_noise_mask = ~new_truth_mask # tilde means not, inverts boolean mask
new_sample_matrix = torch.mul(truth_matrix, new_truth_mask) + torch.mul(noise_matrix, new_noise_mask)
return new_sample_matrix, new_truth_mask, next_cutoff


def trace_neurons_with_patch(
model, # The model
inp, # A set of inputs
states_to_patch, # A list of (token index, layername) triples to restore
answers_t, # Answer probabilities to collect
tokens_to_mix, # Range of tokens to corrupt (begin, end)
noise=0.1, # Level of noise to add
trace_layers=None # List of traced outputs to return
):
prng = numpy.random.RandomState(1) # For reproducibility, use pseudorandom noise
patch_spec = defaultdict(list)
if len(states_to_patch) > 1:
print('error, can only patch one state at a time in trace_neurons mode.')
print('received states_to_patch:', states_to_patch)
assert len(states_to_patch) == 1

for t, l in states_to_patch:
patch_spec[l].append(t)

def untuple(x):
return x[0] if isinstance(x, tuple) else x

# Define the model-patching rule.
def patch_rep(model, layer):
if layer == "transformer.wte":
# If requested, we corrupt a range of token embeddings on batch items x[1:]
if tokens_to_mix is not None:
b, e = tokens_to_mix
model[1:, b:e] += noise * torch.from_numpy(
prng.randn(model.shape[0] - 1, e - b, model.shape[2])
).to(model.device)
return model
if layer not in patch_spec:
return model
# If this layer is in the patch_spec, restore the uncorrupted hidden state
# for selected tokens.
h = untuple(model)
for t in patch_spec[layer]:
h[1:, t] = h[0, t]
return model



additional_layers = [] if trace_layers is None else trace_layers
with torch.no_grad(), nethook.TraceDict(
model,
["transformer.wte"] +
list(patch_spec.keys()) + additional_layers,
edit_output=patch_rep
) as td:
outputs_exp = model(**inp)

# We report softmax probabilities for the answers_t token predictions of interest.
baseline_probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]


# failure state:
# cutoff == matrix.min would mean no more iterations to do
# success state:
# probs exceed prob_threshold
prob_threshold = 0.95 * baseline_probs
probs = 0 * baseline_probs
cutoffs = {} # there will be a different cutoff for each h[0, t] (layer)
# so we have to track a dict of them
truth_masks = {}

while probs < prob_threshold:

# Define the neuron-level model-patching rule for the current loop.
def patch_rep_neurons(x, layer):
if layer == "transformer.wte":
# If requested, we corrupt a range of token embeddings on batch items x[1:]
if tokens_to_mix is not None:
b, e = tokens_to_mix
x[1:, b:e] += noise * torch.from_numpy(
prng.randn(x.shape[0] - 1, e - b, x.shape[2])
).to(x.device)
return x
if layer not in patch_spec:
return x
# If this layer is in the patch_spec, restore the uncorrupted hidden state
# for selected tokens.
h = untuple(x)
for t in patch_spec[layer]:
truth_matrix = h[0, t]
for i, noise_matrix in enumerate(h[1:, t]):
layer_batch_id = f'{layer}_{i}'
sample_matrix, truth_mask, new_cutoff = generate_test_matrix(
truth_matrix, noise_matrix, prior_cutoff=cutoffs.get(layer_batch_id))
cutoffs[layer_batch_id] = new_cutoff
truth_masks[layer_batch_id] = truth_mask
h[i + 1, t] = sample_matrix
return x

with torch.no_grad(), nethook.TraceDict(
model,
["transformer.wte"] +
list(patch_spec.keys()) + additional_layers,
edit_output=patch_rep_neurons
) as td:
outputs_exp = model(**inp)


# We report softmax probabilities for the answers_t token predictions of interest.
probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

mask_size = 0
for k,v in truth_masks.items():
mask_size = v.shape[0]
break
targets = len(list(truth_masks.keys()))
neuron_probs = torch.as_tensor(np.zeros([targets * mask_size]))
for k,v in truth_masks.items():
split_key = k.split('_')
target_layer = split_key[0]
target_token = int(split_key[1])
neuron_probs[target_token * mask_size:(1 + target_token) * mask_size] = v * probs

neuron_probs = torch.reshape(neuron_probs, [targets, mask_size]).mean(dim=0)

# If tracing all layers, collect all activations together to return.
if trace_layers is not None:
print('warning, trace_layers option not implemented yet for neuron trace. may fail.')
all_traced = torch.stack(
[untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2)
return neuron_probs, all_traced

return neuron_probs


def trace_with_repatch(
model, # The model
inp, # A set of inputs
Expand Down Expand Up @@ -208,7 +350,7 @@ def patch_rep(x, layer):


def calculate_hidden_flow(
mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None
mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None, neuron_trace=False
):
"""
Runs causal tracing over every token/layer combination in the network
Expand Down Expand Up @@ -236,6 +378,7 @@ def calculate_hidden_flow(
noise=noise,
window=window,
kind=kind,
neuron_trace=neuron_trace
)
differences = differences.detach().cpu()
return dict(
Expand Down Expand Up @@ -271,7 +414,7 @@ def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1)


def trace_important_window(
model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1
model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1, neuron_trace=False
):
ntoks = inp["input_ids"].shape[1]
table = []
Expand All @@ -284,9 +427,14 @@ def trace_important_window(
max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
)
]
r = trace_with_patch(
model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
)
if neuron_trace:
r = trace_neurons_with_patch(
model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
)
else:
r = trace_with_patch(
model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
)
row.append(r)
table.append(torch.stack(row))
return torch.stack(table)
Expand Down
66 changes: 66 additions & 0 deletions util/neuron_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from math import ceil
from numpy import zeros, ones, reshape
import matplotlib.pyplot as plt


def save_layer_edges(layer_num, layer_edges, save_path):
# write out the list of strings
filename = f'layer_{layer_num}_edges.txt'
with open(save_path + '/' + filename, 'a+') as outf:
outf.write('\n'.join(layer_edges))
outf.write('\n')


def process_neuron_result(context_string, context_id, neuron_result, save_path='', layer_average=False, save_edges=False):
token_total = neuron_result['scores'].shape[0]
layer_total = neuron_result['scores'].shape[1]
layer_width = 80
layer_shape = [int(ceil(neuron_result['scores'][0]
[0].flatten().shape[0] / layer_width)), layer_width]
total_matrix_shape = [layer_shape[0] *
layer_total, layer_shape[1] * token_total]
total_matrix = zeros(total_matrix_shape)

for t in range(token_total):
for l in range(layer_total):
layer_edges = []
if layer_average:
layer_mean = float(neuron_result['scores'][t][l].mean(
)) + float(neuron_result['scores'][t][l].max())
layer_matrix = ones(layer_shape) * layer_mean
else:
layer_matrix = zeros([layer_shape[0] * layer_shape[1]])
flat_scores = neuron_result['scores'][t][l].flatten()
true_size = flat_scores.shape[0]
layer_matrix[:true_size] = flat_scores
layer_matrix = reshape(layer_matrix, layer_shape)
if save_edges:
for i, w in enumerate(flat_scores):
if w != 0:
# formatting to work with networkx lib, weighted edge list
layer_edges.append(
f'layer_{l};{neuron_result["kind"]};n_{i} token_{t};context_{context_id} {w} # {context_string}')
save_layer_edges(l, layer_edges, save_path)
row = l * layer_shape[0]
col = t * layer_shape[1]
row_stop = row + layer_matrix.shape[0]
col_stop = col + layer_matrix.shape[1]
total_matrix[row:row_stop, col:col_stop] = layer_matrix
layer_labels = list(range(0, layer_total + 1))
return total_matrix, layer_shape, layer_labels


def make_matrix_plot(total_matrix, layer_shape, layer_labels, input_tokens):
fig, ax = plt.subplots(figsize=(15, 20))
mat = ax.matshow(total_matrix, cmap='Reds', vmin=total_matrix.min())
ax.grid(True, alpha=0.15)
# ax.invert_yaxis()
ax.set_yticks([i for i in range(0, total_matrix.shape[0], layer_shape[0])])
ax.set_xticks(
[0.5 + i for i in range(0, total_matrix.shape[1], layer_shape[1])])
ax.set_xticklabels(input_tokens) # differences.shape[1] - 6, 5)))
ax.set_yticklabels(layer_labels)
ax.set_ylabel('model layer')
ax.set_xlabel('context tokens')
ax.tick_params(axis='y', labelsize='8')
plt.show()