Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.functional import local_response_norm, mse_loss, l1_loss
from torch.nn.functional import mse_loss, l1_loss
from torch import Tensor
from typing import Optional, Dict, Tuple

import time
from lightning import LightningModule
from torchmdnet.models.model import create_model, load_model

Expand Down Expand Up @@ -37,6 +37,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None):
self.losses = None
self._reset_losses_dict()

self.tstart = time.time()

def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
Expand Down Expand Up @@ -215,6 +217,7 @@ def on_validation_epoch_end(self):
"epoch": float(self.current_epoch),
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
}
result_dict["time"] = time.time() - self.tstart
result_dict.update(self._get_mean_loss_dict_for_type("total"))
result_dict.update(self._get_mean_loss_dict_for_type("y"))
result_dict.update(self._get_mean_loss_dict_for_type("neg_dy"))
Expand All @@ -226,6 +229,7 @@ def on_test_epoch_end(self):
# Log all test losses
if not self.trainer.sanity_checking:
result_dict = {}
result_dict["time"] = time.time() - self.tstart
result_dict.update(self._get_mean_loss_dict_for_type("total"))
result_dict.update(self._get_mean_loss_dict_for_type("y"))
result_dict.update(self._get_mean_loss_dict_for_type("neg_dy"))
Expand Down