Skip to content

Documentation: better info about class_weight and/or DataAdapters #21355

Open
@DLumi

Description

@DLumi

if class_weight is not None:
raise ValueError(
"Argument `class_weight` is not supported for torch "
f"DataLoader inputs. Received: class_weight={class_weight}"
)

The model.fit() method doesn't work on torch's DataLoaders.

Having dugged into the source code, I understand what's going on internally:

  1. TFDatasetAdapter receives tf.data.Dataset
  2. on top of it TFDatasetAdapter maps a function that:
  • extracts the the respective weight for the label
  • adds it as a third argument of the data pipeline

Having looked into the inner functionality of the training step, I can tell that it simply unpacks the data into three arguments: input tensor, label, and class weight. So to solve this issue I have to do redefine __getitem__ of the torch Dataset, to make sure it outputs the correct class weight as a third argument.

Now, I understand why it may not be technically feasible to map a custom function on top of DataLoaders. Whatever, the solution is simple enough. But to find it, I had to look through a bunch of source code, which definitely shouldn't be the case.

So this error message is just not helpful for two reasons:

  1. It's obscure that I simply cannot do that. By looking at documentation of model.fit() I have no idea that I cannot use class_weight with torch DataLoader
  2. The error message itself does not offer any solutions for the problem. A guide, a note, anything helpful, basically

Metadata

Metadata

Labels

type:docsNeed to modify the documentation

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions