Skip to content

Commit

Permalink
[Minor] Allow intervene on generated tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
PinetreePantry committed Feb 24, 2025
1 parent 3a9d242 commit c8493b3
Showing 1 changed file with 39 additions and 23 deletions.
62 changes: 39 additions & 23 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self, config, model, backend, **kwargs):
# token, or on a combination of both.
self._is_generation = False
self._intervene_on_prompt = None
self._intervene_on_generated = False
self._key_getter_call_counter = {}
self._key_setter_call_counter = {}
self._intervention_pointers = {}
Expand Down Expand Up @@ -724,6 +725,7 @@ def _cleanup_states(self, skip_activation_gc=False):
Clean up all old in memo states of interventions
"""
self._is_generation = False
self._intervene_on_generated = False
if not skip_activation_gc:
self.activations.clear()
self.hot_activations.clear()
Expand Down Expand Up @@ -767,7 +769,7 @@ def _intervention_getter(
for key_i, key in enumerate(keys):
intervention = self.interventions[key]
(module_hook, hook_type) = self.intervention_hooks[key]
if self._is_generation:
if self._is_generation or self._intervene_on_generated:
raise NotImplementedError("Generation is not implemented for ndif backend")

if hook_type == CONST_INPUT_HOOK:
Expand Down Expand Up @@ -818,7 +820,7 @@ def _intervention_setter(
0 for _ in range(len(unit_locations_base[0]))
] # batch_size

if self._is_generation:
if self._is_generation or self._intervene_on_generated:
raise NotImplementedError("Generation is not implemented for ndif backend")

if hook_type == CONST_INPUT_HOOK:
Expand Down Expand Up @@ -916,7 +918,7 @@ def _sync_forward_with_parallel_intervention(
**kwargs,
):
# torch.autograd.set_detect_anomaly(True)
all_set_handlers = HandlerList([])
all_handlers_to_remove = HandlerList([])
unit_locations_sources = unit_locations["sources->base"][0]
unit_locations_base = unit_locations["sources->base"][1]

Expand Down Expand Up @@ -1141,6 +1143,7 @@ def _cleanup_states(self, skip_activation_gc=False):
Clean up all old in memo states of interventions
"""
self._is_generation = False
self._intervene_on_generated = False
self._remove_forward_hooks()
self._reset_hook_count()
if not skip_activation_gc:
Expand Down Expand Up @@ -1392,7 +1395,7 @@ def _intervention_getter(
module_hook = self.intervention_hooks[key]

def hook_callback(model, args, kwargs, output=None):
if self._is_generation:
if self._is_generation and not self._intervene_on_generated:
pass
# for getter, there is no restriction.
# is_prompt = self._key_getter_call_counter[key] == 0
Expand Down Expand Up @@ -1685,7 +1688,7 @@ def _wait_for_forward_with_parallel_intervention(
subspaces: Optional[List] = None,
):
# torch.autograd.set_detect_anomaly(True)
all_set_handlers = HandlerList([])
all_handlers_to_remove = HandlerList([])
unit_locations_sources = unit_locations["sources->base"][0]
unit_locations_base = unit_locations["sources->base"][1]

Expand All @@ -1708,7 +1711,11 @@ def _wait_for_forward_with_parallel_intervention(
)
group_get_handlers.extend(get_handlers)
_ = self.model(**sources[group_id])
group_get_handlers.remove()
if not self._is_generation or not self._intervene_on_generated:
group_get_handlers.remove()
else:
# For generation, we use one getter for all setters
all_handlers_to_remove.extend(group_get_handlers)
else:
# simply patch in the ones passed in
self.activations = activations_sources
Expand Down Expand Up @@ -1740,8 +1747,8 @@ def _wait_for_forward_with_parallel_intervention(
else None,
)
# for setters, we don't remove them.
all_set_handlers.extend(set_handlers)
return all_set_handlers
all_handlers_to_remove.extend(set_handlers)
return all_handlers_to_remove

def _wait_for_forward_with_serial_intervention(
self,
Expand All @@ -1750,7 +1757,8 @@ def _wait_for_forward_with_serial_intervention(
activations_sources: Optional[Dict] = None,
subspaces: Optional[List] = None,
):
all_set_handlers = HandlerList([])
set_handlers_to_remove = HandlerList([])
get_handlers_to_remove = HandlerList([])
for group_id, keys in self._intervention_group.items():
if sources[group_id] is None:
continue # smart jump for advance usage only
Expand Down Expand Up @@ -1782,11 +1790,15 @@ def _wait_for_forward_with_serial_intervention(
if activations_sources is None:
# this is when previous setter and THEN the getter get called
_ = self.model(**sources[group_id])
get_handlers.remove()
if not self._is_generation or not self._intervene_on_generated:
get_handlers.remove()
else:
# For generation, we use one getter for all setters
get_handlers_to_remove.extend(get_handlers)
# remove existing setters after getting the curr intervened reprs
if len(all_set_handlers) > 0:
all_set_handlers.remove()
all_set_handlers = HandlerList([])
if len(set_handlers_to_remove) > 0:
set_handlers_to_remove.remove()
set_handlers_to_remove = HandlerList([])

for key in keys:
# skip in case smart jump
Expand All @@ -1807,8 +1819,9 @@ def _wait_for_forward_with_serial_intervention(
else None,
)
# for setters, we don't remove them.
all_set_handlers.extend(set_handlers)
return all_set_handlers
set_handlers_to_remove.extend(set_handlers)
all_handlers_to_remove = set_handlers_to_remove + get_handlers_to_remove
return all_handlers_to_remove

def forward(
self,
Expand Down Expand Up @@ -1923,7 +1936,7 @@ def forward(
try:
# intervene
if self.mode == "parallel":
set_handlers_to_remove = (
all_handlers_to_remove = (
self._wait_for_forward_with_parallel_intervention(
sources,
unit_locations,
Expand All @@ -1932,7 +1945,7 @@ def forward(
)
)
elif self.mode == "serial":
set_handlers_to_remove = (
all_handlers_to_remove = (
self._wait_for_forward_with_serial_intervention(
sources,
unit_locations,
Expand All @@ -1950,7 +1963,7 @@ def forward(

counterfactual_outputs = self.model(**base, **model_kwargs)

set_handlers_to_remove.remove()
all_handlers_to_remove.remove()

self._output_validation()

Expand Down Expand Up @@ -2000,6 +2013,7 @@ def generate(
intervene_on_prompt: bool = False,
subspaces: Optional[List] = None,
output_original_output: Optional[bool] = False,
intervene_on_generated: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -2035,6 +2049,7 @@ def generate(
self._cleanup_states()

self._intervene_on_prompt = intervene_on_prompt
self._intervene_on_generated = intervene_on_generated
self._is_generation = True

if not intervene_on_prompt and unit_locations is None:
Expand All @@ -2061,11 +2076,11 @@ def generate(
# returning un-intervened output
base_outputs = self.model.generate(**base, **kwargs)

set_handlers_to_remove = None
all_handlers_to_remove = None
try:
# intervene
if self.mode == "parallel":
set_handlers_to_remove = (
all_handlers_to_remove = (
self._wait_for_forward_with_parallel_intervention(
sources,
unit_locations,
Expand All @@ -2074,7 +2089,7 @@ def generate(
)
)
elif self.mode == "serial":
set_handlers_to_remove = (
all_handlers_to_remove = (
self._wait_for_forward_with_serial_intervention(
sources,
unit_locations,
Expand All @@ -2099,9 +2114,10 @@ def generate(
except Exception as e:
raise e
finally:
if set_handlers_to_remove is not None:
set_handlers_to_remove.remove()
if all_handlers_to_remove is not None:
all_handlers_to_remove.remove()
self._is_generation = False
self._intervene_on_generated = False
self._cleanup_states(
skip_activation_gc = \
(sources is None and activations_sources is not None) or \
Expand Down

0 comments on commit c8493b3

Please sign in to comment.