diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 61e51e96..2b2e5de7 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -90,7 +90,8 @@ def __init__(self, config, model, backend, **kwargs): # mapping between supported abstract type and module name. ### self.representations = {} - self.interventions = {} + self.interventions = torch.nn.ModuleDict({}) + self.intervention_hooks = {} self._key_collision_counter = {} self.return_collect_activations = False # Flags and counters below are for interventions in the model.generate @@ -116,6 +117,7 @@ def __init__(self, config, model, backend, **kwargs): config.representations ): _key = self._get_representation_key(representation) + print(f"Intervention key: {_key}") if representation.intervention is not None: intervention = representation.intervention @@ -164,7 +166,11 @@ def __init__(self, config, model, backend, **kwargs): model, representation, backend ) self.representations[_key] = representation - self.interventions[_key] = (intervention, module_hook) + if isinstance(intervention, types.FunctionType): + self.interventions[_key] = LambdaIntervention(intervention) + else: + self.interventions[_key] = intervention + self.intervention_hooks[_key] = module_hook self._key_getter_call_counter[ _key ] = 0 # we memo how many the hook is called, @@ -266,11 +272,13 @@ def _get_representation_key(self, representation): c = representation.component u = representation.unit n = representation.max_number_of_units + _u = u.replace(".", "_") # this will need internal functions to be changed as well. if "." in c: + _c = c.replace(".", "_") # string access for sure - key_proposal = f"comp.{c}.unit.{u}.nunit.{n}" + key_proposal = f"comp_{_c}_unit_{_u}_nunit_{n}" else: - key_proposal = f"layer.{l}.comp.{c}.unit.{u}.nunit.{n}" + key_proposal = f"layer_{l}_comp_{c}_unit_{_u}_nunit_{n}" if key_proposal not in self._key_collision_counter: self._key_collision_counter[key_proposal] = 0 else: @@ -283,8 +291,8 @@ def get_trainable_parameters(self): """ ret_params = [] for k, v in self.interventions.items(): - if isinstance(v[0], TrainableIntervention): - ret_params += [p for p in v[0].parameters()] + if isinstance(v, TrainableIntervention): + ret_params += [p for p in v.parameters()] for p in self.model.parameters(): if p.requires_grad: ret_params += [p] @@ -296,8 +304,8 @@ def named_parameters(self, recurse=True): """ ret_params = [] for k, v in self.interventions.items(): - if isinstance(v[0], TrainableIntervention): - ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()] + if isinstance(v, TrainableIntervention): + ret_params += [(k + '.' + n, p) for n, p in v.named_parameters()] for n, p in self.model.named_parameters(): if p.requires_grad: ret_params += [('model.' + n, p)] @@ -320,9 +328,9 @@ def set_temperature(self, temp: torch.Tensor): Set temperature if needed """ for k, v in self.interventions.items(): - if isinstance(v[0], BoundlessRotatedSpaceIntervention) or \ - isinstance(v[0], SigmoidMaskIntervention): - v[0].set_temperature(temp) + if isinstance(v, BoundlessRotatedSpaceIntervention) or \ + isinstance(v, SigmoidMaskIntervention): + v.set_temperature(temp) def enable_model_gradients(self): """ @@ -356,7 +364,7 @@ def set_device(self, device, set_model=True): Set device of interventions and the model """ for k, v in self.interventions.items(): - v[0].to(device) + v.to(device) if set_model: self.model.to(device) @@ -373,13 +381,13 @@ def count_parameters(self, include_model=False): _linked_key_set = set([]) total_parameters = 0 for k, v in self.interventions.items(): - if isinstance(v[0], TrainableIntervention): + if isinstance(v, TrainableIntervention): if k in self._intervention_reverse_link: if not self._intervention_reverse_link[k] in _linked_key_set: _linked_key_set.add(self._intervention_reverse_link[k]) - total_parameters += count_parameters(v[0]) + total_parameters += count_parameters(v) else: - total_parameters += count_parameters(v[0]) + total_parameters += count_parameters(v) if include_model: total_parameters += sum( p.numel() for p in self.model.parameters() if p.requires_grad) @@ -390,16 +398,16 @@ def set_zero_grad(self): Set device of interventions and the model """ for k, v in self.interventions.items(): - if isinstance(v[0], TrainableIntervention): - v[0].zero_grad() + if isinstance(v, TrainableIntervention): + v.zero_grad() def zero_grad(self): """ The above, but for HuggingFace. """ for k, v in self.interventions.items(): - if isinstance(v[0], TrainableIntervention): - v[0].zero_grad() + if isinstance(v, TrainableIntervention): + v.zero_grad() def _input_validation( self, @@ -758,7 +766,8 @@ def _intervention_getter( """ handlers = [] for key_i, key in enumerate(keys): - intervention, (module_hook, hook_type) = self.interventions[key] + intervention = self.interventions[key] + (module_hook, hook_type) = self.intervention_hooks[key] if self._is_generation: raise NotImplementedError("Generation is not implemented for ndif backend") @@ -803,7 +812,8 @@ def _intervention_setter( self._tidy_stateful_activations() for key_i, key in enumerate(keys): - intervention, (module_hook, hook_type) = self.interventions[key] + intervention = self.interventions[key] + (module_hook, hook_type) = self.intervention_hooks[key] if unit_locations_base[0] is not None: self._batched_setter_activation_select[key] = [ 0 for _ in range(len(unit_locations_base[0])) @@ -846,7 +856,7 @@ def _intervention_setter( # no-op to the output else: - if not isinstance(self.interventions[key][0], types.FunctionType): + if not isinstance(self.interventions[key], LambdaIntervention): if intervention.is_source_constant: intervened_representation = do_intervention( selected_output, @@ -944,8 +954,8 @@ def _sync_forward_with_parallel_intervention( for key in keys: # skip in case smart jump if key in self.activations or \ - isinstance(self.interventions[key][0], types.FunctionType) or \ - self.interventions[key][0].is_source_constant: + isinstance(self.interventions[key], LambdaIntervention) or \ + self.interventions[key].is_source_constant: self._intervention_setter( [key], [ @@ -1056,7 +1066,7 @@ def forward( if self.return_collect_activations: for key in self.sorted_keys: if isinstance( - self.interventions[key][0], + self.interventions[key], CollectIntervention ): collected_activations += self.activations[key].clone() @@ -1191,7 +1201,7 @@ def save( serialized_representations for k, v in self.interventions.items(): - intervention = v[0] + intervention = v saving_config.intervention_types += [str(type(intervention))] binary_filename = f"intkey_{k}.bin" # save intervention binary file @@ -1288,7 +1298,7 @@ def load( # load binary files for i, (k, v) in enumerate(intervenable.interventions.items()): - intervention = v[0] + intervention = v binary_filename = f"intkey_{k}.bin" intervention.is_source_constant = \ saving_config.intervention_constant_sources[i] @@ -1334,7 +1344,7 @@ def save_intervention(self, save_directory, include_model=True): # save binary files for k, v in self.interventions.items(): - intervention = v[0] + intervention = v binary_filename = f"intkey_{k}.bin" # save intervention binary file if isinstance(intervention, TrainableIntervention): @@ -1357,7 +1367,7 @@ def load_intervention(self, load_directory, include_model=True): """ # load binary files for i, (k, v) in enumerate(self.interventions.items()): - intervention = v[0] + intervention = v binary_filename = f"intkey_{k}.bin" if isinstance(intervention, TrainableIntervention): saved_state_dict = torch.load(os.path.join(load_directory, binary_filename)) @@ -1379,7 +1389,8 @@ def _intervention_getter( """ handlers = [] for key_i, key in enumerate(keys): - intervention, module_hook = self.interventions[key] + intervention = self.interventions[key] + module_hook = self.intervention_hooks[key] def hook_callback(model, args, kwargs, output=None): if self._is_generation: @@ -1524,7 +1535,8 @@ def _intervention_setter( handlers = [] for key_i, key in enumerate(keys): - intervention, module_hook = self.interventions[key] + intervention = self.interventions[key] + module_hook = self.intervention_hooks[key] if unit_locations_base[0] is not None: self._batched_setter_activation_select[key] = [ 0 for _ in range(len(unit_locations_base[0])) @@ -1570,7 +1582,7 @@ def hook_callback(model, args, kwargs, output=None): # no-op to the output else: - if not isinstance(self.interventions[key][0], types.FunctionType): + if not isinstance(self.interventions[key], LambdaIntervention): if intervention.is_source_constant: raw_intervened_representation = do_intervention( selected_output, @@ -1710,8 +1722,8 @@ def _wait_for_forward_with_parallel_intervention( for key in keys: # skip in case smart jump if key in self.activations or \ - isinstance(self.interventions[key][0], types.FunctionType) or \ - self.interventions[key][0].is_source_constant: + isinstance(self.interventions[key], LambdaIntervention) or \ + self.interventions[key].is_source_constant: set_handlers = self._intervention_setter( [key], [ @@ -1780,8 +1792,8 @@ def _wait_for_forward_with_serial_intervention( for key in keys: # skip in case smart jump if key in self.activations or \ - isinstance(self.interventions[key][0], types.FunctionType) or \ - self.interventions[key][0].is_source_constant: + isinstance(self.interventions[key], LambdaIntervention) or \ + self.interventions[key].is_source_constant: # set with intervened activation to source_i+1 set_handlers = self._intervention_setter( [key], @@ -1947,7 +1959,7 @@ def forward( if self.return_collect_activations: for key in self.sorted_keys: if isinstance( - self.interventions[key][0], + self.interventions[key], CollectIntervention ): collected_activations += self.activations[key] @@ -2081,7 +2093,7 @@ def generate( if self.return_collect_activations: for key in self.sorted_keys: if isinstance( - self.interventions[key][0], + self.interventions[key], CollectIntervention ): collected_activations += self.activations[key] diff --git a/pyvene/models/modeling_utils.py b/pyvene/models/modeling_utils.py index 89d1b402..9d5f2afc 100644 --- a/pyvene/models/modeling_utils.py +++ b/pyvene/models/modeling_utils.py @@ -6,6 +6,21 @@ from .constants import * +class LambdaIntervention(torch.nn.Module): + """ + A generic wrapper to turn any Python callable (e.g. a lambda) + into an nn.Module. This does *not* automatically turn external + Tensors into parameters or buffers—it's just a functional wrapper. + """ + def __init__(self, func): + super().__init__() + self.func = func # store the lambda or any callable + + def forward(self, *args, **kwargs): + # Simply call the stored function + return self.func(*args, **kwargs) + + def get_internal_model_type(model): """Return the model type.""" return type(model) @@ -435,7 +450,7 @@ def do_intervention( ): """Do the actual intervention.""" - if isinstance(intervention, types.FunctionType): + if isinstance(intervention, LambdaIntervention): if subspaces is None: return intervention(base_representation, source_representation) else: diff --git a/pyvene_101.ipynb b/pyvene_101.ipynb index 18817bd6..9d8913d6 100644 --- a/pyvene_101.ipynb +++ b/pyvene_101.ipynb @@ -2715,7 +2715,7 @@ "# zero-out grads\n", "_ = pv_gpt2.model.eval()\n", "for k, v in pv_gpt2.interventions.items():\n", - " v[0].zero_grad()\n", + " v.zero_grad()\n", "\n", "original_outputs, counterfactual_outputs = pv_gpt2(\n", " base, \n", diff --git a/setup.py b/setup.py index db149699..5397734e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="pyvene", - version="0.1.6", + version="0.1.7dev", description="Use Activation Intervention to Interpret Causal Mechanism of Model", long_description=long_description, long_description_content_type='text/markdown', diff --git a/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py b/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py index 7900ce6c..66b1fa5d 100644 --- a/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py +++ b/tests/integration_tests/ComplexInterventionWithGPT2TestCase.py @@ -119,11 +119,8 @@ def _test_subspace_partition_in_forward(self, intervention_type): RotatedSpaceIntervention, LowRankRotatedSpaceIntervention, }: - list(fast.interventions.values())[0][ - 0 - ].rotate_layer.weight = list(intervenable.interventions.values())[0][ - 0 - ].rotate_layer.weight + list(fast.interventions.values())[0].rotate_layer.weight = \ + list(intervenable.interventions.values())[0].rotate_layer.weight _, without_partition_our_output = fast( base, diff --git a/tests/integration_tests/InterventionWithGPT2TestCase.py b/tests/integration_tests/InterventionWithGPT2TestCase.py index 1bc5fd94..9a0091b8 100644 --- a/tests/integration_tests/InterventionWithGPT2TestCase.py +++ b/tests/integration_tests/InterventionWithGPT2TestCase.py @@ -89,7 +89,7 @@ def _test_with_head_position_intervention( intervention_types=intervention_type, ) intervenable = IntervenableModel(config, self.gpt2) - intervention = list(intervenable.interventions.values())[0][0] + intervention = list(intervenable.interventions.values())[0] base_activations = {} source_activations = {} diff --git a/tests/integration_tests/InterventionWithLlamaTestCase.py b/tests/integration_tests/InterventionWithLlamaTestCase.py index aa586d24..514aee5c 100644 --- a/tests/integration_tests/InterventionWithLlamaTestCase.py +++ b/tests/integration_tests/InterventionWithLlamaTestCase.py @@ -88,7 +88,7 @@ def _test_with_head_position_intervention( intervention_types=intervention_type, ) intervenable = IntervenableModel(config, self.llama) - intervention = list(intervenable.interventions.values())[0][0] + intervention = list(intervenable.interventions.values())[0] base_activations = {} source_activations = {} diff --git a/tutorials/advanced_tutorials/Boundless_DAS.ipynb b/tutorials/advanced_tutorials/Boundless_DAS.ipynb index 7b40c536..7f6a84eb 100644 --- a/tutorials/advanced_tutorials/Boundless_DAS.ipynb +++ b/tutorials/advanced_tutorials/Boundless_DAS.ipynb @@ -422,8 +422,8 @@ "warm_up_steps = 0.1 * t_total\n", "optimizer_params = []\n", "for k, v in intervenable.interventions.items():\n", - " optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n", - " optimizer_params += [{\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n", + " optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n", + " optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n", "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", "scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n", @@ -470,7 +470,7 @@ " loss = loss_fct(shift_logits, shift_labels)\n", "\n", " for k, v in intervenable.interventions.items():\n", - " boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n", + " boundary_loss = 1.0 * v.intervention_boundaries.sum()\n", " loss += boundary_loss\n", "\n", " return loss" diff --git a/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb b/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb index 5115e93b..adb11ff5 100644 --- a/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb +++ b/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb @@ -1438,7 +1438,7 @@ "t_total = int(len(dataset) * epochs)\n", "optimizer_params = []\n", "for k, v in intervenable.interventions.items():\n", - " optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n", + " optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n", " break\n", "optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n", "\n", diff --git a/tutorials/advanced_tutorials/IOI_with_DAS.ipynb b/tutorials/advanced_tutorials/IOI_with_DAS.ipynb index d5b39266..f9e29ec8 100644 --- a/tutorials/advanced_tutorials/IOI_with_DAS.ipynb +++ b/tutorials/advanced_tutorials/IOI_with_DAS.ipynb @@ -7301,8 +7301,8 @@ ], "source": [ "intervention = boundless_das_intervenable.interventions[\n", - " \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + " \"layer_8_repr_attention_value_output_unit_pos_nunit_1#0\"\n", + "]\n", "boundary_mask = sigmoid_boundary(\n", " intervention.intervention_population.repeat(1, 1),\n", " 0.0,\n", @@ -12475,8 +12475,8 @@ ], "source": [ "intervention = das_intervenable.interventions[\n", - " \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + " \"layer_8_repr_attention_value_output_unit_pos_nunit_1#0\"\n", + "]\n", "learned_weights = intervention.rotate_layer.weight\n", "headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n", "\n", @@ -17400,8 +17400,8 @@ ], "source": [ "intervention = boundless_das_intervenable.interventions[\n", - " \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + " \"layer_9_repr_attention_value_output_unit_pos_nunit_1#0\"\n", + "]\n", "boundary_mask = sigmoid_boundary(\n", " intervention.intervention_population.repeat(1, 1),\n", " 0.0,\n", @@ -23343,8 +23343,8 @@ ], "source": [ "intervention = das_intervenable.interventions[\n", - " \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + " \"layer_9_repr_attention_value_output_unit_pos_nunit_1#0\"\n", + "]\n", "learned_weights = intervention.rotate_layer.weight\n", "headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n", "\n", diff --git a/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb b/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb index 53d727ca..aa7736c5 100644 --- a/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb +++ b/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb @@ -256,7 +256,7 @@ "def calculate_loss_with_mask(logits, labels, intervenable, coeff=1):\n", " loss = calculate_loss(logits, labels)\n", " for k, v in intervenable.interventions.items():\n", - " mask_loss = coeff * torch.norm(v[0].mask, 1)\n", + " mask_loss = coeff * torch.norm(v.mask, 1)\n", " loss += mask_loss\n", " return loss\n", "\n", @@ -363,8 +363,8 @@ " eval_preds += [counterfactual_outputs.logits]\n", "eval_metrics = compute_metrics(eval_preds, eval_labels)\n", "for k, v in pv_gpt2.interventions.items():\n", - " mask = v[0].mask\n", - " temperature = v[0].temperature\n", + " mask = v.mask\n", + " temperature = v.temperature\n", " break\n", "print(eval_metrics)\n", "print(\n", diff --git a/tutorials/advanced_tutorials/MQNLI.ipynb b/tutorials/advanced_tutorials/MQNLI.ipynb index fd83a5aa..ba6980ab 100644 --- a/tutorials/advanced_tutorials/MQNLI.ipynb +++ b/tutorials/advanced_tutorials/MQNLI.ipynb @@ -1402,7 +1402,7 @@ "\n", "optimizer_params = []\n", "for k, v in intervenable.interventions.items():\n", - " optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n", + " optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n", " break\n", "optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n", "\n", diff --git a/tutorials/advanced_tutorials/Probing_Gender.ipynb b/tutorials/advanced_tutorials/Probing_Gender.ipynb index 0bd32733..1a741ded 100644 --- a/tutorials/advanced_tutorials/Probing_Gender.ipynb +++ b/tutorials/advanced_tutorials/Probing_Gender.ipynb @@ -905,7 +905,7 @@ " optimizer_params = []\n", " for k, v in intervenable.interventions.items():\n", " try:\n", - " optimizer_params.append({\"params\": v[0].rotate_layer.parameters()})\n", + " optimizer_params.append({\"params\": v.rotate_layer.parameters()})\n", " except:\n", " pass\n", " optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", diff --git a/tutorials/advanced_tutorials/Voting_Mechanism.ipynb b/tutorials/advanced_tutorials/Voting_Mechanism.ipynb index 5553f34c..56a45cbf 100644 --- a/tutorials/advanced_tutorials/Voting_Mechanism.ipynb +++ b/tutorials/advanced_tutorials/Voting_Mechanism.ipynb @@ -330,9 +330,9 @@ "optimizer_params = []\n", "for k, v in pv_llama.interventions.items():\n", " optimizer_params += [\n", - " {\"params\": v[0].rotate_layer.parameters()}]\n", + " {\"params\": v.rotate_layer.parameters()}]\n", " optimizer_params += [\n", - " {\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n", + " {\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n", "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", "scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=warm_up_steps,\n", @@ -351,7 +351,7 @@ " loss = loss_fct(shift_logits, shift_labels)\n", "\n", " for k, v in pv_llama.interventions.items():\n", - " boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n", + " boundary_loss = 1.0 * v.intervention_boundaries.sum()\n", " loss += boundary_loss\n", "\n", " return loss" @@ -481,7 +481,7 @@ "source": [ "torch.save(\n", " pv_llama.interventions[\n", - " f\"layer.{layer}.comp.block_output.unit.pos.nunit.1#0\"][0].state_dict(), \n", + " f\"layer_{layer}_comp_block_output_unit_pos_nunit_1#0\"].state_dict(), \n", " f\"./tutorial_data/layer.{layer}.pos.{token_position}.bin\"\n", ")" ] @@ -522,9 +522,9 @@ "pv_llama = pv.IntervenableModel(pv_config, llama)\n", "pv_llama.set_device(\"cuda\")\n", "pv_llama.disable_model_gradients()\n", - "pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n", + "pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#0'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.75.bin'))\n", - "pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#1'][0].load_state_dict(\n", + "pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#1'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.80.bin'))" ] }, @@ -665,11 +665,11 @@ "for loc in [78, 75, 80, [75, 80]]:\n", " if loc == 78:\n", " print(\"[control] intervening location: \", loc)\n", - " pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n", + " pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#0'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.78.bin'))\n", " else:\n", " print(\"intervening location: \", loc)\n", - " pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n", + " pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#0'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.75.bin'))\n", " # evaluation on the test set\n", " collected_probs = []\n", @@ -1382,9 +1382,9 @@ "optimizer_params = []\n", "for k, v in pv_llama.interventions.items():\n", " optimizer_params += [\n", - " {\"params\": v[0].rotate_layer.parameters()}]\n", + " {\"params\": v.rotate_layer.parameters()}]\n", " optimizer_params += [\n", - " {\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n", + " {\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n", "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", "scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=warm_up_steps,\n", @@ -1403,7 +1403,7 @@ " loss = loss_fct(shift_logits, shift_labels)\n", "\n", " for k, v in pv_llama.interventions.items():\n", - " boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n", + " boundary_loss = 1.0 * v.intervention_boundaries.sum()\n", " loss += boundary_loss\n", "\n", " return loss" diff --git a/tutorials/advanced_tutorials/tutorial_ioi_utils.py b/tutorials/advanced_tutorials/tutorial_ioi_utils.py index cb81320b..2da1aed5 100644 --- a/tutorials/advanced_tutorials/tutorial_ioi_utils.py +++ b/tutorials/advanced_tutorials/tutorial_ioi_utils.py @@ -519,7 +519,7 @@ def single_d_low_rank_das_position_config( def calculate_boundless_das_loss(logits, labels, intervenable): loss = calculate_loss(logits, labels) for k, v in intervenable.interventions.items(): - boundary_loss = 2.0 * v[0].intervention_boundaries.sum() + boundary_loss = 2.0 * v.intervention_boundaries.sum() loss += boundary_loss return loss @@ -658,9 +658,9 @@ def find_variable_at( if do_boundless_das: optimizer_params = [] for k, v in intervenable.interventions.items(): - optimizer_params += [{"params": v[0].rotate_layer.parameters()}] + optimizer_params += [{"params": v.rotate_layer.parameters()}] optimizer_params += [ - {"params": v[0].intervention_boundaries, "lr": 0.5} + {"params": v.intervention_boundaries, "lr": 0.5} ] optimizer = torch.optim.Adam(optimizer_params, lr=initial_lr) target_total_step = int(len(D_train) / batch_size) * n_epochs @@ -759,9 +759,7 @@ def find_variable_at( temperature_schedule[total_step] ) for k, v in intervenable.interventions.items(): - intervention_boundaries = v[ - 0 - ].intervention_boundaries.sum() + intervention_boundaries = v.intervention_boundaries.sum() total_step += 1 # eval @@ -828,7 +826,7 @@ def find_variable_at( if do_boundless_das: for k, v in intervenable.interventions.items(): - intervention_boundaries = v[0].intervention_boundaries.sum() + intervention_boundaries = v.intervention_boundaries.sum() data.append( { "pos": aligning_pos, diff --git a/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb b/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb index 8f950656..2d1c4820 100644 --- a/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb +++ b/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb @@ -167,7 +167,7 @@ " )\n", " intervenable = IntervenableModel(config, gpt)\n", " for k, v in intervenable.interventions.items():\n", - " v[0].set_interchange_dim(768)\n", + " v.set_interchange_dim(768)\n", " for pos_i in range(len(base.input_ids[0])):\n", " _, counterfactual_outputs = intervenable(\n", " base,\n", @@ -194,7 +194,7 @@ " )\n", " intervenable = IntervenableModel(config, gpt)\n", " for k, v in intervenable.interventions.items():\n", - " v[0].set_interchange_dim(768)\n", + " v.set_interchange_dim(768)\n", " for pos_i in range(len(base.input_ids[0])):\n", " _, counterfactual_outputs = intervenable(\n", " base,\n",