Skip to content

Commit

Permalink
finish up
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Feb 3, 2025
1 parent 820553f commit bc07940
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
22 changes: 13 additions & 9 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(self, config, model, backend, **kwargs):
config.representations
):
_key = self._get_representation_key(representation)
print(f"Intervention key: {_key}")

if representation.intervention is not None:
intervention = representation.intervention
Expand Down Expand Up @@ -165,7 +166,10 @@ def __init__(self, config, model, backend, **kwargs):
model, representation, backend
)
self.representations[_key] = representation
self.interventions[_key] = intervention
if isinstance(intervention, types.FunctionType):
self.interventions[_key] = LambdaIntervention(intervention)
else:
self.interventions[_key] = intervention
self.intervention_hooks[_key] = module_hook
self._key_getter_call_counter[
_key
Expand Down Expand Up @@ -268,13 +272,13 @@ def _get_representation_key(self, representation):
c = representation.component
u = representation.unit
n = representation.max_number_of_units
_n = n.replace(".", "_") # this will need internal functions to be changed as well.
_u = u.replace(".", "_") # this will need internal functions to be changed as well.
if "." in c:
_c = c.replace(".", "_")
# string access for sure
key_proposal = f"comp_{_c}_unit_{u}_nunit_{_n}"
key_proposal = f"comp_{_c}_unit_{_u}_nunit_{n}"
else:
key_proposal = f"layer_{l}_comp_{c}_unit_{u}_nunit_{_n}"
key_proposal = f"layer_{l}_comp_{c}_unit_{_u}_nunit_{n}"
if key_proposal not in self._key_collision_counter:
self._key_collision_counter[key_proposal] = 0
else:
Expand Down Expand Up @@ -852,7 +856,7 @@ def _intervention_setter(
# no-op to the output

else:
if not isinstance(self.interventions[key], types.FunctionType):
if not isinstance(self.interventions[key], LambdaIntervention):
if intervention.is_source_constant:
intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -950,7 +954,7 @@ def _sync_forward_with_parallel_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key], types.FunctionType) or \
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
self._intervention_setter(
[key],
Expand Down Expand Up @@ -1578,7 +1582,7 @@ def hook_callback(model, args, kwargs, output=None):
# no-op to the output

else:
if not isinstance(self.interventions[key], types.FunctionType):
if not isinstance(self.interventions[key], LambdaIntervention):
if intervention.is_source_constant:
raw_intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -1718,7 +1722,7 @@ def _wait_for_forward_with_parallel_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key], types.FunctionType) or \
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
set_handlers = self._intervention_setter(
[key],
Expand Down Expand Up @@ -1788,7 +1792,7 @@ def _wait_for_forward_with_serial_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key], types.FunctionType) or \
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
# set with intervened activation to source_i+1
set_handlers = self._intervention_setter(
Expand Down
17 changes: 16 additions & 1 deletion pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
from .constants import *


class LambdaIntervention(torch.nn.Module):
"""
A generic wrapper to turn any Python callable (e.g. a lambda)
into an nn.Module. This does *not* automatically turn external
Tensors into parameters or buffers—it's just a functional wrapper.
"""
def __init__(self, func):
super().__init__()
self.func = func # store the lambda or any callable

def forward(self, *args, **kwargs):
# Simply call the stored function
return self.func(*args, **kwargs)


def get_internal_model_type(model):
"""Return the model type."""
return type(model)
Expand Down Expand Up @@ -435,7 +450,7 @@ def do_intervention(
):
"""Do the actual intervention."""

if isinstance(intervention, types.FunctionType):
if isinstance(intervention, LambdaIntervention):
if subspaces is None:
return intervention(base_representation, source_representation)
else:
Expand Down

0 comments on commit bc07940

Please sign in to comment.