-
Notifications
You must be signed in to change notification settings - Fork 54
Open
Labels
cat:enhancementNew feature or requestNew feature or request
Description
When implementing the ML pipeline using the engine function of PPE, the next part is essentially where the model calculates the loss.
pytorch-pfn-extras/pytorch_pfn_extras/handler/_logic.py
Lines 209 to 217 in c5b4d58
def _forward(self, model: torch.nn.Module, batch: Any) -> Any: | |
if isinstance(batch, tuple) and hasattr(batch, "_fields"): | |
# namedtuple | |
return model(batch) | |
if isinstance(batch, dict): | |
return model(**batch) | |
if isinstance(batch, (list, tuple)): | |
return model(*batch) | |
return model(batch) |
Currently, it is abstracted to work for batches of various formats, but improving customization in this area would make it easier for users to understand the behavior with minimal code changes.
Inheriting and modifying the _forward() method also makes it possible to manage nn.Module
, which implements only inference functions that do not require labels, without wrapping it in nn.Module
for training. Specifically, this can be achieved by executing the loss function inside the _forward() method.
Metadata
Metadata
Assignees
Labels
cat:enhancementNew feature or requestNew feature or request