Skip to content

Commit c8493b3

Browse files
[Minor] Allow intervene on generated tokens
1 parent 3a9d242 commit c8493b3

File tree

1 file changed

+39
-23
lines changed

1 file changed

+39
-23
lines changed

pyvene/models/intervenable_base.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, config, model, backend, **kwargs):
9999
# token, or on a combination of both.
100100
self._is_generation = False
101101
self._intervene_on_prompt = None
102+
self._intervene_on_generated = False
102103
self._key_getter_call_counter = {}
103104
self._key_setter_call_counter = {}
104105
self._intervention_pointers = {}
@@ -724,6 +725,7 @@ def _cleanup_states(self, skip_activation_gc=False):
724725
Clean up all old in memo states of interventions
725726
"""
726727
self._is_generation = False
728+
self._intervene_on_generated = False
727729
if not skip_activation_gc:
728730
self.activations.clear()
729731
self.hot_activations.clear()
@@ -767,7 +769,7 @@ def _intervention_getter(
767769
for key_i, key in enumerate(keys):
768770
intervention = self.interventions[key]
769771
(module_hook, hook_type) = self.intervention_hooks[key]
770-
if self._is_generation:
772+
if self._is_generation or self._intervene_on_generated:
771773
raise NotImplementedError("Generation is not implemented for ndif backend")
772774

773775
if hook_type == CONST_INPUT_HOOK:
@@ -818,7 +820,7 @@ def _intervention_setter(
818820
0 for _ in range(len(unit_locations_base[0]))
819821
] # batch_size
820822

821-
if self._is_generation:
823+
if self._is_generation or self._intervene_on_generated:
822824
raise NotImplementedError("Generation is not implemented for ndif backend")
823825

824826
if hook_type == CONST_INPUT_HOOK:
@@ -916,7 +918,7 @@ def _sync_forward_with_parallel_intervention(
916918
**kwargs,
917919
):
918920
# torch.autograd.set_detect_anomaly(True)
919-
all_set_handlers = HandlerList([])
921+
all_handlers_to_remove = HandlerList([])
920922
unit_locations_sources = unit_locations["sources->base"][0]
921923
unit_locations_base = unit_locations["sources->base"][1]
922924

@@ -1141,6 +1143,7 @@ def _cleanup_states(self, skip_activation_gc=False):
11411143
Clean up all old in memo states of interventions
11421144
"""
11431145
self._is_generation = False
1146+
self._intervene_on_generated = False
11441147
self._remove_forward_hooks()
11451148
self._reset_hook_count()
11461149
if not skip_activation_gc:
@@ -1392,7 +1395,7 @@ def _intervention_getter(
13921395
module_hook = self.intervention_hooks[key]
13931396

13941397
def hook_callback(model, args, kwargs, output=None):
1395-
if self._is_generation:
1398+
if self._is_generation and not self._intervene_on_generated:
13961399
pass
13971400
# for getter, there is no restriction.
13981401
# is_prompt = self._key_getter_call_counter[key] == 0
@@ -1685,7 +1688,7 @@ def _wait_for_forward_with_parallel_intervention(
16851688
subspaces: Optional[List] = None,
16861689
):
16871690
# torch.autograd.set_detect_anomaly(True)
1688-
all_set_handlers = HandlerList([])
1691+
all_handlers_to_remove = HandlerList([])
16891692
unit_locations_sources = unit_locations["sources->base"][0]
16901693
unit_locations_base = unit_locations["sources->base"][1]
16911694

@@ -1708,7 +1711,11 @@ def _wait_for_forward_with_parallel_intervention(
17081711
)
17091712
group_get_handlers.extend(get_handlers)
17101713
_ = self.model(**sources[group_id])
1711-
group_get_handlers.remove()
1714+
if not self._is_generation or not self._intervene_on_generated:
1715+
group_get_handlers.remove()
1716+
else:
1717+
# For generation, we use one getter for all setters
1718+
all_handlers_to_remove.extend(group_get_handlers)
17121719
else:
17131720
# simply patch in the ones passed in
17141721
self.activations = activations_sources
@@ -1740,8 +1747,8 @@ def _wait_for_forward_with_parallel_intervention(
17401747
else None,
17411748
)
17421749
# for setters, we don't remove them.
1743-
all_set_handlers.extend(set_handlers)
1744-
return all_set_handlers
1750+
all_handlers_to_remove.extend(set_handlers)
1751+
return all_handlers_to_remove
17451752

17461753
def _wait_for_forward_with_serial_intervention(
17471754
self,
@@ -1750,7 +1757,8 @@ def _wait_for_forward_with_serial_intervention(
17501757
activations_sources: Optional[Dict] = None,
17511758
subspaces: Optional[List] = None,
17521759
):
1753-
all_set_handlers = HandlerList([])
1760+
set_handlers_to_remove = HandlerList([])
1761+
get_handlers_to_remove = HandlerList([])
17541762
for group_id, keys in self._intervention_group.items():
17551763
if sources[group_id] is None:
17561764
continue # smart jump for advance usage only
@@ -1782,11 +1790,15 @@ def _wait_for_forward_with_serial_intervention(
17821790
if activations_sources is None:
17831791
# this is when previous setter and THEN the getter get called
17841792
_ = self.model(**sources[group_id])
1785-
get_handlers.remove()
1793+
if not self._is_generation or not self._intervene_on_generated:
1794+
get_handlers.remove()
1795+
else:
1796+
# For generation, we use one getter for all setters
1797+
get_handlers_to_remove.extend(get_handlers)
17861798
# remove existing setters after getting the curr intervened reprs
1787-
if len(all_set_handlers) > 0:
1788-
all_set_handlers.remove()
1789-
all_set_handlers = HandlerList([])
1799+
if len(set_handlers_to_remove) > 0:
1800+
set_handlers_to_remove.remove()
1801+
set_handlers_to_remove = HandlerList([])
17901802

17911803
for key in keys:
17921804
# skip in case smart jump
@@ -1807,8 +1819,9 @@ def _wait_for_forward_with_serial_intervention(
18071819
else None,
18081820
)
18091821
# for setters, we don't remove them.
1810-
all_set_handlers.extend(set_handlers)
1811-
return all_set_handlers
1822+
set_handlers_to_remove.extend(set_handlers)
1823+
all_handlers_to_remove = set_handlers_to_remove + get_handlers_to_remove
1824+
return all_handlers_to_remove
18121825

18131826
def forward(
18141827
self,
@@ -1923,7 +1936,7 @@ def forward(
19231936
try:
19241937
# intervene
19251938
if self.mode == "parallel":
1926-
set_handlers_to_remove = (
1939+
all_handlers_to_remove = (
19271940
self._wait_for_forward_with_parallel_intervention(
19281941
sources,
19291942
unit_locations,
@@ -1932,7 +1945,7 @@ def forward(
19321945
)
19331946
)
19341947
elif self.mode == "serial":
1935-
set_handlers_to_remove = (
1948+
all_handlers_to_remove = (
19361949
self._wait_for_forward_with_serial_intervention(
19371950
sources,
19381951
unit_locations,
@@ -1950,7 +1963,7 @@ def forward(
19501963

19511964
counterfactual_outputs = self.model(**base, **model_kwargs)
19521965

1953-
set_handlers_to_remove.remove()
1966+
all_handlers_to_remove.remove()
19541967

19551968
self._output_validation()
19561969

@@ -2000,6 +2013,7 @@ def generate(
20002013
intervene_on_prompt: bool = False,
20012014
subspaces: Optional[List] = None,
20022015
output_original_output: Optional[bool] = False,
2016+
intervene_on_generated: bool = False,
20032017
**kwargs,
20042018
):
20052019
"""
@@ -2035,6 +2049,7 @@ def generate(
20352049
self._cleanup_states()
20362050

20372051
self._intervene_on_prompt = intervene_on_prompt
2052+
self._intervene_on_generated = intervene_on_generated
20382053
self._is_generation = True
20392054

20402055
if not intervene_on_prompt and unit_locations is None:
@@ -2061,11 +2076,11 @@ def generate(
20612076
# returning un-intervened output
20622077
base_outputs = self.model.generate(**base, **kwargs)
20632078

2064-
set_handlers_to_remove = None
2079+
all_handlers_to_remove = None
20652080
try:
20662081
# intervene
20672082
if self.mode == "parallel":
2068-
set_handlers_to_remove = (
2083+
all_handlers_to_remove = (
20692084
self._wait_for_forward_with_parallel_intervention(
20702085
sources,
20712086
unit_locations,
@@ -2074,7 +2089,7 @@ def generate(
20742089
)
20752090
)
20762091
elif self.mode == "serial":
2077-
set_handlers_to_remove = (
2092+
all_handlers_to_remove = (
20782093
self._wait_for_forward_with_serial_intervention(
20792094
sources,
20802095
unit_locations,
@@ -2099,9 +2114,10 @@ def generate(
20992114
except Exception as e:
21002115
raise e
21012116
finally:
2102-
if set_handlers_to_remove is not None:
2103-
set_handlers_to_remove.remove()
2117+
if all_handlers_to_remove is not None:
2118+
all_handlers_to_remove.remove()
21042119
self._is_generation = False
2120+
self._intervene_on_generated = False
21052121
self._cleanup_states(
21062122
skip_activation_gc = \
21072123
(sources is None and activations_sources is not None) or \

0 commit comments

Comments
 (0)