Skip to content

Design choices on the solver side #31

@rtavenar

Description

@rtavenar

This is rather a discussion than an issue.

During the sprint we have made some design choices (that could be later rediscussed ofc) for the models.

The idea is that the minimal code one has to come up with to implement a new model is (Chronos2 example):

class Solver(BaseTSFMSolver):
    name = "Chronos"

    requirements = ["pip::chronos-forecasting>=2.2,<3"]

    parameters = {
        "model_size": ["small"],
        "layer": [None],
        "pooler": ["mean"],
    }

    def __init__(self, model_size="small", layer=None, pooler="mean"):
        """Initialize Chronos-specific state.

        Parameters
        ----------
        model_size : str, default="small"
            Chronos model variant to load.
        layer : int or None, default=None
            Encoder block index for classification embeddings.
        pooler : {"mean", "max", "last"}, default="mean"
            Pooling strategy over the time-token axis for classification.
        """
        super().__init__(
            model_size=model_size,
            layer=layer,
            pooler=pooler,
        )
        self._pipeline = None
        self._loaded_model = None

    @property
    def supported_tasks(self):
        return SUPPORTED_TASKS

    def load_model(self, device, dtype):
        """Load Chronos-2 pipeline (cached if already loaded)."""
        from chronos import Chronos2Pipeline

        model_id = f"autogluon/chronos-2-{self.model_size}"
        if not hasattr(self, "_pipeline") or self._loaded_model != model_id:
            self._pipeline = Chronos2Pipeline.from_pretrained(
                model_id,
                device_map=device,
                dtype=dtype,
            )
            self._loaded_model = model_id
        return self._pipeline

    def forecast_batch(self, inputs):
        """Chronos-specific batch prediction.

        Parameters
        ----------
        inputs : list of torch.Tensor
            Each tensor shape (C, T_cutoff)

        Returns
        -------
        list of torch.Tensor
            Each tensor shape (H, C, Q)
        """
        with torch.no_grad():
            return self.model.predict(inputs, prediction_length=self.prediction_length)

In the end, we should also have a time_embed_batch method, and an optional embed_batch one, such that the BaseTSFMSolver class we inherit from already implements a basic way to iterate over the cutoffs, etc.

Then, the BaseTSFMSolver class should handle the use of basic adapters, based on which tasks the model handles and using the time_embed, embed and forecast methods.

Someone could provide specific adapters if they want, but if they do not, then default adapters would be used, as implemented in the base class.

The standard adaptation strategies are:

  • for forecasting
    • if the model has an implemented forecast method, we should use it,
    • otherwise, use windowing relying on the embed method
  • for classification
    • if the model has an implemented embed method, we should use it,
    • otherwise use pooling based on time_embed
  • for anomaly detection,
    • if the model has an implemented embed method, we should use it,
    • otherwise rely on residuals as provided by forecast
  • for event detection
    • if the model has an implemented time_embed method, we should use it,
    • otherwise, use windowing relying on the embed method

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions