Skip to content

Commit

Permalink
avoid tensor squash for localist repr intervention
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 17, 2024
1 parent d40fff9 commit 769143a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
42 changes: 30 additions & 12 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,31 @@ def tie_weight(self, linked_intervention):

class ConstantSourceIntervention(Intervention):

"""Intervention the original representations."""
"""Constant source."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.is_source_constant = True



class LocalistRepresentationIntervention(torch.nn.Module):

"""Localist representation."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.is_repr_distributed = False


class DistributedRepresentationIntervention(torch.nn.Module):

"""Distributed representation."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.is_repr_distributed = True


class BasisAgnosticIntervention(Intervention):

"""Intervention that will modify its basis in a uncontrolled manner."""
Expand All @@ -66,7 +84,7 @@ def __init__(self, **kwargs):
self.shared_weights = True


class ZeroIntervention(ConstantSourceIntervention):
class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention):

"""Zero-out activations."""

Expand Down Expand Up @@ -126,7 +144,7 @@ def __str__(self):
return f"CollectIntervention(embed_dim={self.embed_dim})"


class SkipIntervention(BasisAgnosticIntervention):
class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention):

"""Skip the current intervening layer's computation in the hook function."""

Expand Down Expand Up @@ -156,7 +174,7 @@ def __str__(self):
return f"SkipIntervention(embed_dim={self.embed_dim})"


class VanillaIntervention(Intervention):
class VanillaIntervention(Intervention, LocalistRepresentationIntervention):

"""Intervention the original representations."""

Expand Down Expand Up @@ -191,7 +209,7 @@ def __str__(self):
return f"VanillaIntervention(embed_dim={self.embed_dim})"


class AdditionIntervention(BasisAgnosticIntervention):
class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention):

"""Intervention the original representations with activation addition."""

Expand Down Expand Up @@ -226,7 +244,7 @@ def __str__(self):
return f"AdditionIntervention(embed_dim={self.embed_dim})"


class SubtractionIntervention(BasisAgnosticIntervention):
class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention):

"""Intervention the original representations with activation subtraction."""

Expand Down Expand Up @@ -261,7 +279,7 @@ def __str__(self):
return f"SubtractionIntervention(embed_dim={self.embed_dim})"


class RotatedSpaceIntervention(TrainableIntervention):
class RotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):

"""Intervention in the rotated space."""

Expand Down Expand Up @@ -299,7 +317,7 @@ def __str__(self):
return f"RotatedSpaceIntervention(embed_dim={self.embed_dim})"


class BoundlessRotatedSpaceIntervention(TrainableIntervention):
class BoundlessRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):

"""Intervention in the rotated space with boundary mask."""

Expand Down Expand Up @@ -366,7 +384,7 @@ def __str__(self):
return f"BoundlessRotatedSpaceIntervention(embed_dim={self.embed_dim})"


class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention):
class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):

"""Intervention in the rotated space with boundary mask."""

Expand Down Expand Up @@ -420,7 +438,7 @@ def __str__(self):
return f"SigmoidMaskRotatedSpaceIntervention(embed_dim={self.embed_dim})"


class LowRankRotatedSpaceIntervention(TrainableIntervention):
class LowRankRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention):

"""Intervention in the rotated space."""

Expand Down Expand Up @@ -503,7 +521,7 @@ def __str__(self):
return f"LowRankRotatedSpaceIntervention(embed_dim={self.embed_dim})"


class PCARotatedSpaceIntervention(BasisAgnosticIntervention):
class PCARotatedSpaceIntervention(BasisAgnosticIntervention, DistributedRepresentationIntervention):
"""Intervention in the pca space."""

def __init__(self, embed_dim, **kwargs):
Expand Down
7 changes: 5 additions & 2 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
import numpy as np
from .intervenable_modelcard import *
from .interventions import *


def get_internal_model_type(model):
Expand Down Expand Up @@ -517,7 +518,8 @@ def do_intervention(

# flatten
original_base_shape = base_representation.shape
if len(original_base_shape) == 2:
if len(original_base_shape) == 2 or \
isinstance(intervention, LocalistRepresentationIntervention):
# no pos dimension, e.g., gru
base_representation_f = base_representation
source_representation_f = source_representation
Expand All @@ -537,7 +539,8 @@ def do_intervention(
)

# unflatten
if len(original_base_shape) == 2:
if len(original_base_shape) == 2 or \
isinstance(intervention, LocalistRepresentationIntervention):
# no pos dimension, e.g., gru
pass
elif len(original_base_shape) == 3:
Expand Down
29 changes: 24 additions & 5 deletions tests/integration_tests/InterventionWithGPT2TestCase.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,14 @@ def _test_with_position_intervention_constant_source(
_key = f"{intervention_layer}.{intervention_stream}"

for position in positions:
base_activations[_key][:, position] = intervention(
base_activations[_key][:, position],
None,
)
if intervention_type == ZeroIntervention:
base_activations[_key][:, position] = torch.zeros_like(
base_activations[_key][:, position])
else:
base_activations[_key][:, position] = intervention(
base_activations[_key][:, position],
None,
)

golden_out = GPT2_RUN(
self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]}
Expand Down Expand Up @@ -535,7 +539,7 @@ def test_with_position_intervention_constant_source_subtraction_intervention_pos
use_base_only=True
)

def test_with_position_intervention_constant_source_subtraction_intervention_positive(self):
def test_with_position_intervention_constant_source_zero_intervention_positive(self):
"""
Enable constant source with subtraction intervention.
"""
Expand Down Expand Up @@ -605,6 +609,21 @@ def suite():
"test_with_position_intervention_constant_source_vanilla_intervention_positive"
)
)
suite.addTest(
InterventionWithGPT2TestCase(
"test_with_position_intervention_constant_source_addition_intervention_positive"
)
)
suite.addTest(
InterventionWithGPT2TestCase(
"test_with_position_intervention_constant_source_subtraction_intervention_positive"
)
)
suite.addTest(
InterventionWithGPT2TestCase(
"test_with_position_intervention_constant_source_zero_intervention_positive"
)
)
return suite


Expand Down

0 comments on commit 769143a

Please sign in to comment.