Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P0] Initiate the support of FSDP training (#205) #204

Merged
merged 9 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def __init__(self, config, model, backend, **kwargs):
# mapping between supported abstract type and module name.
###
self.representations = {}
self.interventions = {}
self.interventions = torch.nn.ModuleDict({})
self.intervention_hooks = {}
self._key_collision_counter = {}
self.return_collect_activations = False
# Flags and counters below are for interventions in the model.generate
Expand All @@ -116,6 +117,7 @@ def __init__(self, config, model, backend, **kwargs):
config.representations
):
_key = self._get_representation_key(representation)
print(f"Intervention key: {_key}")

if representation.intervention is not None:
intervention = representation.intervention
Expand Down Expand Up @@ -164,7 +166,11 @@ def __init__(self, config, model, backend, **kwargs):
model, representation, backend
)
self.representations[_key] = representation
self.interventions[_key] = (intervention, module_hook)
if isinstance(intervention, types.FunctionType):
self.interventions[_key] = LambdaIntervention(intervention)
else:
self.interventions[_key] = intervention
self.intervention_hooks[_key] = module_hook
self._key_getter_call_counter[
_key
] = 0 # we memo how many the hook is called,
Expand Down Expand Up @@ -266,11 +272,13 @@ def _get_representation_key(self, representation):
c = representation.component
u = representation.unit
n = representation.max_number_of_units
_u = u.replace(".", "_") # this will need internal functions to be changed as well.
if "." in c:
_c = c.replace(".", "_")
# string access for sure
key_proposal = f"comp.{c}.unit.{u}.nunit.{n}"
key_proposal = f"comp_{_c}_unit_{_u}_nunit_{n}"
else:
key_proposal = f"layer.{l}.comp.{c}.unit.{u}.nunit.{n}"
key_proposal = f"layer_{l}_comp_{c}_unit_{_u}_nunit_{n}"
if key_proposal not in self._key_collision_counter:
self._key_collision_counter[key_proposal] = 0
else:
Expand All @@ -283,8 +291,8 @@ def get_trainable_parameters(self):
"""
ret_params = []
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
ret_params += [p for p in v[0].parameters()]
if isinstance(v, TrainableIntervention):
ret_params += [p for p in v.parameters()]
for p in self.model.parameters():
if p.requires_grad:
ret_params += [p]
Expand All @@ -296,8 +304,8 @@ def named_parameters(self, recurse=True):
"""
ret_params = []
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()]
if isinstance(v, TrainableIntervention):
ret_params += [(k + '.' + n, p) for n, p in v.named_parameters()]
for n, p in self.model.named_parameters():
if p.requires_grad:
ret_params += [('model.' + n, p)]
Expand All @@ -320,9 +328,9 @@ def set_temperature(self, temp: torch.Tensor):
Set temperature if needed
"""
for k, v in self.interventions.items():
if isinstance(v[0], BoundlessRotatedSpaceIntervention) or \
isinstance(v[0], SigmoidMaskIntervention):
v[0].set_temperature(temp)
if isinstance(v, BoundlessRotatedSpaceIntervention) or \
isinstance(v, SigmoidMaskIntervention):
v.set_temperature(temp)

def enable_model_gradients(self):
"""
Expand Down Expand Up @@ -356,7 +364,7 @@ def set_device(self, device, set_model=True):
Set device of interventions and the model
"""
for k, v in self.interventions.items():
v[0].to(device)
v.to(device)
if set_model:
self.model.to(device)

Expand All @@ -373,13 +381,13 @@ def count_parameters(self, include_model=False):
_linked_key_set = set([])
total_parameters = 0
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
if isinstance(v, TrainableIntervention):
if k in self._intervention_reverse_link:
if not self._intervention_reverse_link[k] in _linked_key_set:
_linked_key_set.add(self._intervention_reverse_link[k])
total_parameters += count_parameters(v[0])
total_parameters += count_parameters(v)
else:
total_parameters += count_parameters(v[0])
total_parameters += count_parameters(v)
if include_model:
total_parameters += sum(
p.numel() for p in self.model.parameters() if p.requires_grad)
Expand All @@ -390,16 +398,16 @@ def set_zero_grad(self):
Set device of interventions and the model
"""
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
v[0].zero_grad()
if isinstance(v, TrainableIntervention):
v.zero_grad()

def zero_grad(self):
"""
The above, but for HuggingFace.
"""
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
v[0].zero_grad()
if isinstance(v, TrainableIntervention):
v.zero_grad()

def _input_validation(
self,
Expand Down Expand Up @@ -758,7 +766,8 @@ def _intervention_getter(
"""
handlers = []
for key_i, key in enumerate(keys):
intervention, (module_hook, hook_type) = self.interventions[key]
intervention = self.interventions[key]
(module_hook, hook_type) = self.intervention_hooks[key]
if self._is_generation:
raise NotImplementedError("Generation is not implemented for ndif backend")

Expand Down Expand Up @@ -803,7 +812,8 @@ def _intervention_setter(
self._tidy_stateful_activations()

for key_i, key in enumerate(keys):
intervention, (module_hook, hook_type) = self.interventions[key]
intervention = self.interventions[key]
(module_hook, hook_type) = self.intervention_hooks[key]
if unit_locations_base[0] is not None:
self._batched_setter_activation_select[key] = [
0 for _ in range(len(unit_locations_base[0]))
Expand Down Expand Up @@ -846,7 +856,7 @@ def _intervention_setter(
# no-op to the output

else:
if not isinstance(self.interventions[key][0], types.FunctionType):
if not isinstance(self.interventions[key], LambdaIntervention):
if intervention.is_source_constant:
intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -944,8 +954,8 @@ def _sync_forward_with_parallel_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key][0], types.FunctionType) or \
self.interventions[key][0].is_source_constant:
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
self._intervention_setter(
[key],
[
Expand Down Expand Up @@ -1056,7 +1066,7 @@ def forward(
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(
self.interventions[key][0],
self.interventions[key],
CollectIntervention
):
collected_activations += self.activations[key].clone()
Expand Down Expand Up @@ -1191,7 +1201,7 @@ def save(
serialized_representations

for k, v in self.interventions.items():
intervention = v[0]
intervention = v
saving_config.intervention_types += [str(type(intervention))]
binary_filename = f"intkey_{k}.bin"
# save intervention binary file
Expand Down Expand Up @@ -1288,7 +1298,7 @@ def load(

# load binary files
for i, (k, v) in enumerate(intervenable.interventions.items()):
intervention = v[0]
intervention = v
binary_filename = f"intkey_{k}.bin"
intervention.is_source_constant = \
saving_config.intervention_constant_sources[i]
Expand Down Expand Up @@ -1334,7 +1344,7 @@ def save_intervention(self, save_directory, include_model=True):

# save binary files
for k, v in self.interventions.items():
intervention = v[0]
intervention = v
binary_filename = f"intkey_{k}.bin"
# save intervention binary file
if isinstance(intervention, TrainableIntervention):
Expand All @@ -1357,7 +1367,7 @@ def load_intervention(self, load_directory, include_model=True):
"""
# load binary files
for i, (k, v) in enumerate(self.interventions.items()):
intervention = v[0]
intervention = v
binary_filename = f"intkey_{k}.bin"
if isinstance(intervention, TrainableIntervention):
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
Expand All @@ -1379,7 +1389,8 @@ def _intervention_getter(
"""
handlers = []
for key_i, key in enumerate(keys):
intervention, module_hook = self.interventions[key]
intervention = self.interventions[key]
module_hook = self.intervention_hooks[key]

def hook_callback(model, args, kwargs, output=None):
if self._is_generation:
Expand Down Expand Up @@ -1524,7 +1535,8 @@ def _intervention_setter(

handlers = []
for key_i, key in enumerate(keys):
intervention, module_hook = self.interventions[key]
intervention = self.interventions[key]
module_hook = self.intervention_hooks[key]
if unit_locations_base[0] is not None:
self._batched_setter_activation_select[key] = [
0 for _ in range(len(unit_locations_base[0]))
Expand Down Expand Up @@ -1570,7 +1582,7 @@ def hook_callback(model, args, kwargs, output=None):
# no-op to the output

else:
if not isinstance(self.interventions[key][0], types.FunctionType):
if not isinstance(self.interventions[key], LambdaIntervention):
if intervention.is_source_constant:
raw_intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -1710,8 +1722,8 @@ def _wait_for_forward_with_parallel_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key][0], types.FunctionType) or \
self.interventions[key][0].is_source_constant:
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
set_handlers = self._intervention_setter(
[key],
[
Expand Down Expand Up @@ -1780,8 +1792,8 @@ def _wait_for_forward_with_serial_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key][0], types.FunctionType) or \
self.interventions[key][0].is_source_constant:
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
# set with intervened activation to source_i+1
set_handlers = self._intervention_setter(
[key],
Expand Down Expand Up @@ -1947,7 +1959,7 @@ def forward(
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(
self.interventions[key][0],
self.interventions[key],
CollectIntervention
):
collected_activations += self.activations[key]
Expand Down Expand Up @@ -2081,7 +2093,7 @@ def generate(
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(
self.interventions[key][0],
self.interventions[key],
CollectIntervention
):
collected_activations += self.activations[key]
Expand Down
17 changes: 16 additions & 1 deletion pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
from .constants import *


class LambdaIntervention(torch.nn.Module):
"""
A generic wrapper to turn any Python callable (e.g. a lambda)
into an nn.Module. This does *not* automatically turn external
Tensors into parameters or buffers—it's just a functional wrapper.
"""
def __init__(self, func):
super().__init__()
self.func = func # store the lambda or any callable

def forward(self, *args, **kwargs):
# Simply call the stored function
return self.func(*args, **kwargs)


def get_internal_model_type(model):
"""Return the model type."""
return type(model)
Expand Down Expand Up @@ -435,7 +450,7 @@ def do_intervention(
):
"""Do the actual intervention."""

if isinstance(intervention, types.FunctionType):
if isinstance(intervention, LambdaIntervention):
if subspaces is None:
return intervention(base_representation, source_representation)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2715,7 +2715,7 @@
"# zero-out grads\n",
"_ = pv_gpt2.model.eval()\n",
"for k, v in pv_gpt2.interventions.items():\n",
" v[0].zero_grad()\n",
" v.zero_grad()\n",
"\n",
"original_outputs, counterfactual_outputs = pv_gpt2(\n",
" base, \n",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="pyvene",
version="0.1.6",
version="0.1.7dev",
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,8 @@ def _test_subspace_partition_in_forward(self, intervention_type):
RotatedSpaceIntervention,
LowRankRotatedSpaceIntervention,
}:
list(fast.interventions.values())[0][
0
].rotate_layer.weight = list(intervenable.interventions.values())[0][
0
].rotate_layer.weight
list(fast.interventions.values())[0].rotate_layer.weight = \
list(intervenable.interventions.values())[0].rotate_layer.weight

_, without_partition_our_output = fast(
base,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/InterventionWithGPT2TestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _test_with_head_position_intervention(
intervention_types=intervention_type,
)
intervenable = IntervenableModel(config, self.gpt2)
intervention = list(intervenable.interventions.values())[0][0]
intervention = list(intervenable.interventions.values())[0]

base_activations = {}
source_activations = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/InterventionWithLlamaTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _test_with_head_position_intervention(
intervention_types=intervention_type,
)
intervenable = IntervenableModel(config, self.llama)
intervention = list(intervenable.interventions.values())[0][0]
intervention = list(intervenable.interventions.values())[0]

base_activations = {}
source_activations = {}
Expand Down
6 changes: 3 additions & 3 deletions tutorials/advanced_tutorials/Boundless_DAS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@
"warm_up_steps = 0.1 * t_total\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
"scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
" loss = loss_fct(shift_logits, shift_labels)\n",
"\n",
" for k, v in intervenable.interventions.items():\n",
" boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n",
" boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
" loss += boundary_loss\n",
"\n",
" return loss"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@
"t_total = int(len(dataset) * epochs)\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" break\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n",
"\n",
Expand Down
Loading
Loading