Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DistributedGradientBoostedTreesModel does not support Ranking task #209

Open
JackGammack opened this issue Feb 13, 2024 · 1 comment
Open

Comments

@JackGammack
Copy link

JackGammack commented Feb 13, 2024

The documentation shows that you can use the ranking task for this model, but there is no warning or failure until training time. This error message is not clear that the ranking task is actually not available for this model, and I couldn't find any documentation indicating this.

Are there plans to add support for distributed ranking models? I figure there may be limitations related to examples from the same ranking_group ending up on different workers when the ndcg needs to be calculated.

Minimal example

strategy = tf.distribute.experimental.ParameterServerStrategy(...)

with strategy.scope():
        model = tfdf.keras.DistributedGradientBoostedTreesModel(
            task=tfdf.keras.Task.RANKING,
            ranking_group="group",
        )

model.fit_on_dataset_path(
        train_path=train_input_pattern,
        label_key="label",
        weight_key="sample_weight",
        dataset_format="tfrecord+tfe",
)

Error message below. Changing the task to regression makes the model train successfully.

File "/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core.py", line 1942, in fit_on_dataset_path
    tf_core.train_on_file_dataset(
  File "/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/tensorflow/core.py", line 779, in train_on_file_dataset
    training_op.SimpleMLCheckStatus(process_id=process_id) == 1
  File "/opt/conda/lib/python3.10/site-packages/tensorflow/python/util/tf_export.py", line 403, in wrapper
    return f(**kwargs)
  File "<string>", line 1373, in simple_ml_check_status
  File "/opt/conda/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 5883, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.UnknownError: {{function_node __wrapped__SimpleMLCheckStatus_device_/job:chief/replica:0/task:0/device:CPU:0}} TensorFlow: INVALID_ARGUMENT: Worker #0: INVALID_ARGUMENT: Not supported task [Op:SimpleMLCheckStatus] name:
@rstz
Copy link
Collaborator

rstz commented Feb 14, 2024

Hi Jack,

You're right, Ranking is not currently available in distributed training. I've improved the error message and the documentation about it on https://ydf.readthedocs.org.

Our team is always happy to implement missing features in TF-DF / Yggdrasil Decision Forests, but our resources are limited, and we have to prioritize, among other factors, upon impact. If you have a cool / strong use case for this feature, please contact use at [email protected].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants