Skip to content

Commit

Permalink
Merge pull request #204 from stanfordnlp/zen/fsdp
Browse files Browse the repository at this point in the history
[P0] Initiate the support of FSDP training (#205)
  • Loading branch information
frankaging authored Feb 3, 2025
2 parents 4be6f6e + 262ed44 commit 781cd02
Show file tree
Hide file tree
Showing 16 changed files with 107 additions and 85 deletions.
88 changes: 50 additions & 38 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def __init__(self, config, model, backend, **kwargs):
# mapping between supported abstract type and module name.
###
self.representations = {}
self.interventions = {}
self.interventions = torch.nn.ModuleDict({})
self.intervention_hooks = {}
self._key_collision_counter = {}
self.return_collect_activations = False
# Flags and counters below are for interventions in the model.generate
Expand All @@ -116,6 +117,7 @@ def __init__(self, config, model, backend, **kwargs):
config.representations
):
_key = self._get_representation_key(representation)
print(f"Intervention key: {_key}")

if representation.intervention is not None:
intervention = representation.intervention
Expand Down Expand Up @@ -164,7 +166,11 @@ def __init__(self, config, model, backend, **kwargs):
model, representation, backend
)
self.representations[_key] = representation
self.interventions[_key] = (intervention, module_hook)
if isinstance(intervention, types.FunctionType):
self.interventions[_key] = LambdaIntervention(intervention)
else:
self.interventions[_key] = intervention
self.intervention_hooks[_key] = module_hook
self._key_getter_call_counter[
_key
] = 0 # we memo how many the hook is called,
Expand Down Expand Up @@ -266,11 +272,13 @@ def _get_representation_key(self, representation):
c = representation.component
u = representation.unit
n = representation.max_number_of_units
_u = u.replace(".", "_") # this will need internal functions to be changed as well.
if "." in c:
_c = c.replace(".", "_")
# string access for sure
key_proposal = f"comp.{c}.unit.{u}.nunit.{n}"
key_proposal = f"comp_{_c}_unit_{_u}_nunit_{n}"
else:
key_proposal = f"layer.{l}.comp.{c}.unit.{u}.nunit.{n}"
key_proposal = f"layer_{l}_comp_{c}_unit_{_u}_nunit_{n}"
if key_proposal not in self._key_collision_counter:
self._key_collision_counter[key_proposal] = 0
else:
Expand All @@ -283,8 +291,8 @@ def get_trainable_parameters(self):
"""
ret_params = []
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
ret_params += [p for p in v[0].parameters()]
if isinstance(v, TrainableIntervention):
ret_params += [p for p in v.parameters()]
for p in self.model.parameters():
if p.requires_grad:
ret_params += [p]
Expand All @@ -296,8 +304,8 @@ def named_parameters(self, recurse=True):
"""
ret_params = []
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
ret_params += [(k + '.' + n, p) for n, p in v[0].named_parameters()]
if isinstance(v, TrainableIntervention):
ret_params += [(k + '.' + n, p) for n, p in v.named_parameters()]
for n, p in self.model.named_parameters():
if p.requires_grad:
ret_params += [('model.' + n, p)]
Expand All @@ -320,9 +328,9 @@ def set_temperature(self, temp: torch.Tensor):
Set temperature if needed
"""
for k, v in self.interventions.items():
if isinstance(v[0], BoundlessRotatedSpaceIntervention) or \
isinstance(v[0], SigmoidMaskIntervention):
v[0].set_temperature(temp)
if isinstance(v, BoundlessRotatedSpaceIntervention) or \
isinstance(v, SigmoidMaskIntervention):
v.set_temperature(temp)

def enable_model_gradients(self):
"""
Expand Down Expand Up @@ -356,7 +364,7 @@ def set_device(self, device, set_model=True):
Set device of interventions and the model
"""
for k, v in self.interventions.items():
v[0].to(device)
v.to(device)
if set_model:
self.model.to(device)

Expand All @@ -373,13 +381,13 @@ def count_parameters(self, include_model=False):
_linked_key_set = set([])
total_parameters = 0
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
if isinstance(v, TrainableIntervention):
if k in self._intervention_reverse_link:
if not self._intervention_reverse_link[k] in _linked_key_set:
_linked_key_set.add(self._intervention_reverse_link[k])
total_parameters += count_parameters(v[0])
total_parameters += count_parameters(v)
else:
total_parameters += count_parameters(v[0])
total_parameters += count_parameters(v)
if include_model:
total_parameters += sum(
p.numel() for p in self.model.parameters() if p.requires_grad)
Expand All @@ -390,16 +398,16 @@ def set_zero_grad(self):
Set device of interventions and the model
"""
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
v[0].zero_grad()
if isinstance(v, TrainableIntervention):
v.zero_grad()

def zero_grad(self):
"""
The above, but for HuggingFace.
"""
for k, v in self.interventions.items():
if isinstance(v[0], TrainableIntervention):
v[0].zero_grad()
if isinstance(v, TrainableIntervention):
v.zero_grad()

def _input_validation(
self,
Expand Down Expand Up @@ -758,7 +766,8 @@ def _intervention_getter(
"""
handlers = []
for key_i, key in enumerate(keys):
intervention, (module_hook, hook_type) = self.interventions[key]
intervention = self.interventions[key]
(module_hook, hook_type) = self.intervention_hooks[key]
if self._is_generation:
raise NotImplementedError("Generation is not implemented for ndif backend")

Expand Down Expand Up @@ -803,7 +812,8 @@ def _intervention_setter(
self._tidy_stateful_activations()

for key_i, key in enumerate(keys):
intervention, (module_hook, hook_type) = self.interventions[key]
intervention = self.interventions[key]
(module_hook, hook_type) = self.intervention_hooks[key]
if unit_locations_base[0] is not None:
self._batched_setter_activation_select[key] = [
0 for _ in range(len(unit_locations_base[0]))
Expand Down Expand Up @@ -846,7 +856,7 @@ def _intervention_setter(
# no-op to the output

else:
if not isinstance(self.interventions[key][0], types.FunctionType):
if not isinstance(self.interventions[key], LambdaIntervention):
if intervention.is_source_constant:
intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -944,8 +954,8 @@ def _sync_forward_with_parallel_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key][0], types.FunctionType) or \
self.interventions[key][0].is_source_constant:
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
self._intervention_setter(
[key],
[
Expand Down Expand Up @@ -1056,7 +1066,7 @@ def forward(
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(
self.interventions[key][0],
self.interventions[key],
CollectIntervention
):
collected_activations += self.activations[key].clone()
Expand Down Expand Up @@ -1191,7 +1201,7 @@ def save(
serialized_representations

for k, v in self.interventions.items():
intervention = v[0]
intervention = v
saving_config.intervention_types += [str(type(intervention))]
binary_filename = f"intkey_{k}.bin"
# save intervention binary file
Expand Down Expand Up @@ -1288,7 +1298,7 @@ def load(

# load binary files
for i, (k, v) in enumerate(intervenable.interventions.items()):
intervention = v[0]
intervention = v
binary_filename = f"intkey_{k}.bin"
intervention.is_source_constant = \
saving_config.intervention_constant_sources[i]
Expand Down Expand Up @@ -1334,7 +1344,7 @@ def save_intervention(self, save_directory, include_model=True):

# save binary files
for k, v in self.interventions.items():
intervention = v[0]
intervention = v
binary_filename = f"intkey_{k}.bin"
# save intervention binary file
if isinstance(intervention, TrainableIntervention):
Expand All @@ -1357,7 +1367,7 @@ def load_intervention(self, load_directory, include_model=True):
"""
# load binary files
for i, (k, v) in enumerate(self.interventions.items()):
intervention = v[0]
intervention = v
binary_filename = f"intkey_{k}.bin"
if isinstance(intervention, TrainableIntervention):
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
Expand All @@ -1379,7 +1389,8 @@ def _intervention_getter(
"""
handlers = []
for key_i, key in enumerate(keys):
intervention, module_hook = self.interventions[key]
intervention = self.interventions[key]
module_hook = self.intervention_hooks[key]

def hook_callback(model, args, kwargs, output=None):
if self._is_generation:
Expand Down Expand Up @@ -1524,7 +1535,8 @@ def _intervention_setter(

handlers = []
for key_i, key in enumerate(keys):
intervention, module_hook = self.interventions[key]
intervention = self.interventions[key]
module_hook = self.intervention_hooks[key]
if unit_locations_base[0] is not None:
self._batched_setter_activation_select[key] = [
0 for _ in range(len(unit_locations_base[0]))
Expand Down Expand Up @@ -1570,7 +1582,7 @@ def hook_callback(model, args, kwargs, output=None):
# no-op to the output

else:
if not isinstance(self.interventions[key][0], types.FunctionType):
if not isinstance(self.interventions[key], LambdaIntervention):
if intervention.is_source_constant:
raw_intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -1710,8 +1722,8 @@ def _wait_for_forward_with_parallel_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key][0], types.FunctionType) or \
self.interventions[key][0].is_source_constant:
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
set_handlers = self._intervention_setter(
[key],
[
Expand Down Expand Up @@ -1780,8 +1792,8 @@ def _wait_for_forward_with_serial_intervention(
for key in keys:
# skip in case smart jump
if key in self.activations or \
isinstance(self.interventions[key][0], types.FunctionType) or \
self.interventions[key][0].is_source_constant:
isinstance(self.interventions[key], LambdaIntervention) or \
self.interventions[key].is_source_constant:
# set with intervened activation to source_i+1
set_handlers = self._intervention_setter(
[key],
Expand Down Expand Up @@ -1947,7 +1959,7 @@ def forward(
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(
self.interventions[key][0],
self.interventions[key],
CollectIntervention
):
collected_activations += self.activations[key]
Expand Down Expand Up @@ -2081,7 +2093,7 @@ def generate(
if self.return_collect_activations:
for key in self.sorted_keys:
if isinstance(
self.interventions[key][0],
self.interventions[key],
CollectIntervention
):
collected_activations += self.activations[key]
Expand Down
17 changes: 16 additions & 1 deletion pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
from .constants import *


class LambdaIntervention(torch.nn.Module):
"""
A generic wrapper to turn any Python callable (e.g. a lambda)
into an nn.Module. This does *not* automatically turn external
Tensors into parameters or buffers—it's just a functional wrapper.
"""
def __init__(self, func):
super().__init__()
self.func = func # store the lambda or any callable

def forward(self, *args, **kwargs):
# Simply call the stored function
return self.func(*args, **kwargs)


def get_internal_model_type(model):
"""Return the model type."""
return type(model)
Expand Down Expand Up @@ -435,7 +450,7 @@ def do_intervention(
):
"""Do the actual intervention."""

if isinstance(intervention, types.FunctionType):
if isinstance(intervention, LambdaIntervention):
if subspaces is None:
return intervention(base_representation, source_representation)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2715,7 +2715,7 @@
"# zero-out grads\n",
"_ = pv_gpt2.model.eval()\n",
"for k, v in pv_gpt2.interventions.items():\n",
" v[0].zero_grad()\n",
" v.zero_grad()\n",
"\n",
"original_outputs, counterfactual_outputs = pv_gpt2(\n",
" base, \n",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setup(
name="pyvene",
version="0.1.6",
version="0.1.7dev",
description="Use Activation Intervention to Interpret Causal Mechanism of Model",
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,8 @@ def _test_subspace_partition_in_forward(self, intervention_type):
RotatedSpaceIntervention,
LowRankRotatedSpaceIntervention,
}:
list(fast.interventions.values())[0][
0
].rotate_layer.weight = list(intervenable.interventions.values())[0][
0
].rotate_layer.weight
list(fast.interventions.values())[0].rotate_layer.weight = \
list(intervenable.interventions.values())[0].rotate_layer.weight

_, without_partition_our_output = fast(
base,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/InterventionWithGPT2TestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _test_with_head_position_intervention(
intervention_types=intervention_type,
)
intervenable = IntervenableModel(config, self.gpt2)
intervention = list(intervenable.interventions.values())[0][0]
intervention = list(intervenable.interventions.values())[0]

base_activations = {}
source_activations = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/InterventionWithLlamaTestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _test_with_head_position_intervention(
intervention_types=intervention_type,
)
intervenable = IntervenableModel(config, self.llama)
intervention = list(intervenable.interventions.values())[0][0]
intervention = list(intervenable.interventions.values())[0]

base_activations = {}
source_activations = {}
Expand Down
6 changes: 3 additions & 3 deletions tutorials/advanced_tutorials/Boundless_DAS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@
"warm_up_steps = 0.1 * t_total\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
"scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
" loss = loss_fct(shift_logits, shift_labels)\n",
"\n",
" for k, v in intervenable.interventions.items():\n",
" boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n",
" boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
" loss += boundary_loss\n",
"\n",
" return loss"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@
"t_total = int(len(dataset) * epochs)\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" break\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n",
"\n",
Expand Down
Loading

0 comments on commit 781cd02

Please sign in to comment.