Skip to content

Commit

Permalink
Merge pull request #177 from stanfordnlp/zen/as_adaptor
Browse files Browse the repository at this point in the history
[Minor] Start to support generic intervention output, and adaptor-like tuning
  • Loading branch information
frankaging authored Jul 25, 2024
2 parents 90417a7 + 3b35f9b commit 9591f47
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .models.interventions import NoiseIntervention
from .models.interventions import SigmoidMaskIntervention
from .models.interventions import AutoencoderIntervention
from .models.interventions import InterventionOutput


# Utils
Expand Down
17 changes: 14 additions & 3 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
TrainableIntervention,
SkipIntervention,
CollectIntervention,
BoundlessRotatedSpaceIntervention
BoundlessRotatedSpaceIntervention,
InterventionOutput
)

from torch import optim
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, config, model, backend, **kwargs):
self.is_model_stateless = is_stateless(model)
self.config.model_type = str(type(model)) # backfill
self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False
self.as_adaptor = kwargs["as_adaptor"] if "as_adaptor" in kwargs else False

self.model_has_grad = False
if self.use_fast:
Expand Down Expand Up @@ -224,6 +226,8 @@ def __init__(self, config, model, backend, **kwargs):
# cached swapped activations (hot)
self.hot_activations = {}

self.aux_loss = []

# temp fields should not be accessed outside
self._batched_setter_activation_select = {}
"""
Expand Down Expand Up @@ -1509,7 +1513,8 @@ def _intervention_setter(
] # batch_size

def hook_callback(model, args, kwargs, output=None):
if self._is_generation:
# if it is None, we use it as adaptor.
if unit_locations_base[key_i] is not None and self._is_generation:
is_prompt = self._key_setter_call_counter[key] == 0
if not self._intervene_on_prompt or is_prompt:
self._key_setter_call_counter[key] += 1
Expand Down Expand Up @@ -1555,6 +1560,10 @@ def hook_callback(model, args, kwargs, output=None):
intervention,
subspaces[key_i] if subspaces is not None else None,
)
if isinstance(intervened_representation, InterventionOutput):
if intervened_representation.loss is not None:
self.aux_loss.append(intervened_representation.loss)
intervened_representation = intervened_representation.output
else:
intervened_representation = do_intervention(
selected_output,
Expand Down Expand Up @@ -1852,7 +1861,9 @@ def forward(
activations_sources = source_representations
if sources is not None and not isinstance(sources, list):
sources = [sources]


self.aux_loss.clear()

self._cleanup_states()

# if no source input or intervention, we return base
Expand Down
14 changes: 14 additions & 0 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Optional, Sequence, Union, List, Any

from .layers import RotateLayer, LowRankRotateLayer, SubspaceLowRankRotateLayer, AutoencoderLayer
from .basic_utils import sigmoid_boundary
from .intervention_utils import _can_use_fast, _do_intervention_by_swap

from dataclasses import dataclass
from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


@dataclass
class InterventionOutput(ModelOutput):
"""
Output of the IntervenableModel, including original outputs, intervened outputs, and collected activations.
"""
output: Optional[Any] = None
loss: Optional[Any] = None


class Intervention(torch.nn.Module):

Expand Down
14 changes: 11 additions & 3 deletions pyvene/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,14 @@ def do_intervention(
source_representation_f = bhsd_to_bs_hd(source_representation)
else:
assert False # what's going on?
intervened_representation = intervention(

intervention_output = intervention(
base_representation_f, source_representation_f, subspaces
)
if isinstance(intervention_output, InterventionOutput):
intervened_representation = intervention_output.output
else:
intervened_representation = intervention_output

post_d = intervened_representation.shape[-1]

Expand All @@ -481,7 +485,11 @@ def do_intervention(
else:
assert False # what's going on?

return intervened_representation
if not isinstance(intervention_output, InterventionOutput):
return intervened_representation

intervention_output.output = intervened_representation
return intervention_output


def simple_output_to_subcomponent(output, representation_type, model_config):
Expand Down

0 comments on commit 9591f47

Please sign in to comment.