Skip to content

Commit

Permalink
update tutorials (no test)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Feb 3, 2025
1 parent bc07940 commit 262ed44
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions tutorials/advanced_tutorials/IOI_with_DAS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7301,7 +7301,7 @@
],
"source": [
"intervention = boundless_das_intervenable.interventions[\n",
" \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
" \"layer_8_repr_attention_value_output_unit_pos_nunit_1#0\"\n",
"]\n",
"boundary_mask = sigmoid_boundary(\n",
" intervention.intervention_population.repeat(1, 1),\n",
Expand Down Expand Up @@ -12475,7 +12475,7 @@
],
"source": [
"intervention = das_intervenable.interventions[\n",
" \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
" \"layer_8_repr_attention_value_output_unit_pos_nunit_1#0\"\n",
"]\n",
"learned_weights = intervention.rotate_layer.weight\n",
"headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n",
Expand Down Expand Up @@ -17400,7 +17400,7 @@
],
"source": [
"intervention = boundless_das_intervenable.interventions[\n",
" \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
" \"layer_9_repr_attention_value_output_unit_pos_nunit_1#0\"\n",
"]\n",
"boundary_mask = sigmoid_boundary(\n",
" intervention.intervention_population.repeat(1, 1),\n",
Expand Down Expand Up @@ -23343,7 +23343,7 @@
],
"source": [
"intervention = das_intervenable.interventions[\n",
" \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
" \"layer_9_repr_attention_value_output_unit_pos_nunit_1#0\"\n",
"]\n",
"learned_weights = intervention.rotate_layer.weight\n",
"headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n",
Expand Down
10 changes: 5 additions & 5 deletions tutorials/advanced_tutorials/Voting_Mechanism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@
"source": [
"torch.save(\n",
" pv_llama.interventions[\n",
" f\"layer.{layer}.comp.block_output.unit.pos.nunit.1#0\"].state_dict(), \n",
" f\"layer_{layer}_comp_block_output_unit_pos_nunit_1#0\"].state_dict(), \n",
" f\"./tutorial_data/layer.{layer}.pos.{token_position}.bin\"\n",
")"
]
Expand Down Expand Up @@ -522,9 +522,9 @@
"pv_llama = pv.IntervenableModel(pv_config, llama)\n",
"pv_llama.set_device(\"cuda\")\n",
"pv_llama.disable_model_gradients()\n",
"pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n",
"pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#0'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.75.bin'))\n",
"pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#1'].load_state_dict(\n",
"pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#1'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.80.bin'))"
]
},
Expand Down Expand Up @@ -665,11 +665,11 @@
"for loc in [78, 75, 80, [75, 80]]:\n",
" if loc == 78:\n",
" print(\"[control] intervening location: \", loc)\n",
" pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n",
" pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#0'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.78.bin'))\n",
" else:\n",
" print(\"intervening location: \", loc)\n",
" pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n",
" pv_llama.interventions[f'layer_{layer}_comp_block_output_unit_pos_nunit_1#0'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.75.bin'))\n",
" # evaluation on the test set\n",
" collected_probs = []\n",
Expand Down

0 comments on commit 262ed44

Please sign in to comment.