Skip to content

Commit 2f92a12

Browse files
abdulfatirAbdul Fatir Ansari
and
Abdul Fatir Ansari
authored
Add support for causal models (#113)
*Description of changes:* This PR adds support for training causal/decoder-only models. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Abdul Fatir Ansari <[email protected]>
1 parent 79028e3 commit 2f92a12

File tree

4 files changed

+96
-4
lines changed

4 files changed

+96
-4
lines changed

scripts/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@
8989
The output and checkpoints will be saved in `output/run-{id}/`.
9090
> [!TIP]
9191
> If the initial training step is too slow, you might want to change the `shuffle_buffer_length` and/or set `torch_compile` to `false`.
92+
93+
> [!IMPORTANT]
94+
> When pretraining causal models (such as GPT2), the training script does [`LastValueImputation`](https://github.com/awslabs/gluonts/blob/f0f2266d520cb980f4c1ce18c28b003ad5cd2599/src/gluonts/transform/feature.py#L103) for missing values by default. If you pretrain causal models, please ensure that missing values are imputed similarly before passing the context tensor to `ChronosPipeline.predict()` for accurate results.
9295
- (Optional) Once trained, you can easily push your fine-tuned model to HuggingFace🤗 Hub. Before that, do not forget to [create an access token](https://huggingface.co/settings/tokens) with **write permissions** and put it in `~/.cache/huggingface/token`. Here's a snippet that will push a fine-tuned model to HuggingFace🤗 Hub at `<your_hf_username>/chronos-t5-small-fine-tuned`.
9396
```py
9497
from chronos import ChronosPipeline
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
training_data_paths:
2+
- "/home/ubuntu/tsmixup-data.arrow"
3+
- "/home/ubuntu/kernelsynth-data.arrow"
4+
probability:
5+
- 0.9
6+
- 0.1
7+
context_length: 512
8+
prediction_length: 64
9+
min_past: 60
10+
max_steps: 200_000
11+
save_steps: 100_000
12+
log_steps: 500
13+
per_device_train_batch_size: 32
14+
learning_rate: 0.001
15+
optim: adamw_torch_fused
16+
num_samples: 20
17+
shuffle_buffer_length: 100_000
18+
gradient_accumulation_steps: 1
19+
model_id: openai-community/gpt2
20+
model_type: causal
21+
random_init: false
22+
tie_embeddings: false
23+
output_dir: ./output/
24+
tf32: true
25+
torch_compile: true
26+
tokenizer_class: "MeanScaleUniformBins"
27+
tokenizer_kwargs:
28+
low_limit: -15.0
29+
high_limit: 15.0
30+
n_tokens: 4096
31+
lr_scheduler_type: linear
32+
warmup_ratio: 0.0
33+
dataloader_num_workers: 1
34+
max_missing_prop: 0.1
35+
use_eos_token: true

scripts/training/train.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
ValidationSplitSampler,
4040
InstanceSplitter,
4141
ExpectedNumInstanceSampler,
42+
MissingValueImputation,
43+
LeavesMissingValues,
44+
LastValueImputation,
4245
)
4346

4447
from chronos import ChronosConfig, ChronosTokenizer
@@ -301,13 +304,16 @@ def __init__(
301304
prediction_length: int = 64,
302305
drop_prob: float = 0.2,
303306
min_past: Optional[int] = None,
307+
model_type: str = "seq2seq",
308+
imputation_method: Optional[MissingValueImputation] = None,
304309
mode: str = "training",
305310
np_dtype=np.float32,
306311
) -> None:
307312
super().__init__()
308313

309314
assert len(probabilities) == len(datasets)
310315
assert mode in ("training", "validation", "test")
316+
assert model_type in ("seq2seq", "causal")
311317

312318
self.datasets = datasets
313319
self.probabilities = probabilities
@@ -316,6 +322,8 @@ def __init__(
316322
self.prediction_length = prediction_length
317323
self.drop_prob = drop_prob
318324
self.min_past = min_past or prediction_length
325+
self.model_type = model_type
326+
self.imputation_method = imputation_method or LeavesMissingValues()
319327
self.mode = mode
320328
self.np_dtype = np_dtype
321329

@@ -324,6 +332,11 @@ def preprocess_entry(self, entry: dict, mode: str) -> dict:
324332
entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype)
325333
assert entry["target"].ndim == 1, f"got {entry['target'].ndim=}, expected 1"
326334

335+
if self.model_type == "causal":
336+
# Causal models do not play nice with missing values, so it is
337+
# recommended to use an imputation method, e.g., LastValueImputation
338+
entry["target"] = self.imputation_method(entry["target"])
339+
327340
if mode == "training" and self.drop_prob > 0:
328341
target = entry["target"].copy()
329342
drop_p = np.random.uniform(low=0.0, high=self.drop_prob)
@@ -386,6 +399,48 @@ def to_hf_format(self, entry: dict) -> dict:
386399
future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
387400
labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
388401
labels[labels_mask == 0] = -100
402+
403+
if self.model_type == "causal":
404+
# The InstanceSplitter pads time series on the left to be equal to the
405+
# context_length. However, certain models (e.g., GPT2) with absolute
406+
# position embeddings should not be trained with left padding.
407+
# The following piece of code moves padding from left to right.
408+
409+
assert input_ids.shape[-1] == entry["past_is_pad"].shape[0]
410+
411+
# Find the index where padding starts
412+
pad_start_idx = np.searchsorted(1 - entry["past_is_pad"], 1)
413+
padded_input_ids, obs_input_ids = torch.tensor_split(
414+
input_ids, [pad_start_idx], dim=-1
415+
)
416+
padded_attention_mask, obs_attention_mask = torch.tensor_split(
417+
attention_mask, [pad_start_idx], dim=-1
418+
)
419+
420+
# Move padding to the right
421+
input_ids = torch.cat(
422+
[
423+
obs_input_ids,
424+
labels,
425+
padded_input_ids,
426+
],
427+
axis=-1,
428+
)
429+
attention_mask = torch.cat(
430+
[
431+
obs_attention_mask,
432+
labels_mask,
433+
padded_attention_mask,
434+
],
435+
axis=-1,
436+
)
437+
438+
# labels for causal models are same as the input_ids.
439+
# Internally transformers shifts the labels by one during training.
440+
labels = input_ids.clone()
441+
input_ids[~attention_mask] = self.tokenizer.config.pad_token_id
442+
labels[~attention_mask] = -100
443+
389444
return {
390445
"input_ids": input_ids.squeeze(0),
391446
"attention_mask": attention_mask.squeeze(0),
@@ -520,9 +575,6 @@ def main(
520575

521576
assert model_type in ["seq2seq", "causal"]
522577

523-
if not model_type == "seq2seq":
524-
raise NotImplementedError("Only seq2seq models are currently supported")
525-
526578
output_dir = get_next_path("run", base_dir=output_dir, file_type="")
527579

528580
log_on_main(f"Logging dir: {output_dir}", logger)
@@ -588,6 +640,8 @@ def main(
588640
context_length=context_length,
589641
prediction_length=prediction_length,
590642
min_past=min_past,
643+
model_type=model_type,
644+
imputation_method=LastValueImputation() if model_type == "causal" else None,
591645
mode="training",
592646
).shuffle(shuffle_buffer_length=shuffle_buffer_length)
593647

src/chronos/chronos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def from_pretrained(cls, *args, **kwargs):
551551
if chronos_config.model_type == "seq2seq":
552552
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
553553
else:
554-
assert config.model_type == "causal"
554+
assert chronos_config.model_type == "causal"
555555
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
556556

557557
return cls(

0 commit comments

Comments
 (0)