diff --git a/model_coverage.py b/model_coverage.py index f5ba1a3..98c9a4a 100644 --- a/model_coverage.py +++ b/model_coverage.py @@ -91,7 +91,7 @@ def mask_text(self, text_tokenized): return masked def reload_model(self, model_file): - print(self.model.load_state_dict(torch.load(model_file), strict=False)) + print(self.model.load_state_dict(torch.load(model_file, map_location=torch.device(self.device)), strict=False)) def save_model(self, model_file): torch.save(self.model.state_dict(), model_file) diff --git a/model_generator.py b/model_generator.py index a8a0efc..b1eca97 100644 --- a/model_generator.py +++ b/model_generator.py @@ -36,7 +36,7 @@ def __init__(self, max_output_length=25, max_input_length=300, device='cpu', tok self.mode = "train" def reload(self, from_file): - print(self.model.load_state_dict(torch.load(from_file), strict=False)) + print(self.model.load_state_dict(torch.load(from_file, map_location=torch.device(self.device)), strict=False)) def save(self, to_file): torch.save(self.model.state_dict(), to_file)