|
4 | 4 | from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter |
5 | 5 | from fast_llm.layers.decoder.mlp.config import MoEMLPConfig |
6 | 6 | from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat |
7 | | -from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, get_weight_and_bias_converters |
| 7 | +from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters |
8 | 8 | from fast_llm.models.gpt.conversion.mistral import ( |
9 | 9 | MistralBaseModelConverter, |
10 | 10 | MistralBlockConverter, |
@@ -50,16 +50,29 @@ def get_converters( |
50 | 50 | return [ |
51 | 51 | *get_weight_and_bias_converters( |
52 | 52 | f"{fast_llm_prefix}.router", |
53 | | - () if drop_on_export else (f"{hf_prefix}.router",), |
54 | | - config.add_linear_biases, |
| 53 | + f"{hf_prefix}.gate", |
| 54 | + False, |
| 55 | + drop_on_export=drop_on_export, |
| 56 | + ), |
| 57 | + *get_weight_and_bias_converters( |
| 58 | + f"{fast_llm_prefix}.layer_1", |
| 59 | + tuple(f"{hf_prefix}.experts.{i}.{w}" for i in range(config.experts) for w in ("w1", "w3")), |
| 60 | + False, |
55 | 61 | SplitWeightConverter, |
56 | 62 | drop_on_export=drop_on_export, |
57 | 63 | ), |
58 | | - *super().get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export), |
| 64 | + *get_weight_and_bias_converters( |
| 65 | + f"{fast_llm_prefix}.layer_2", |
| 66 | + tuple(f"{hf_prefix}.experts.{i}.w2" for i in range(config.experts)), |
| 67 | + False, |
| 68 | + MLPLayer2Converter, |
| 69 | + drop_on_export=drop_on_export, |
| 70 | + ), |
59 | 71 | ] |
60 | 72 |
|
61 | 73 |
|
62 | 74 | class MixtralBlockConverter(MistralBlockConverter): |
| 75 | + hf_mlp_name: typing.ClassVar[str] = "block_sparse_moe" |
63 | 76 | mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter |
64 | 77 |
|
65 | 78 |
|
|
0 commit comments