@@ -19,9 +19,9 @@ class ChebaiBaseNet(LightningModule):
1919 Args:
2020 criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None.
2121 out_dim (int, optional): The output dimension of the model. Defaults to None.
22- train_metrics (Dict[str, Metric] , optional): The metrics to be used during training. Defaults to None.
23- val_metrics (Dict[str, Metric] , optional): The metrics to be used during validation. Defaults to None.
24- test_metrics (Dict[str, Metric] , optional): The metrics to be used during testing. Defaults to None.
22+ train_metrics (torch.nn.Module , optional): The metrics to be used during training. Defaults to None.
23+ val_metrics (torch.nn.Module , optional): The metrics to be used during validation. Defaults to None.
24+ test_metrics (torch.nn.Module , optional): The metrics to be used during testing. Defaults to None.
2525 pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True.
2626 optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None.
2727 **kwargs: Additional keyword arguments.
@@ -36,9 +36,9 @@ def __init__(
3636 self ,
3737 criterion : torch .nn .Module = None ,
3838 out_dim : Optional [int ] = None ,
39- train_metrics : Optional [Dict [ str , Metric ] ] = None ,
40- val_metrics : Optional [Dict [ str , Metric ] ] = None ,
41- test_metrics : Optional [Dict [ str , Metric ] ] = None ,
39+ train_metrics : Optional [torch . nn . Module ] = None ,
40+ val_metrics : Optional [torch . nn . Module ] = None ,
41+ test_metrics : Optional [torch . nn . Module ] = None ,
4242 pass_loss_kwargs : bool = True ,
4343 optimizer_kwargs : Optional [Dict [str , Any ]] = None ,
4444 ** kwargs ,
@@ -207,7 +207,7 @@ def _execute(
207207 self ,
208208 batch : XYData ,
209209 batch_idx : int ,
210- metrics : Dict [ str , Metric ] ,
210+ metrics : Optional [ torch . nn . Module ] = None ,
211211 prefix : Optional [str ] = "" ,
212212 log : Optional [bool ] = True ,
213213 sync_dist : Optional [bool ] = False ,
@@ -218,7 +218,7 @@ def _execute(
218218 Args:
219219 batch (XYData): The input batch of data.
220220 batch_idx (int): The index of the current batch.
221- metrics (Dict[str, Metric] ): A dictionary of metrics to track.
221+ metrics (torch.nn.Module ): A dictionary of metrics to track.
222222 prefix (str, optional): A prefix to add to the metric names. Defaults to "".
223223 log (bool, optional): Whether to log the metrics. Defaults to True.
224224 sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
@@ -275,13 +275,13 @@ def _execute(
275275 self ._log_metrics (prefix , metrics , len (batch ))
276276 return d
277277
278- def _log_metrics (self , prefix : str , metrics : Dict [ str , Metric ] , batch_size : int ):
278+ def _log_metrics (self , prefix : str , metrics : torch . nn . Module , batch_size : int ):
279279 """
280280 Logs the metrics for the given prefix.
281281
282282 Args:
283283 prefix (str): The prefix to be added to the metric names.
284- metrics (Dict[str, Metric] ): A dictionary containing the metrics to be logged.
284+ metrics (torch.nn.Module ): A dictionary containing the metrics to be logged.
285285 batch_size (int): The batch size used for logging.
286286
287287 Returns:
0 commit comments