diff --git a/examples/scripts/fix_backwards_compatibility.py b/examples/scripts/fix_backwards_compatibility.py index 2f2dc331..254bddd9 100644 --- a/examples/scripts/fix_backwards_compatibility.py +++ b/examples/scripts/fix_backwards_compatibility.py @@ -36,12 +36,14 @@ if k == 'state_dict': state_dict_renamed[k] = {} for k1, v1 in v.items(): - if 'model.encoder.' in k1: - state_dict_renamed[k][k1.replace('model.encoder.', 'model.encoder._timm_module.')] = v1 + splits = k1.split(".") + splits_ = [s for s in splits if "timm" not in s] + k1_ = ".".join(splits_) + if k1 != k1_: + state_dict_renamed[k][k1_] = v1 else: state_dict_renamed[k][k1] = v1 else: state_dict_renamed[k] = v - -torch.save(state_dict_renamed, path_out) \ No newline at end of file +torch.save(state_dict_renamed, path_out)