Skip to content

Commit 7d5094b

Browse files
Adjust in the checkpoint adjuster script
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent b1443fa commit 7d5094b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

examples/scripts/fix_backwards_compatibility.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@
3636
if k == 'state_dict':
3737
state_dict_renamed[k] = {}
3838
for k1, v1 in v.items():
39-
if 'model.encoder.' in k1:
40-
state_dict_renamed[k][k1.replace('model.encoder.', 'model.encoder._timm_module.')] = v1
39+
splits = k1.split(".")
40+
splits_ = [s for s in splits if "timm" not in s]
41+
k1_ = ".".join(splits_)
42+
if k1 != k1_:
43+
state_dict_renamed[k][k1_] = v1
4144
else:
4245
state_dict_renamed[k][k1] = v1
4346
else:
4447
state_dict_renamed[k] = v
45-
4648

47-
torch.save(state_dict_renamed, path_out)
49+
torch.save(state_dict_renamed, path_out)

0 commit comments

Comments
 (0)