Skip to content

Commit

Permalink
add simple location broadcast for easy interfact
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 17, 2024
1 parent 4f9e5ca commit a36ac29
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<a href="https://nlp.stanford.edu/~wuzhengx/"><strong>Library Paper and Doc Are Forthcoming »</strong></a>
</div>

<a href="https://pypi.org/project/pyvene/"><img src="https://img.shields.io/pypi/v/pyvene?color=red"></img></a>
<a href="https://pypi.org/project/pyvene/"><img src="https://img.shields.io/pypi/v/pyvene?color=red"></img></a> *This is a beta-release.*

# **Use _Activation Intervention_ to Interpret _Causal Mechanism_ of Model**
**pyvene** supports customizable interventions on different neural architectures (e.g., RNN or Transformers). It supports complex intervention schemas (e.g., parallel or serialized interventions) and a wide range of intervention modes (e.g., static or trained interventions) at scale to gain interpretability insights.
Expand Down
19 changes: 19 additions & 0 deletions pyvene/models/basic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,22 @@ def top_vals(tokenizer, res, n=10):
for i, _ in enumerate(top_values):
tok = format_token(tokenizer, top_indices[i].item())
print(f"{tok:<20} {top_values[i].item()}")


def get_list_depth(lst):
"""Return the max depth of the input list"""
if isinstance(lst, list):
return 1 + max((list_depth(item) for item in lst), default=0)
return 0

def get_batch_size(model_input):
"""
Get batch size based on the input
"""
if isinstance(model_input, torch.Tensor):
batch_size = model_input.shape[0]
else:
for _, v in model_input.items():
batch_size = v.shape[0]
break
return batch_size
32 changes: 29 additions & 3 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,27 @@ def _wait_for_forward_with_serial_intervention(
# for setters, we don't remove them.
all_set_handlers.extend(set_handlers)
return all_set_handlers


def _broadcast_unit_locations(
self,
batch_size,
unit_locations
):
_unit_locations = copy.deepcopy(unit_locations)
for k, v in unit_locations.items():
if isinstance(v, int):
_unit_locations[k] = ([[[v]]*batch_size], [[[v]]*batch_size])
elif isinstance(v[0], int) and isinstance(v[1], int):
_unit_locations[k] = ([[[v[0]]]*batch_size], [[[v[1]]]*batch_size])
elif isinstance(v[0], list) and isinstance(v[1], list):
pass # we don't support boardcase here yet.
else:
raise ValueError(
f"unit_locations {unit_locations} contains invalid format."
)

return _unit_locations

def forward(
self,
base,
Expand Down Expand Up @@ -1153,7 +1173,10 @@ def forward(
# if no source inputs, we are calling a simple forward
if sources is None and activations_sources is None:
return self.model(**base), None


unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)

self._input_validation(
base,
sources,
Expand Down Expand Up @@ -1258,7 +1281,10 @@ def generate(

if sources is None and activations_sources is None:
return self.model.generate(inputs=base["input_ids"], **kwargs), None


unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)

self._input_validation(
base,
sources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _test_with_position_intervention(
intervention_type,
positions=[0],
use_fast=False,
use_boardcast=False,
):
max_position = np.max(np.array(positions))
if isinstance(positions[0], list):
Expand Down Expand Up @@ -155,19 +156,31 @@ def _test_with_position_intervention(
golden_out = GPT2_RUN(
self.gpt2, base["input_ids"], {}, {_key: base_activations[_key]}
)

if isinstance(positions[0], list):
_, out_output = intervenable(
base, [source], {"sources->base": ([positions], [positions])}

if use_boardcast:
assert isinstance(positions[0], int)
_, out_output_1 = intervenable(
base, [source], {"sources->base": positions[0]}
)
else:
_, out_output = intervenable(
base,
[source],
{"sources->base": ([[positions] * b_s], [[positions] * b_s])},
self.assertTrue(torch.allclose(out_output_1[0], golden_out))

_, out_output_2 = intervenable(
base, [source], {"sources->base": (positions[0], positions[0])}
)
self.assertTrue(torch.allclose(out_output_2[0], golden_out))
else:
if isinstance(positions[0], list):
_, out_output = intervenable(
base, [source], {"sources->base": ([positions], [positions])}
)
else:
_, out_output = intervenable(
base,
[source],
{"sources->base": ([[positions] * b_s], [[positions] * b_s])},
)

self.assertTrue(torch.allclose(out_output[0], golden_out))
self.assertTrue(torch.allclose(out_output[0], golden_out))

def test_with_single_position_vanilla_intervention_positive(self):
"""
Expand Down Expand Up @@ -337,6 +350,28 @@ def test_with_use_fast_vanilla_intervention_positive(self):
use_fast=True,
)

def test_with_location_broadcast_vanilla_intervention_positive(self):
"""
Enable use_fast with vanilla intervention.
"""
for stream in self.nonhead_streams:
print(f"testing broadcast with stream: {stream} with a single position")
self._test_with_position_intervention(
intervention_layer=random.randint(0, 3),
intervention_stream=stream,
intervention_type=VanillaIntervention,
positions=[random.randint(0, 3)],
use_boardcast=True,
)
print(f"testing broadcast with stream: {stream} with a single position (with fast)")
self._test_with_position_intervention(
intervention_layer=random.randint(0, 3),
intervention_stream=stream,
intervention_type=VanillaIntervention,
positions=[random.randint(0, 3)],
use_fast=True,
use_boardcast=True,
)

def suite():
suite = unittest.TestSuite()
Expand Down Expand Up @@ -376,6 +411,11 @@ def suite():
"test_with_use_fast_vanilla_intervention_positive"
)
)
suite.addTest(
VanillaInterventionWithTransformerTestCase(
"test_with_location_broadcast_vanilla_intervention_positive"
)
)
return suite


Expand Down

0 comments on commit a36ac29

Please sign in to comment.