Skip to content

Commit f40a266

Browse files
authored
Fix type-checking issues (#295)
*Issue #, if available:* See example build https://github.com/amazon-science/chronos-forecasting/actions/runs/14302765904/job/40313421985 *Description of changes:* - Address type-checker complaints, where possible - Bump bugfix version of the package By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent eec771e commit f40a266

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "chronos-forecasting"
3-
version = "1.5.0"
3+
version = "1.5.1"
44
authors = [
55
{ name="Abdul Fatir Ansari", email="[email protected]" },
66
{ name="Lorenzo Stella", email="[email protected]" },

src/chronos/chronos.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ def encode(
305305
assert (
306306
self.config.model_type == "seq2seq"
307307
), "Encoder embeddings are only supported for encoder-decoder models"
308+
assert hasattr(self.model, "encoder")
309+
308310
return self.model.encoder(
309311
input_ids=input_ids, attention_mask=attention_mask
310312
).last_hidden_state
@@ -344,6 +346,8 @@ def forward(
344346
if top_p is None:
345347
top_p = self.config.top_p
346348

349+
assert hasattr(self.model, "generate")
350+
347351
preds = self.model.generate(
348352
input_ids=input_ids,
349353
attention_mask=attention_mask,

src/chronos/chronos_bolt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def forward(self, x: torch.Tensor):
136136

137137

138138
class ChronosBoltModelForForecasting(T5PreTrainedModel):
139-
_keys_to_ignore_on_load_missing = [
139+
_keys_to_ignore_on_load_missing = [ # type: ignore
140140
r"input_patch_embedding\.",
141141
r"output_patch_embedding\.",
142142
]
143-
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
144-
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
143+
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"] # type: ignore
144+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # type: ignore
145145

146146
def __init__(self, config: T5Config):
147147
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
@@ -358,7 +358,7 @@ def forward(
358358
(target - quantile_preds)
359359
* (
360360
(target <= quantile_preds).float()
361-
- self.quantiles.view(1, self.num_quantiles, 1)
361+
- self.quantiles.view(1, self.num_quantiles, 1) # type: ignore
362362
)
363363
)
364364
* target_mask.float()
@@ -429,7 +429,7 @@ class ChronosBoltPipeline(BaseChronosPipeline):
429429
default_context_length: int = 2048
430430

431431
def __init__(self, model: ChronosBoltModelForForecasting):
432-
super().__init__(inner_model=model)
432+
super().__init__(inner_model=model) # type: ignore
433433
self.model = model
434434

435435
@property

0 commit comments

Comments
 (0)