Skip to content

Problems about multihead replay finetuning. #1333

@SaltyLemonWHU

Description

@SaltyLemonWHU

Hi,

I'm doing multihead replay finetuning on "mace-mh-1" model and encounter some problems.

1. The "fine_tuning_select.py" script does not seem to support specifying a head when the model provides multiple heads.

I used follow command to select configurations from the replay dataset based on my dataset:

python -m mace.cli.fine_tuning_select \ --configs_pt /home/yanning2/project/MACE/database/replay/replay-data-mh-1-omat-pbe.xyz \ --configs_ft /home/yanning2/project/MACE/database/finetuning/total_energy.db \ --num_samples 30000 \ --subselect fps \ --model /home/yanning2/project/MACE/model/mace-mh-1.model \ --output /home/yanning2/project/MACE/database/select/selected_configs_omat.xyz \ --filtering_type combinations \ --head_pt oc20_usemppbe \ --head_ft norr \ --weight_pt 1.0 \ --weight_ft 10.0

and it raised error as follow:

Traceback (most recent call last): File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/fine_tuning_select.py", line 529, in <module> main() File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/fine_tuning_select.py", line 525, in main select_samples(settings) File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/fine_tuning_select.py", line 463, in select_samples calc = _load_calc( File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/fine_tuning_select.py", line 273, in _load_calc calc = MACECalculator( File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/calculators/mace.py", line 271, in __init__ raise ValueError( ValueError: Head keyword was not provided, and no head in the model is 'default'. Please provide a head keyword to specify the head you want to use. Available heads are: ['matpes_r2scan', 'mp_pbe_refit_add', 'spice_wB97M', 'oc20_usemppbe', 'omol', 'omat_pbe']

After checking the "fine_tuning_select.py", I found that there is no support on choosing a specific head. This error would not occur when using "mace-mp0" model which just provide a single "default" head.

Using head Default out of ['Default']

So what is the effect of model loading here? Is it important to keep the model consistent when selecting configurations and finetuning?

2. Error happened when using "select_head.py" script.

To solve the problem I encounter above, I noticed there is a mace_select_head CLI tool which could get a single head model from a multihead model. My running command is as follow:

python -m mace.cli.select_head 'E:\MACE\model\mace-mh-1.model' --head_name oc20_usemppbe

and it raised error as follow:

Traceback (most recent call last): File "C:\Users\Administrator\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "C:\Users\Administrator\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 86, in _run_code exec(code, run_globals) File "E:\others_script\lib\site-packages\mace\cli\select_head.py", line 60, in <module> main() File "E:\others_script\lib\site-packages\mace\cli\select_head.py", line 52, in main model_single = remove_pt_head(model, args.head_name) File "E:\others_script\lib\site-packages\mace\tools\scripts_utils.py", line 415, in remove_pt_head new_model.load_state_dict(new_state_dict) File "E:\others_script\lib\site-packages\torch\nn\modules\module.py", line 2584, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE: size mismatch for interactions.0.linear_up.weight: copying a param with shape torch.Size([65536]) from checkpoint, the shape in current model is torch.Size([2 62144]). size mismatch for interactions.0.linear_up.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size ([512]). size mismatch for interactions.0.conv_tp.output_mask: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size( [8192]). size mismatch for interactions.0.conv_tp_weights.net.9.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([2048, 64]). size mismatch for interactions.0.conv_tp_weights.net.9.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch .Size([2048]). size mismatch for interactions.0.linear_res.weight: copying a param with shape torch.Size([262144]) from checkpoint, the shape in current model is torch.Size( [1048576]). size mismatch for interactions.0.linear_1.weight: copying a param with shape torch.Size([458752]) from checkpoint, the shape in current model is torch.Size([1 835008]).

Could you please to check what's wrong with it?

3. Error happened during the finetuning.

Using the selected configurations produced by a "mace-mp0" model I mentioned in problem 1, I tried to do some finetuning on a "mace-mh-1" model. My running command is as follow:

python -m mace.cli.run_train \ --name mymodel_finetuned \ --pt_train_file /home/yanning2/project/MACE/database/select/selected_configs_omat.xyz \ --train_file /home/yanning2/project/MACE/database/finetuning/total_energy.db \ --valid_fraction 0.05 \ --foundation_model /home/yanning2/project/MACE/model/mace-mh-1.model \ --foundation_head oc20_usemppbe \ --energy_weight 1.0 \ --forces_weight 100.0 \ --swa \ --swa_energy_weight 10.0 \ --swa_forces_weight 100.0 \ --results_dir /home/yanning2/project/MACE/results

and it raised error as follow:

Traceback (most recent call last): File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/run_train.py", line 1068, in <module> main() File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/run_train.py", line 77, in main run(args) File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/cli/run_train.py", line 190, in run model_foundation = remove_pt_head( File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/mace/tools/scripts_utils.py", line 415, in remove_pt_head new_model.load_state_dict(new_state_dict) File "/home/yanning2/scratch/conda_env/MACE/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2629, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE: size mismatch for interactions.0.linear_up.weight: copying a param with shape torch.Size([65536]) from checkpoint, the shape in current model is torch.Size([262144]). size mismatch for interactions.0.linear_up.output_mask: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for interactions.0.conv_tp.output_mask: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([8192]). size mismatch for interactions.0.conv_tp_weights.net.9.weight: copying a param with shape torch.Size([512, 64]) from checkpoint, the shape in current model is torch.Size([2048, 64]). size mismatch for interactions.0.conv_tp_weights.net.9.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2048]). size mismatch for interactions.0.linear_res.weight: copying a param with shape torch.Size([262144]) from checkpoint, the shape in current model is torch.Size([1048576]). size mismatch for interactions.0.linear_1.weight: copying a param with shape torch.Size([458752]) from checkpoint, the shape in current model is torch.Size([1835008]). Using agnostic product in EquivariantProductBasisBlock Using agnostic product in EquivariantProductBasisBlock

Is this error resulted from a model inconsistency when generating selected configurations and finetuning or others? Could you please to check this?

My running environment:

Python 3.10

Name | Version
_libgcc_mutex | 0.1
_openmp_mutex | 5.1
ase | 3.27.0
bzip2 | 1.0.8
ca-certificates | 2025.12.2
configargparse | 1.7.1
contourpy | 1.3.2
cuequivariance | 0.8.1
cuequivariance-ops-cu12 | 0.8.1
cuequivariance-ops-torch-cu12 | 0.8.1
cuequivariance-torch | 0.8.1
cycler | 0.12.1
e3nn | 0.4.4
expat | 2.7.3
filelock | 3.20.3
fonttools | 4.61.1
fsspec | 2026.1.0
gitdb | 4.0.12
gitpython | 3.1.46
h5py | 3.15.1
jinja2 | 3.1.6
kiwisolver | 1.4.9
ld_impl_linux-64 | 2.44
libexpat | 2.7.3
libffi | 3.4.4
libgcc | 15.2.0
libgcc-ng | 15.2.0
libgomp | 15.2.0
libnsl | 2.0.0
libstdcxx | 15.2.0
libstdcxx-ng | 15.2.0
libuuid | 1.41.5
libxcb | 1.17.0
libzlib | 1.3.1
lightning-utilities | 0.15.2
lmdb | 1.7.5
mace-torch | 0.3.14
markupsafe | 3.0.3
matplotlib | 3.10.8
matscipy | 1.2.0
mpmath | 1.3.0
ncurses | 6.5
networkx | 3.4.2
numpy | 2.2.6
nvidia-cublas-cu12 | 12.8.4.1
nvidia-cuda-cupti-cu12 | 12.8.90
nvidia-cuda-nvrtc-cu12 | 12.8.93
nvidia-cuda-runtime-cu12 | 12.8.90
nvidia-cudnn-cu12 | 9.10.2.21
nvidia-cufft-cu12 | 11.3.3.83
nvidia-cufile-cu12 | 1.13.1.3
nvidia-curand-cu12 | 10.3.9.90
nvidia-cusolver-cu12 | 11.7.3.90
nvidia-cusparse-cu12 | 12.5.8.93
nvidia-cusparselt-cu12 | 0.7.1
nvidia-ml-py | 13.590.44
nvidia-nccl-cu12 | 2.27.5
nvidia-nvjitlink-cu12 | 12.8.93
nvidia-nvshmem-cu12 | 3.3.20
nvidia-nvtx-cu12 | 12.8.90
openssl | 3.0.18
opt-einsum | 3.4.0
opt-einsum-fx | 0.1.4
orjson | 3.11.5
packaging | 25
pandas | 2.3.3
pillow | 12.1.0
pip | 25.3
platformdirs | 4.5.1
prettytable | 3.17.0
pthread-stubs | 0.3
pyparsing | 3.3.1
python | 3.10.19
python-dateutil | 2.9.0.post0
python-hostlist | 2.3.0
pytz | 2025.2
pyyaml | 6.0.3
readline | 8.3
scipy | 1.15.3
setuptools | 80.9.0
six | 1.17.0
smmap | 5.0.2
sqlite | 3.51.1
sympy | 1.14.0
tk | 8.6.15
torch | 2.9.1
torch-ema | 0.3
torchmetrics | 1.8.2
tqdm | 4.67.1
triton | 3.5.1
typing-extensions | 4.15.0
tzdata | 2025.3
wcwidth | 0.2.14
wheel | 0.45.1
xorg-libx11 | 1.8.12
xorg-libxau | 1.0.12
xorg-libxdmcp | 1.1.5
xorg-xorgproto | 2024.1
xz | 5.6.4
zlib | 1.3.1

Thank you for your time and for developing this tool. I'm looking forward to your reply on these points.

Yours sincerely.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions