Skip to content

Make Logic._forward() method a public method #772

@linshokaku

Description

@linshokaku

When implementing the ML pipeline using the engine function of PPE, the next part is essentially where the model calculates the loss.

def _forward(self, model: torch.nn.Module, batch: Any) -> Any:
if isinstance(batch, tuple) and hasattr(batch, "_fields"):
# namedtuple
return model(batch)
if isinstance(batch, dict):
return model(**batch)
if isinstance(batch, (list, tuple)):
return model(*batch)
return model(batch)

Currently, it is abstracted to work for batches of various formats, but improving customization in this area would make it easier for users to understand the behavior with minimal code changes.

Inheriting and modifying the _forward() method also makes it possible to manage nn.Module, which implements only inference functions that do not require labels, without wrapping it in nn.Module for training. Specifically, this can be achieved by executing the loss function inside the _forward() method.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions