|
25 | 25 |
|
26 | 26 | from .base import BaseChronosPipeline, ForecastType |
27 | 27 |
|
28 | | - |
29 | 28 | logger = logging.getLogger(__file__) |
30 | 29 |
|
31 | 30 |
|
@@ -240,13 +239,11 @@ def _init_weights(self, module): |
240 | 239 | ): |
241 | 240 | module.output_layer.bias.data.zero_() |
242 | 241 |
|
243 | | - def forward( |
244 | | - self, |
245 | | - context: torch.Tensor, |
246 | | - mask: Optional[torch.Tensor] = None, |
247 | | - target: Optional[torch.Tensor] = None, |
248 | | - target_mask: Optional[torch.Tensor] = None, |
249 | | - ) -> ChronosBoltOutput: |
| 242 | + def encode( |
| 243 | + self, context: torch.Tensor, mask: Optional[torch.Tensor] = None |
| 244 | + ) -> Tuple[ |
| 245 | + torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor |
| 246 | + ]: |
250 | 247 | mask = ( |
251 | 248 | mask.to(context.dtype) |
252 | 249 | if mask is not None |
@@ -301,8 +298,21 @@ def forward( |
301 | 298 | attention_mask=attention_mask, |
302 | 299 | inputs_embeds=input_embeds, |
303 | 300 | ) |
304 | | - hidden_states = encoder_outputs[0] |
305 | 301 |
|
| 302 | + return encoder_outputs[0], loc_scale, input_embeds, attention_mask |
| 303 | + |
| 304 | + def forward( |
| 305 | + self, |
| 306 | + context: torch.Tensor, |
| 307 | + mask: Optional[torch.Tensor] = None, |
| 308 | + target: Optional[torch.Tensor] = None, |
| 309 | + target_mask: Optional[torch.Tensor] = None, |
| 310 | + ) -> ChronosBoltOutput: |
| 311 | + batch_size = context.size(0) |
| 312 | + |
| 313 | + hidden_states, loc_scale, input_embeds, attention_mask = self.encode( |
| 314 | + context=context, mask=mask |
| 315 | + ) |
306 | 316 | sequence_output = self.decode(input_embeds, attention_mask, hidden_states) |
307 | 317 |
|
308 | 318 | quantile_preds_shape = ( |
@@ -426,6 +436,46 @@ def __init__(self, model: ChronosBoltModelForForecasting): |
426 | 436 | def quantiles(self) -> List[float]: |
427 | 437 | return self.model.config.chronos_config["quantiles"] |
428 | 438 |
|
| 439 | + @torch.no_grad() |
| 440 | + def embed( |
| 441 | + self, context: Union[torch.Tensor, List[torch.Tensor]] |
| 442 | + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| 443 | + """ |
| 444 | + Get encoder embeddings for the given time series. |
| 445 | +
|
| 446 | + Parameters |
| 447 | + ---------- |
| 448 | + context |
| 449 | + Input series. This is either a 1D tensor, or a list |
| 450 | + of 1D tensors, or a 2D tensor whose first dimension |
| 451 | + is batch. In the latter case, use left-padding with |
| 452 | + ``torch.nan`` to align series of different lengths. |
| 453 | +
|
| 454 | + Returns |
| 455 | + ------- |
| 456 | + embeddings, loc_scale |
| 457 | + A tuple of two items: the encoder embeddings and the loc_scale, |
| 458 | + i.e., the mean and std of the original time series. |
| 459 | + The encoder embeddings are shaped (batch_size, num_patches + 1, d_model), |
| 460 | + where num_patches is the number of patches in the time series |
| 461 | + and the extra 1 is for the [REG] token (if used by the model). |
| 462 | + """ |
| 463 | + context_tensor = self._prepare_and_validate_context(context=context) |
| 464 | + model_context_length = self.model.config.chronos_config["context_length"] |
| 465 | + |
| 466 | + if context_tensor.shape[-1] > model_context_length: |
| 467 | + context_tensor = context_tensor[..., -model_context_length:] |
| 468 | + |
| 469 | + context_tensor = context_tensor.to( |
| 470 | + device=self.model.device, |
| 471 | + dtype=torch.float32, |
| 472 | + ) |
| 473 | + embeddings, loc_scale, *_ = self.model.encode(context=context_tensor) |
| 474 | + return embeddings.cpu(), ( |
| 475 | + loc_scale[0].squeeze(-1).cpu(), |
| 476 | + loc_scale[1].squeeze(-1).cpu(), |
| 477 | + ) |
| 478 | + |
429 | 479 | def predict( # type: ignore[override] |
430 | 480 | self, |
431 | 481 | context: Union[torch.Tensor, List[torch.Tensor]], |
|
0 commit comments