diff --git a/pyvene/models/intervenable_base.py b/pyvene/models/intervenable_base.py index 1033cbaa..f1426648 100644 --- a/pyvene/models/intervenable_base.py +++ b/pyvene/models/intervenable_base.py @@ -1317,6 +1317,7 @@ def forward( unit_locations: Optional[Dict] = None, source_representations: Optional[Dict] = None, subspaces: Optional[List] = None, + labels: Optional[torch.LongTensor] = None, output_original_output: Optional[bool] = False, return_dict: Optional[bool] = None, ): @@ -1438,7 +1439,10 @@ def forward( ) # run intervened forward - counterfactual_outputs = self.model(**base) + if labels is not None: + counterfactual_outputs = self.model(**base, labels=labels) + else: + counterfactual_outputs = self.model(**base) set_handlers_to_remove.remove() self._output_validation() diff --git a/setup.py b/setup.py index 07692470..506b45d3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name="pyvene", - version="0.0.8dev", + version="0.0.8", description="Use Activation Intervention to Interpret Causal Mechanism of Model", long_description=long_description, long_description_content_type='text/markdown',