From 810c90daac94a56a123734e25818d4926998efed Mon Sep 17 00:00:00 2001 From: atticusg Date: Tue, 14 Jan 2025 21:06:12 -0800 Subject: [PATCH 1/5] add qwen model --- pyvene/models/intervenable_modelcard.py | 11 ++- .../qwen/modelings_intervenable_qwen.py | 77 +++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 pyvene/models/qwen/modelings_intervenable_qwen.py diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index f8997850..e9b8022f 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -11,6 +11,7 @@ from .blip.modelings_intervenable_blip_itm import * from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import * from .llava.modelings_intervenable_llava import * +from .qwen.modelings_intervenable_qwen import * # Add Qwen import ######################################################################### @@ -62,7 +63,9 @@ GRULMHeadModel: gru_lm_type_to_module_mapping, GRUForClassification: gru_classifier_type_to_module_mapping, BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_module_mapping, - # new model type goes here after defining the model files + hf_models.qwen.modeling_qwen.QWenModel: qwen_type_to_module_mapping, + hf_models.qwen.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_module_mapping, + hf_models.qwen.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_module_mapping, } @@ -93,6 +96,8 @@ GRULMHeadModel: gru_lm_type_to_dimension_mapping, GRUForClassification: gru_classifier_type_to_dimension_mapping, BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_dimension_mapping, - # new model type goes here after defining the model files + hf_models.qwen.modeling_qwen.QWenModel: qwen_type_to_dimension_mapping, + hf_models.qwen.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_dimension_mapping, + hf_models.qwen.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_dimension_mapping, } -######################################################################### +######################################################################### \ No newline at end of file diff --git a/pyvene/models/qwen/modelings_intervenable_qwen.py b/pyvene/models/qwen/modelings_intervenable_qwen.py new file mode 100644 index 00000000..1a99911a --- /dev/null +++ b/pyvene/models/qwen/modelings_intervenable_qwen.py @@ -0,0 +1,77 @@ +""" +Each modeling file in this library is a mapping between +abstract naming of intervention anchor points and actual +model module defined in the huggingface library. +We also want to let the intervention library know how to +config the dimensions of intervention based on model config +defined in the huggingface library. +""" +import torch +from ..constants import * + +qwen_type_to_module_mapping = { + "block_input": ("h[%s]", CONST_INPUT_HOOK), + "block_output": ("h[%s]", CONST_OUTPUT_HOOK), + "mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK), + "mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK), + "mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK), + "attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK), + "head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), + "attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK), + "attention_input": ("h[%s].attn", CONST_INPUT_HOOK), + "query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK), + "key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK), + "value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK), + "head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), + "head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), + "head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), +} + +qwen_type_to_dimension_mapping = { + "n_head": ("num_attention_heads",), + "n_kv_head": ("num_key_value_heads",), + "block_input": ("hidden_size",), + "block_output": ("hidden_size",), + "mlp_activation": ("intermediate_size",), + "mlp_output": ("hidden_size",), + "mlp_input": ("hidden_size",), + "attention_value_output": ("hidden_size",), + "head_attention_value_output": ("head_dim",), + "attention_output": ("hidden_size",), + "attention_input": ("hidden_size",), + "query_output": ("hidden_size",), + "key_output": ("hidden_size",), + "value_output": ("hidden_size",), + "head_query_output": ("head_dim",), + "head_key_output": ("head_dim",), + "head_value_output": ("head_dim",), +} + +"""qwen model with LM head""" +qwen_lm_type_to_module_mapping = {} +for k, v in qwen_type_to_module_mapping.items(): + qwen_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] +qwen_lm_type_to_dimension_mapping = qwen_type_to_dimension_mapping + +"""qwen model with classifier head""" +qwen_classifier_type_to_module_mapping = {} +for k, v in qwen_type_to_module_mapping.items(): + qwen_classifier_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] +qwen_classifier_type_to_dimension_mapping = qwen_type_to_dimension_mapping + +def create_qwen( + name="Qwen/Qwen2.5-0.5B", cache_dir=None, dtype=torch.bfloat16 +): + """Creates a Causal LM model, config, and tokenizer from the given name and revision""" + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + + config = AutoConfig.from_pretrained(name, cache_dir=cache_dir) + tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir) + model = AutoModelForCausalLM.from_pretrained( + name, + config=config, + cache_dir=cache_dir, + torch_dtype=dtype, + ) + print("loaded model") + return config, tokenizer, model \ No newline at end of file From 2fd81abf7812bcbc23cae22aaf22f63af1487df5 Mon Sep 17 00:00:00 2001 From: atticusg Date: Tue, 14 Jan 2025 21:12:25 -0800 Subject: [PATCH 2/5] Update intervenable_modelcard.py --- pyvene/models/intervenable_modelcard.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index e9b8022f..1ee4bc4f 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -63,9 +63,9 @@ GRULMHeadModel: gru_lm_type_to_module_mapping, GRUForClassification: gru_classifier_type_to_module_mapping, BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_module_mapping, - hf_models.qwen.modeling_qwen.QWenModel: qwen_type_to_module_mapping, - hf_models.qwen.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_module_mapping, - hf_models.qwen.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_module_mapping, + hf_models.qwen2.modeling_qwen.QWenModel: qwen_type_to_module_mapping, + hf_models.qwen2.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_module_mapping, + hf_models.qwen2.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_module_mapping, } @@ -96,8 +96,8 @@ GRULMHeadModel: gru_lm_type_to_dimension_mapping, GRUForClassification: gru_classifier_type_to_dimension_mapping, BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_dimension_mapping, - hf_models.qwen.modeling_qwen.QWenModel: qwen_type_to_dimension_mapping, - hf_models.qwen.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_dimension_mapping, - hf_models.qwen.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_dimension_mapping, + hf_models.qwen2.modeling_qwen.QWenModel: qwen_type_to_dimension_mapping, + hf_models.qwen2.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_dimension_mapping, + hf_models.qwen2.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_dimension_mapping, } ######################################################################### \ No newline at end of file From 484acee0d860673a25d3693308fb63ec89e18c2a Mon Sep 17 00:00:00 2001 From: atticusg Date: Tue, 14 Jan 2025 21:18:28 -0800 Subject: [PATCH 3/5] fix qwen --- pyvene/models/intervenable_modelcard.py | 14 ++++++------ .../modelings_intervenable_qwen2.py} | 22 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) rename pyvene/models/{qwen/modelings_intervenable_qwen.py => qwen2/modelings_intervenable_qwen2.py} (82%) diff --git a/pyvene/models/intervenable_modelcard.py b/pyvene/models/intervenable_modelcard.py index 1ee4bc4f..48158fee 100644 --- a/pyvene/models/intervenable_modelcard.py +++ b/pyvene/models/intervenable_modelcard.py @@ -11,7 +11,7 @@ from .blip.modelings_intervenable_blip_itm import * from .backpack_gpt2.modelings_intervenable_backpack_gpt2 import * from .llava.modelings_intervenable_llava import * -from .qwen.modelings_intervenable_qwen import * # Add Qwen import +from .qwen2.modelings_intervenable_qwen2 import * ######################################################################### @@ -63,9 +63,9 @@ GRULMHeadModel: gru_lm_type_to_module_mapping, GRUForClassification: gru_classifier_type_to_module_mapping, BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_module_mapping, - hf_models.qwen2.modeling_qwen.QWenModel: qwen_type_to_module_mapping, - hf_models.qwen2.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_module_mapping, - hf_models.qwen2.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_module_mapping, + hf_models.qwen2.modeling_qwen2.Qwen2Model: qwen2_type_to_module_mapping, + hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_module_mapping, + hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_module_mapping, } @@ -96,8 +96,8 @@ GRULMHeadModel: gru_lm_type_to_dimension_mapping, GRUForClassification: gru_classifier_type_to_dimension_mapping, BackpackGPT2LMHeadModel: backpack_gpt2_lm_type_to_dimension_mapping, - hf_models.qwen2.modeling_qwen.QWenModel: qwen_type_to_dimension_mapping, - hf_models.qwen2.modeling_qwen.QWenForCausalLM: qwen_lm_type_to_dimension_mapping, - hf_models.qwen2.modeling_qwen.QWenForSequenceClassification: qwen_classifier_type_to_dimension_mapping, + hf_models.qwen2.modeling_qwen2.Qwen2Model: qwen2_type_to_dimension_mapping, + hf_models.qwen2.modeling_qwen2.Qwen2ForCausalLM: qwen2_lm_type_to_dimension_mapping, + hf_models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification: qwen2_classifier_type_to_dimension_mapping, } ######################################################################### \ No newline at end of file diff --git a/pyvene/models/qwen/modelings_intervenable_qwen.py b/pyvene/models/qwen2/modelings_intervenable_qwen2.py similarity index 82% rename from pyvene/models/qwen/modelings_intervenable_qwen.py rename to pyvene/models/qwen2/modelings_intervenable_qwen2.py index 1a99911a..08a24d6a 100644 --- a/pyvene/models/qwen/modelings_intervenable_qwen.py +++ b/pyvene/models/qwen2/modelings_intervenable_qwen2.py @@ -9,7 +9,7 @@ import torch from ..constants import * -qwen_type_to_module_mapping = { +qwen2_type_to_module_mapping = { "block_input": ("h[%s]", CONST_INPUT_HOOK), "block_output": ("h[%s]", CONST_OUTPUT_HOOK), "mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK), @@ -27,7 +27,7 @@ "head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), } -qwen_type_to_dimension_mapping = { +qwen2_type_to_dimension_mapping = { "n_head": ("num_attention_heads",), "n_kv_head": ("num_key_value_heads",), "block_input": ("hidden_size",), @@ -48,18 +48,18 @@ } """qwen model with LM head""" -qwen_lm_type_to_module_mapping = {} -for k, v in qwen_type_to_module_mapping.items(): - qwen_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] -qwen_lm_type_to_dimension_mapping = qwen_type_to_dimension_mapping +qwen2_lm_type_to_module_mapping = {} +for k, v in qwen2_type_to_module_mapping.items(): + qwen2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] +qwen2_lm_type_to_dimension_mapping = qwen_type_to_dimension_mapping """qwen model with classifier head""" -qwen_classifier_type_to_module_mapping = {} -for k, v in qwen_type_to_module_mapping.items(): - qwen_classifier_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] -qwen_classifier_type_to_dimension_mapping = qwen_type_to_dimension_mapping +qwen2_classifier_type_to_module_mapping = {} +for k, v in qwen2_type_to_module_mapping.items(): + qwen2_classifier_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] +qwen2_classifier_type_to_dimension_mapping = qwen_type_to_dimension_mapping -def create_qwen( +def create_qwen2( name="Qwen/Qwen2.5-0.5B", cache_dir=None, dtype=torch.bfloat16 ): """Creates a Causal LM model, config, and tokenizer from the given name and revision""" From 63ce83179cbc73e60c1ade32353609387d547c03 Mon Sep 17 00:00:00 2001 From: atticusg Date: Tue, 14 Jan 2025 21:23:10 -0800 Subject: [PATCH 4/5] Update modelings_intervenable_qwen2.py --- pyvene/models/qwen2/modelings_intervenable_qwen2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyvene/models/qwen2/modelings_intervenable_qwen2.py b/pyvene/models/qwen2/modelings_intervenable_qwen2.py index 08a24d6a..c9013e82 100644 --- a/pyvene/models/qwen2/modelings_intervenable_qwen2.py +++ b/pyvene/models/qwen2/modelings_intervenable_qwen2.py @@ -51,13 +51,13 @@ qwen2_lm_type_to_module_mapping = {} for k, v in qwen2_type_to_module_mapping.items(): qwen2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] -qwen2_lm_type_to_dimension_mapping = qwen_type_to_dimension_mapping +qwen2_lm_type_to_dimension_mapping = qwen2_type_to_dimension_mapping """qwen model with classifier head""" qwen2_classifier_type_to_module_mapping = {} for k, v in qwen2_type_to_module_mapping.items(): qwen2_classifier_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] -qwen2_classifier_type_to_dimension_mapping = qwen_type_to_dimension_mapping +qwen2_classifier_type_to_dimension_mapping = qwen2_type_to_dimension_mapping def create_qwen2( name="Qwen/Qwen2.5-0.5B", cache_dir=None, dtype=torch.bfloat16 From 129fbc805a61b19bd7f25a86f9f793a9d4cbcee8 Mon Sep 17 00:00:00 2001 From: atticusg Date: Tue, 14 Jan 2025 21:55:51 -0800 Subject: [PATCH 5/5] Update modelings_intervenable_qwen2.py --- .../qwen2/modelings_intervenable_qwen2.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/pyvene/models/qwen2/modelings_intervenable_qwen2.py b/pyvene/models/qwen2/modelings_intervenable_qwen2.py index c9013e82..74c428e1 100644 --- a/pyvene/models/qwen2/modelings_intervenable_qwen2.py +++ b/pyvene/models/qwen2/modelings_intervenable_qwen2.py @@ -10,26 +10,26 @@ from ..constants import * qwen2_type_to_module_mapping = { - "block_input": ("h[%s]", CONST_INPUT_HOOK), - "block_output": ("h[%s]", CONST_OUTPUT_HOOK), - "mlp_activation": ("h[%s].mlp.act", CONST_OUTPUT_HOOK), - "mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK), - "mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK), - "attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK), - "head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), - "attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK), - "attention_input": ("h[%s].attn", CONST_INPUT_HOOK), - "query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK), - "key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK), - "value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK), - "head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), - "head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), - "head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), + "block_input": ("layers[%s]", CONST_INPUT_HOOK), + "block_output": ("layers[%s]", CONST_OUTPUT_HOOK), + "mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK), + "mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK), + "mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK), + "attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK), + "head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")), + "attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK), + "attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK), + "query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK), + "key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK), + "value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK), + "head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")), + "head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), + "head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_kv_head")), } qwen2_type_to_dimension_mapping = { "n_head": ("num_attention_heads",), - "n_kv_head": ("num_key_value_heads",), + "n_kv_head": ("num_key_value_heads",), "block_input": ("hidden_size",), "block_output": ("hidden_size",), "mlp_activation": ("intermediate_size",), @@ -47,20 +47,20 @@ "head_value_output": ("head_dim",), } -"""qwen model with LM head""" +"""qwen2 model with LM head""" qwen2_lm_type_to_module_mapping = {} for k, v in qwen2_type_to_module_mapping.items(): - qwen2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] + qwen2_lm_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:] qwen2_lm_type_to_dimension_mapping = qwen2_type_to_dimension_mapping -"""qwen model with classifier head""" +"""qwen2 model with classifier head""" qwen2_classifier_type_to_module_mapping = {} for k, v in qwen2_type_to_module_mapping.items(): - qwen2_classifier_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:] + qwen2_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", ) + v[1:] qwen2_classifier_type_to_dimension_mapping = qwen2_type_to_dimension_mapping def create_qwen2( - name="Qwen/Qwen2.5-0.5B", cache_dir=None, dtype=torch.bfloat16 + name="Qwen/Qwen2-7B-beta", cache_dir=None, dtype=torch.bfloat16 ): """Creates a Causal LM model, config, and tokenizer from the given name and revision""" from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer