-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
executable file
·32 lines (28 loc) · 1012 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import json
import math
import pandas as pd
import torch
import os
import sys
import shutil
import pickle
def save_checkpoint(epoch, encoder, model, metrics, filename):
state = {
'epoch': epoch,
'encoder_state_dict': encoder.state_dict(),
'classifier_state_dict': model.state_dict(),
'metrics': metrics,
}
torch.save(state, filename)
def load_dict(model, fname):
with open(fname, 'rb') as f:
weights = pickle.load(f, encoding='latin1')
own_state = model.state_dict()
for name, param in weights.items():
if name in own_state:
try:
own_state[name].copy_(torch.from_numpy(param))
except Exception:
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
else:
raise KeyError('unexpected key "{}" in state_dict'.format(name))