@@ -99,6 +99,7 @@ def __init__(self, config, model, backend, **kwargs):
99
99
# token, or on a combination of both.
100
100
self ._is_generation = False
101
101
self ._intervene_on_prompt = None
102
+ self ._intervene_on_generated = False
102
103
self ._key_getter_call_counter = {}
103
104
self ._key_setter_call_counter = {}
104
105
self ._intervention_pointers = {}
@@ -724,6 +725,7 @@ def _cleanup_states(self, skip_activation_gc=False):
724
725
Clean up all old in memo states of interventions
725
726
"""
726
727
self ._is_generation = False
728
+ self ._intervene_on_generated = False
727
729
if not skip_activation_gc :
728
730
self .activations .clear ()
729
731
self .hot_activations .clear ()
@@ -767,7 +769,7 @@ def _intervention_getter(
767
769
for key_i , key in enumerate (keys ):
768
770
intervention = self .interventions [key ]
769
771
(module_hook , hook_type ) = self .intervention_hooks [key ]
770
- if self ._is_generation :
772
+ if self ._is_generation or self . _intervene_on_generated :
771
773
raise NotImplementedError ("Generation is not implemented for ndif backend" )
772
774
773
775
if hook_type == CONST_INPUT_HOOK :
@@ -818,7 +820,7 @@ def _intervention_setter(
818
820
0 for _ in range (len (unit_locations_base [0 ]))
819
821
] # batch_size
820
822
821
- if self ._is_generation :
823
+ if self ._is_generation or self . _intervene_on_generated :
822
824
raise NotImplementedError ("Generation is not implemented for ndif backend" )
823
825
824
826
if hook_type == CONST_INPUT_HOOK :
@@ -916,7 +918,7 @@ def _sync_forward_with_parallel_intervention(
916
918
** kwargs ,
917
919
):
918
920
# torch.autograd.set_detect_anomaly(True)
919
- all_set_handlers = HandlerList ([])
921
+ all_handlers_to_remove = HandlerList ([])
920
922
unit_locations_sources = unit_locations ["sources->base" ][0 ]
921
923
unit_locations_base = unit_locations ["sources->base" ][1 ]
922
924
@@ -1141,6 +1143,7 @@ def _cleanup_states(self, skip_activation_gc=False):
1141
1143
Clean up all old in memo states of interventions
1142
1144
"""
1143
1145
self ._is_generation = False
1146
+ self ._intervene_on_generated = False
1144
1147
self ._remove_forward_hooks ()
1145
1148
self ._reset_hook_count ()
1146
1149
if not skip_activation_gc :
@@ -1392,7 +1395,7 @@ def _intervention_getter(
1392
1395
module_hook = self .intervention_hooks [key ]
1393
1396
1394
1397
def hook_callback (model , args , kwargs , output = None ):
1395
- if self ._is_generation :
1398
+ if self ._is_generation and not self . _intervene_on_generated :
1396
1399
pass
1397
1400
# for getter, there is no restriction.
1398
1401
# is_prompt = self._key_getter_call_counter[key] == 0
@@ -1685,7 +1688,7 @@ def _wait_for_forward_with_parallel_intervention(
1685
1688
subspaces : Optional [List ] = None ,
1686
1689
):
1687
1690
# torch.autograd.set_detect_anomaly(True)
1688
- all_set_handlers = HandlerList ([])
1691
+ all_handlers_to_remove = HandlerList ([])
1689
1692
unit_locations_sources = unit_locations ["sources->base" ][0 ]
1690
1693
unit_locations_base = unit_locations ["sources->base" ][1 ]
1691
1694
@@ -1708,7 +1711,11 @@ def _wait_for_forward_with_parallel_intervention(
1708
1711
)
1709
1712
group_get_handlers .extend (get_handlers )
1710
1713
_ = self .model (** sources [group_id ])
1711
- group_get_handlers .remove ()
1714
+ if not self ._is_generation or not self ._intervene_on_generated :
1715
+ group_get_handlers .remove ()
1716
+ else :
1717
+ # For generation, we use one getter for all setters
1718
+ all_handlers_to_remove .extend (group_get_handlers )
1712
1719
else :
1713
1720
# simply patch in the ones passed in
1714
1721
self .activations = activations_sources
@@ -1740,8 +1747,8 @@ def _wait_for_forward_with_parallel_intervention(
1740
1747
else None ,
1741
1748
)
1742
1749
# for setters, we don't remove them.
1743
- all_set_handlers .extend (set_handlers )
1744
- return all_set_handlers
1750
+ all_handlers_to_remove .extend (set_handlers )
1751
+ return all_handlers_to_remove
1745
1752
1746
1753
def _wait_for_forward_with_serial_intervention (
1747
1754
self ,
@@ -1750,7 +1757,8 @@ def _wait_for_forward_with_serial_intervention(
1750
1757
activations_sources : Optional [Dict ] = None ,
1751
1758
subspaces : Optional [List ] = None ,
1752
1759
):
1753
- all_set_handlers = HandlerList ([])
1760
+ set_handlers_to_remove = HandlerList ([])
1761
+ get_handlers_to_remove = HandlerList ([])
1754
1762
for group_id , keys in self ._intervention_group .items ():
1755
1763
if sources [group_id ] is None :
1756
1764
continue # smart jump for advance usage only
@@ -1782,11 +1790,15 @@ def _wait_for_forward_with_serial_intervention(
1782
1790
if activations_sources is None :
1783
1791
# this is when previous setter and THEN the getter get called
1784
1792
_ = self .model (** sources [group_id ])
1785
- get_handlers .remove ()
1793
+ if not self ._is_generation or not self ._intervene_on_generated :
1794
+ get_handlers .remove ()
1795
+ else :
1796
+ # For generation, we use one getter for all setters
1797
+ get_handlers_to_remove .extend (get_handlers )
1786
1798
# remove existing setters after getting the curr intervened reprs
1787
- if len (all_set_handlers ) > 0 :
1788
- all_set_handlers .remove ()
1789
- all_set_handlers = HandlerList ([])
1799
+ if len (set_handlers_to_remove ) > 0 :
1800
+ set_handlers_to_remove .remove ()
1801
+ set_handlers_to_remove = HandlerList ([])
1790
1802
1791
1803
for key in keys :
1792
1804
# skip in case smart jump
@@ -1807,8 +1819,9 @@ def _wait_for_forward_with_serial_intervention(
1807
1819
else None ,
1808
1820
)
1809
1821
# for setters, we don't remove them.
1810
- all_set_handlers .extend (set_handlers )
1811
- return all_set_handlers
1822
+ set_handlers_to_remove .extend (set_handlers )
1823
+ all_handlers_to_remove = set_handlers_to_remove + get_handlers_to_remove
1824
+ return all_handlers_to_remove
1812
1825
1813
1826
def forward (
1814
1827
self ,
@@ -1923,7 +1936,7 @@ def forward(
1923
1936
try :
1924
1937
# intervene
1925
1938
if self .mode == "parallel" :
1926
- set_handlers_to_remove = (
1939
+ all_handlers_to_remove = (
1927
1940
self ._wait_for_forward_with_parallel_intervention (
1928
1941
sources ,
1929
1942
unit_locations ,
@@ -1932,7 +1945,7 @@ def forward(
1932
1945
)
1933
1946
)
1934
1947
elif self .mode == "serial" :
1935
- set_handlers_to_remove = (
1948
+ all_handlers_to_remove = (
1936
1949
self ._wait_for_forward_with_serial_intervention (
1937
1950
sources ,
1938
1951
unit_locations ,
@@ -1950,7 +1963,7 @@ def forward(
1950
1963
1951
1964
counterfactual_outputs = self .model (** base , ** model_kwargs )
1952
1965
1953
- set_handlers_to_remove .remove ()
1966
+ all_handlers_to_remove .remove ()
1954
1967
1955
1968
self ._output_validation ()
1956
1969
@@ -2000,6 +2013,7 @@ def generate(
2000
2013
intervene_on_prompt : bool = False ,
2001
2014
subspaces : Optional [List ] = None ,
2002
2015
output_original_output : Optional [bool ] = False ,
2016
+ intervene_on_generated : bool = False ,
2003
2017
** kwargs ,
2004
2018
):
2005
2019
"""
@@ -2035,6 +2049,7 @@ def generate(
2035
2049
self ._cleanup_states ()
2036
2050
2037
2051
self ._intervene_on_prompt = intervene_on_prompt
2052
+ self ._intervene_on_generated = intervene_on_generated
2038
2053
self ._is_generation = True
2039
2054
2040
2055
if not intervene_on_prompt and unit_locations is None :
@@ -2061,11 +2076,11 @@ def generate(
2061
2076
# returning un-intervened output
2062
2077
base_outputs = self .model .generate (** base , ** kwargs )
2063
2078
2064
- set_handlers_to_remove = None
2079
+ all_handlers_to_remove = None
2065
2080
try :
2066
2081
# intervene
2067
2082
if self .mode == "parallel" :
2068
- set_handlers_to_remove = (
2083
+ all_handlers_to_remove = (
2069
2084
self ._wait_for_forward_with_parallel_intervention (
2070
2085
sources ,
2071
2086
unit_locations ,
@@ -2074,7 +2089,7 @@ def generate(
2074
2089
)
2075
2090
)
2076
2091
elif self .mode == "serial" :
2077
- set_handlers_to_remove = (
2092
+ all_handlers_to_remove = (
2078
2093
self ._wait_for_forward_with_serial_intervention (
2079
2094
sources ,
2080
2095
unit_locations ,
@@ -2099,9 +2114,10 @@ def generate(
2099
2114
except Exception as e :
2100
2115
raise e
2101
2116
finally :
2102
- if set_handlers_to_remove is not None :
2103
- set_handlers_to_remove .remove ()
2117
+ if all_handlers_to_remove is not None :
2118
+ all_handlers_to_remove .remove ()
2104
2119
self ._is_generation = False
2120
+ self ._intervene_on_generated = False
2105
2121
self ._cleanup_states (
2106
2122
skip_activation_gc = \
2107
2123
(sources is None and activations_sources is not None ) or \
0 commit comments