@@ -40,13 +40,31 @@ def tie_weight(self, linked_intervention):
40
40
41
41
class ConstantSourceIntervention (Intervention ):
42
42
43
- """Intervention the original representations ."""
43
+ """Constant source ."""
44
44
45
45
def __init__ (self , ** kwargs ):
46
46
super ().__init__ (** kwargs )
47
47
self .is_source_constant = True
48
-
49
48
49
+
50
+ class LocalistRepresentationIntervention (torch .nn .Module ):
51
+
52
+ """Localist representation."""
53
+
54
+ def __init__ (self , ** kwargs ):
55
+ super ().__init__ (** kwargs )
56
+ self .is_repr_distributed = False
57
+
58
+
59
+ class DistributedRepresentationIntervention (torch .nn .Module ):
60
+
61
+ """Distributed representation."""
62
+
63
+ def __init__ (self , ** kwargs ):
64
+ super ().__init__ (** kwargs )
65
+ self .is_repr_distributed = True
66
+
67
+
50
68
class BasisAgnosticIntervention (Intervention ):
51
69
52
70
"""Intervention that will modify its basis in a uncontrolled manner."""
@@ -66,7 +84,7 @@ def __init__(self, **kwargs):
66
84
self .shared_weights = True
67
85
68
86
69
- class ZeroIntervention (ConstantSourceIntervention ):
87
+ class ZeroIntervention (ConstantSourceIntervention , LocalistRepresentationIntervention ):
70
88
71
89
"""Zero-out activations."""
72
90
@@ -126,7 +144,7 @@ def __str__(self):
126
144
return f"CollectIntervention(embed_dim={ self .embed_dim } )"
127
145
128
146
129
- class SkipIntervention (BasisAgnosticIntervention ):
147
+ class SkipIntervention (BasisAgnosticIntervention , LocalistRepresentationIntervention ):
130
148
131
149
"""Skip the current intervening layer's computation in the hook function."""
132
150
@@ -156,7 +174,7 @@ def __str__(self):
156
174
return f"SkipIntervention(embed_dim={ self .embed_dim } )"
157
175
158
176
159
- class VanillaIntervention (Intervention ):
177
+ class VanillaIntervention (Intervention , LocalistRepresentationIntervention ):
160
178
161
179
"""Intervention the original representations."""
162
180
@@ -191,7 +209,7 @@ def __str__(self):
191
209
return f"VanillaIntervention(embed_dim={ self .embed_dim } )"
192
210
193
211
194
- class AdditionIntervention (BasisAgnosticIntervention ):
212
+ class AdditionIntervention (BasisAgnosticIntervention , LocalistRepresentationIntervention ):
195
213
196
214
"""Intervention the original representations with activation addition."""
197
215
@@ -226,7 +244,7 @@ def __str__(self):
226
244
return f"AdditionIntervention(embed_dim={ self .embed_dim } )"
227
245
228
246
229
- class SubtractionIntervention (BasisAgnosticIntervention ):
247
+ class SubtractionIntervention (BasisAgnosticIntervention , LocalistRepresentationIntervention ):
230
248
231
249
"""Intervention the original representations with activation subtraction."""
232
250
@@ -261,7 +279,7 @@ def __str__(self):
261
279
return f"SubtractionIntervention(embed_dim={ self .embed_dim } )"
262
280
263
281
264
- class RotatedSpaceIntervention (TrainableIntervention ):
282
+ class RotatedSpaceIntervention (TrainableIntervention , DistributedRepresentationIntervention ):
265
283
266
284
"""Intervention in the rotated space."""
267
285
@@ -299,7 +317,7 @@ def __str__(self):
299
317
return f"RotatedSpaceIntervention(embed_dim={ self .embed_dim } )"
300
318
301
319
302
- class BoundlessRotatedSpaceIntervention (TrainableIntervention ):
320
+ class BoundlessRotatedSpaceIntervention (TrainableIntervention , DistributedRepresentationIntervention ):
303
321
304
322
"""Intervention in the rotated space with boundary mask."""
305
323
@@ -366,7 +384,7 @@ def __str__(self):
366
384
return f"BoundlessRotatedSpaceIntervention(embed_dim={ self .embed_dim } )"
367
385
368
386
369
- class SigmoidMaskRotatedSpaceIntervention (TrainableIntervention ):
387
+ class SigmoidMaskRotatedSpaceIntervention (TrainableIntervention , DistributedRepresentationIntervention ):
370
388
371
389
"""Intervention in the rotated space with boundary mask."""
372
390
@@ -420,7 +438,7 @@ def __str__(self):
420
438
return f"SigmoidMaskRotatedSpaceIntervention(embed_dim={ self .embed_dim } )"
421
439
422
440
423
- class LowRankRotatedSpaceIntervention (TrainableIntervention ):
441
+ class LowRankRotatedSpaceIntervention (TrainableIntervention , DistributedRepresentationIntervention ):
424
442
425
443
"""Intervention in the rotated space."""
426
444
@@ -503,7 +521,7 @@ def __str__(self):
503
521
return f"LowRankRotatedSpaceIntervention(embed_dim={ self .embed_dim } )"
504
522
505
523
506
- class PCARotatedSpaceIntervention (BasisAgnosticIntervention ):
524
+ class PCARotatedSpaceIntervention (BasisAgnosticIntervention , DistributedRepresentationIntervention ):
507
525
"""Intervention in the pca space."""
508
526
509
527
def __init__ (self , embed_dim , ** kwargs ):
0 commit comments