Description
keras/keras/src/trainers/data_adapters/__init__.py
Lines 93 to 97 in c7b66fc
The model.fit()
method doesn't work on torch's DataLoaders.
Having dugged into the source code, I understand what's going on internally:
TFDatasetAdapter
receivestf.data.Dataset
- on top of it
TFDatasetAdapter
maps a function that:
- extracts the the respective weight for the label
- adds it as a third argument of the data pipeline
Having looked into the inner functionality of the training step, I can tell that it simply unpacks the data into three arguments: input tensor, label, and class weight. So to solve this issue I have to do redefine __getitem__
of the torch Dataset, to make sure it outputs the correct class weight as a third argument.
Now, I understand why it may not be technically feasible to map a custom function on top of DataLoaders. Whatever, the solution is simple enough. But to find it, I had to look through a bunch of source code, which definitely shouldn't be the case.
So this error message is just not helpful for two reasons:
- It's obscure that I simply cannot do that. By looking at documentation of
model.fit()
I have no idea that I cannot useclass_weight
with torch DataLoader - The error message itself does not offer any solutions for the problem. A guide, a note, anything helpful, basically