diff --git a/experiments/causal_trace.py b/experiments/causal_trace.py index dd3c78d7..d3201597 100644 --- a/experiments/causal_trace.py +++ b/experiments/causal_trace.py @@ -149,6 +149,148 @@ def patch_rep(x, layer): return probs +def generate_test_matrix(truth_matrix, noise_matrix, prior_cutoff=None): + if prior_cutoff is None: + next_cutoff = truth_matrix.max() + else: + prior_cutoff_matrix = truth_matrix[truth_matrix < prior_cutoff] + next_cutoff = prior_cutoff_matrix.max() + new_truth_mask = truth_matrix.ge(next_cutoff) # ge = greater than or equal to + new_noise_mask = ~new_truth_mask # tilde means not, inverts boolean mask + new_sample_matrix = torch.mul(truth_matrix, new_truth_mask) + torch.mul(noise_matrix, new_noise_mask) + return new_sample_matrix, new_truth_mask, next_cutoff + + +def trace_neurons_with_patch( + model, # The model + inp, # A set of inputs + states_to_patch, # A list of (token index, layername) triples to restore + answers_t, # Answer probabilities to collect + tokens_to_mix, # Range of tokens to corrupt (begin, end) + noise=0.1, # Level of noise to add + trace_layers=None # List of traced outputs to return +): + prng = numpy.random.RandomState(1) # For reproducibility, use pseudorandom noise + patch_spec = defaultdict(list) + if len(states_to_patch) > 1: + print('error, can only patch one state at a time in trace_neurons mode.') + print('received states_to_patch:', states_to_patch) + assert len(states_to_patch) == 1 + + for t, l in states_to_patch: + patch_spec[l].append(t) + + def untuple(x): + return x[0] if isinstance(x, tuple) else x + + # Define the model-patching rule. + def patch_rep(model, layer): + if layer == "transformer.wte": + # If requested, we corrupt a range of token embeddings on batch items x[1:] + if tokens_to_mix is not None: + b, e = tokens_to_mix + model[1:, b:e] += noise * torch.from_numpy( + prng.randn(model.shape[0] - 1, e - b, model.shape[2]) + ).to(model.device) + return model + if layer not in patch_spec: + return model + # If this layer is in the patch_spec, restore the uncorrupted hidden state + # for selected tokens. + h = untuple(model) + for t in patch_spec[layer]: + h[1:, t] = h[0, t] + return model + + + + additional_layers = [] if trace_layers is None else trace_layers + with torch.no_grad(), nethook.TraceDict( + model, + ["transformer.wte"] + + list(patch_spec.keys()) + additional_layers, + edit_output=patch_rep + ) as td: + outputs_exp = model(**inp) + + # We report softmax probabilities for the answers_t token predictions of interest. + baseline_probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t] + + + # failure state: + # cutoff == matrix.min would mean no more iterations to do + # success state: + # probs exceed prob_threshold + prob_threshold = 0.95 * baseline_probs + probs = 0 * baseline_probs + cutoffs = {} # there will be a different cutoff for each h[0, t] (layer) + # so we have to track a dict of them + truth_masks = {} + + while probs < prob_threshold: + + # Define the neuron-level model-patching rule for the current loop. + def patch_rep_neurons(x, layer): + if layer == "transformer.wte": + # If requested, we corrupt a range of token embeddings on batch items x[1:] + if tokens_to_mix is not None: + b, e = tokens_to_mix + x[1:, b:e] += noise * torch.from_numpy( + prng.randn(x.shape[0] - 1, e - b, x.shape[2]) + ).to(x.device) + return x + if layer not in patch_spec: + return x + # If this layer is in the patch_spec, restore the uncorrupted hidden state + # for selected tokens. + h = untuple(x) + for t in patch_spec[layer]: + truth_matrix = h[0, t] + for i, noise_matrix in enumerate(h[1:, t]): + layer_batch_id = f'{layer}_{i}' + sample_matrix, truth_mask, new_cutoff = generate_test_matrix( + truth_matrix, noise_matrix, prior_cutoff=cutoffs.get(layer_batch_id)) + cutoffs[layer_batch_id] = new_cutoff + truth_masks[layer_batch_id] = truth_mask + h[i + 1, t] = sample_matrix + return x + + with torch.no_grad(), nethook.TraceDict( + model, + ["transformer.wte"] + + list(patch_spec.keys()) + additional_layers, + edit_output=patch_rep_neurons + ) as td: + outputs_exp = model(**inp) + + + # We report softmax probabilities for the answers_t token predictions of interest. + probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t] + + mask_size = 0 + for k,v in truth_masks.items(): + mask_size = v.shape[0] + break + targets = len(list(truth_masks.keys())) + neuron_probs = torch.as_tensor(np.zeros([targets * mask_size])) + for k,v in truth_masks.items(): + split_key = k.split('_') + target_layer = split_key[0] + target_token = int(split_key[1]) + neuron_probs[target_token * mask_size:(1 + target_token) * mask_size] = v * probs + + neuron_probs = torch.reshape(neuron_probs, [targets, mask_size]).mean(dim=0) + + # If tracing all layers, collect all activations together to return. + if trace_layers is not None: + print('warning, trace_layers option not implemented yet for neuron trace. may fail.') + all_traced = torch.stack( + [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2) + return neuron_probs, all_traced + + return neuron_probs + + def trace_with_repatch( model, # The model inp, # A set of inputs @@ -208,7 +350,7 @@ def patch_rep(x, layer): def calculate_hidden_flow( - mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None + mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None, neuron_trace=False ): """ Runs causal tracing over every token/layer combination in the network @@ -236,6 +378,7 @@ def calculate_hidden_flow( noise=noise, window=window, kind=kind, + neuron_trace=neuron_trace ) differences = differences.detach().cpu() return dict( @@ -271,7 +414,7 @@ def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1) def trace_important_window( - model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1 + model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1, neuron_trace=False ): ntoks = inp["input_ids"].shape[1] table = [] @@ -284,9 +427,14 @@ def trace_important_window( max(0, layer - window // 2), min(num_layers, layer - (-window // 2)) ) ] - r = trace_with_patch( - model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise - ) + if neuron_trace: + r = trace_neurons_with_patch( + model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise + ) + else: + r = trace_with_patch( + model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise + ) row.append(r) table.append(torch.stack(row)) return torch.stack(table) diff --git a/util/neuron_graph.py b/util/neuron_graph.py new file mode 100644 index 00000000..53295bfb --- /dev/null +++ b/util/neuron_graph.py @@ -0,0 +1,66 @@ +from math import ceil +from numpy import zeros, ones, reshape +import matplotlib.pyplot as plt + + +def save_layer_edges(layer_num, layer_edges, save_path): + # write out the list of strings + filename = f'layer_{layer_num}_edges.txt' + with open(save_path + '/' + filename, 'a+') as outf: + outf.write('\n'.join(layer_edges)) + outf.write('\n') + + +def process_neuron_result(context_string, context_id, neuron_result, save_path='', layer_average=False, save_edges=False): + token_total = neuron_result['scores'].shape[0] + layer_total = neuron_result['scores'].shape[1] + layer_width = 80 + layer_shape = [int(ceil(neuron_result['scores'][0] + [0].flatten().shape[0] / layer_width)), layer_width] + total_matrix_shape = [layer_shape[0] * + layer_total, layer_shape[1] * token_total] + total_matrix = zeros(total_matrix_shape) + + for t in range(token_total): + for l in range(layer_total): + layer_edges = [] + if layer_average: + layer_mean = float(neuron_result['scores'][t][l].mean( + )) + float(neuron_result['scores'][t][l].max()) + layer_matrix = ones(layer_shape) * layer_mean + else: + layer_matrix = zeros([layer_shape[0] * layer_shape[1]]) + flat_scores = neuron_result['scores'][t][l].flatten() + true_size = flat_scores.shape[0] + layer_matrix[:true_size] = flat_scores + layer_matrix = reshape(layer_matrix, layer_shape) + if save_edges: + for i, w in enumerate(flat_scores): + if w != 0: + # formatting to work with networkx lib, weighted edge list + layer_edges.append( + f'layer_{l};{neuron_result["kind"]};n_{i} token_{t};context_{context_id} {w} # {context_string}') + save_layer_edges(l, layer_edges, save_path) + row = l * layer_shape[0] + col = t * layer_shape[1] + row_stop = row + layer_matrix.shape[0] + col_stop = col + layer_matrix.shape[1] + total_matrix[row:row_stop, col:col_stop] = layer_matrix + layer_labels = list(range(0, layer_total + 1)) + return total_matrix, layer_shape, layer_labels + + +def make_matrix_plot(total_matrix, layer_shape, layer_labels, input_tokens): + fig, ax = plt.subplots(figsize=(15, 20)) + mat = ax.matshow(total_matrix, cmap='Reds', vmin=total_matrix.min()) + ax.grid(True, alpha=0.15) + # ax.invert_yaxis() + ax.set_yticks([i for i in range(0, total_matrix.shape[0], layer_shape[0])]) + ax.set_xticks( + [0.5 + i for i in range(0, total_matrix.shape[1], layer_shape[1])]) + ax.set_xticklabels(input_tokens) # differences.shape[1] - 6, 5))) + ax.set_yticklabels(layer_labels) + ax.set_ylabel('model layer') + ax.set_xlabel('context tokens') + ax.tick_params(axis='y', labelsize='8') + plt.show()