Skip to content

Commit 769143a

Browse files
committed
avoid tensor squash for localist repr intervention
1 parent d40fff9 commit 769143a

File tree

3 files changed

+59
-19
lines changed

3 files changed

+59
-19
lines changed

pyvene/models/interventions.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,31 @@ def tie_weight(self, linked_intervention):
4040

4141
class ConstantSourceIntervention(Intervention):
4242

43-
"""Intervention the original representations."""
43+
"""Constant source."""
4444

4545
def __init__(self, **kwargs):
4646
super().__init__(**kwargs)
4747
self.is_source_constant = True
48-
4948

49+
50+
class LocalistRepresentationIntervention(torch.nn.Module):
51+
52+
"""Localist representation."""
53+
54+
def __init__(self, **kwargs):
55+
super().__init__(**kwargs)
56+
self.is_repr_distributed = False
57+
58+
59+
class DistributedRepresentationIntervention(torch.nn.Module):
60+
61+
"""Distributed representation."""
62+
63+
def __init__(self, **kwargs):
64+
super().__init__(**kwargs)
65+
self.is_repr_distributed = True
66+
67+
5068
class BasisAgnosticIntervention(Intervention):
5169

5270
"""Intervention that will modify its basis in a uncontrolled manner."""
@@ -66,7 +84,7 @@ def __init__(self, **kwargs):
6684
self.shared_weights = True
6785

6886

69-
class ZeroIntervention(ConstantSourceIntervention):
87+
class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention):
7088

7189
"""Zero-out activations."""
7290

@@ -126,7 +144,7 @@ def __str__(self):
126144
return f"CollectIntervention(embed_dim={self.embed_dim})"
127145

128146

129-
class SkipIntervention(BasisAgnosticIntervention):
147+
class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention):
130148

131149
"""Skip the current intervening layer's computation in the hook function."""
132150

@@ -156,7 +174,7 @@ def __str__(self):
156174
return f"SkipIntervention(embed_dim={self.embed_dim})"
157175

158176

159-
class VanillaIntervention(Intervention):
177+
class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
160178

161179
"""Intervention the original representations."""
162180

@@ -191,7 +209,7 @@ def __str__(self):
191209
return f"VanillaIntervention(embed_dim={self.embed_dim})"
192210

193211

194-
class AdditionIntervention(BasisAgnosticIntervention):
212+
class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention):
195213

196214
"""Intervention the original representations with activation addition."""
197215

@@ -226,7 +244,7 @@ def __str__(self):
226244
return f"AdditionIntervention(embed_dim={self.embed_dim})"
227245

228246

229-
class SubtractionIntervention(BasisAgnosticIntervention):
247+
class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention):
230248

231249
"""Intervention the original representations with activation subtraction."""
232250

@@ -261,7 +279,7 @@ def __str__(self):
261279
return f"SubtractionIntervention(embed_dim={self.embed_dim})"
262280

263281

264-
class RotatedSpaceIntervention(TrainableIntervention):
282+
class RotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):
265283

266284
"""Intervention in the rotated space."""
267285

@@ -299,7 +317,7 @@ def __str__(self):
299317
return f"RotatedSpaceIntervention(embed_dim={self.embed_dim})"
300318

301319

302-
class BoundlessRotatedSpaceIntervention(TrainableIntervention):
320+
class BoundlessRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):
303321

304322
"""Intervention in the rotated space with boundary mask."""
305323

@@ -366,7 +384,7 @@ def __str__(self):
366384
return f"BoundlessRotatedSpaceIntervention(embed_dim={self.embed_dim})"
367385

368386

369-
class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention):
387+
class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):
370388

371389
"""Intervention in the rotated space with boundary mask."""
372390

@@ -420,7 +438,7 @@ def __str__(self):
420438
return f"SigmoidMaskRotatedSpaceIntervention(embed_dim={self.embed_dim})"
421439

422440

423-
class LowRankRotatedSpaceIntervention(TrainableIntervention):
441+
class LowRankRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):
424442

425443
"""Intervention in the rotated space."""
426444

@@ -503,7 +521,7 @@ def __str__(self):
503521
return f"LowRankRotatedSpaceIntervention(embed_dim={self.embed_dim})"
504522

505523

506-
class PCARotatedSpaceIntervention(BasisAgnosticIntervention):
524+
class PCARotatedSpaceIntervention(BasisAgnosticIntervention, DistributedRepresentationIntervention):
507525
"""Intervention in the pca space."""
508526

509527
def __init__(self, embed_dim, **kwargs):

pyvene/models/modeling_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch import nn
33
import numpy as np
44
from .intervenable_modelcard import *
5+
from .interventions import *
56

67

78
def get_internal_model_type(model):
@@ -517,7 +518,8 @@ def do_intervention(
517518

518519
# flatten
519520
original_base_shape = base_representation.shape
520-
if len(original_base_shape) == 2:
521+
if len(original_base_shape) == 2 or \
522+
isinstance(intervention, LocalistRepresentationIntervention):
521523
# no pos dimension, e.g., gru
522524
base_representation_f = base_representation
523525
source_representation_f = source_representation
@@ -537,7 +539,8 @@ def do_intervention(
537539
)
538540

539541
# unflatten
540-
if len(original_base_shape) == 2:
542+
if len(original_base_shape) == 2 or \
543+
isinstance(intervention, LocalistRepresentationIntervention):
541544
# no pos dimension, e.g., gru
542545
pass
543546
elif len(original_base_shape) == 3:

tests/integration_tests/InterventionWithGPT2TestCase.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,14 @@ def _test_with_position_intervention_constant_source(
418418
_key = f"{intervention_layer}.{intervention_stream}"
419419

420420
for position in positions:
421-
base_activations[_key][:, position] = intervention(
422-
base_activations[_key][:, position],
423-
None,
424-
)
421+
if intervention_type == ZeroIntervention:
422+
base_activations[_key][:, position] = torch.zeros_like(
423+
base_activations[_key][:, position])
424+
else:
425+
base_activations[_key][:, position] = intervention(
426+
base_activations[_key][:, position],
427+
None,
428+
)
425429

426430
golden_out = GPT2_RUN(
427431
self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]}
@@ -535,7 +539,7 @@ def test_with_position_intervention_constant_source_subtraction_intervention_pos
535539
use_base_only=True
536540
)
537541

538-
def test_with_position_intervention_constant_source_subtraction_intervention_positive(self):
542+
def test_with_position_intervention_constant_source_zero_intervention_positive(self):
539543
"""
540544
Enable constant source with subtraction intervention.
541545
"""
@@ -605,6 +609,21 @@ def suite():
605609
"test_with_position_intervention_constant_source_vanilla_intervention_positive"
606610
)
607611
)
612+
suite.addTest(
613+
InterventionWithGPT2TestCase(
614+
"test_with_position_intervention_constant_source_addition_intervention_positive"
615+
)
616+
)
617+
suite.addTest(
618+
InterventionWithGPT2TestCase(
619+
"test_with_position_intervention_constant_source_subtraction_intervention_positive"
620+
)
621+
)
622+
suite.addTest(
623+
InterventionWithGPT2TestCase(
624+
"test_with_position_intervention_constant_source_zero_intervention_positive"
625+
)
626+
)
608627
return suite
609628

610629

0 commit comments

Comments
 (0)