Skip to content

Support saving and loading from remote paths in Fabric #19113

Open
@claudio-alanaai

Description

@claudio-alanaai

Bug description

Hello everyone,

I am training a model using FSDP with Fabric.

When saving the model to an S3 bucket calling the following function:

    fabric.save(
        "s3://model-training-us/models/experiment/checkpoint",
        state
    )

I get the following error:

───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /vision-language-model/vlm/train_vlm_fabric.py:520 in <module>               │
│                                                                              │
│   517 │                                                                      │
│   518 │   cfg = Config(args)                                                 │
│   519 │                                                                      │
│ ❱ 520 │   main(cfg)                                                          │
│   521                                                                        │
│                                                                              │
│ /vision-language-model/vlm/train_vlm_fabric.py:387 in main                   │
│                                                                              │
│   384 │   │   "step_count": 0,                                               │
│   385 │   }                                                                  │
│   386 │                                                                      │
│ ❱ 387 │   fabric.save(                                                       │
│   388 │   │   "s3://model-training-us/models/experiment1/t │
│   389 │   │   state,                                                         │
│   390 │   )                                                                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/fabric.py:738 in save        │
│                                                                              │
│    735 │   │   │   for k, v in filter.items():                               │
│    736 │   │   │   │   if not callable(v):                                   │
│    737 │   │   │   │   │   raise TypeError(f"Expected `fabric.save(filter=.. │
│ ❱  738 │   │   self._strategy.save_checkpoint(path=path, state=_unwrap_objec │
│    739 │   │   self.barrier()                                                │
│    740 │                                                                     │
│    741 │   def load(                                                         │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/strategies/fsdp.py:498 in    │
│ save_checkpoint                                                              │
│                                                                              │
│   495 │   │   │   │   │   _apply_filter(key, filter or {}, converted, full_s │
│   496 │   │   │                                                              │
│   497 │   │   │   if self.global_rank == 0:                                  │
│ ❱ 498 │   │   │   │   torch.save(full_state, path)                           │
│   499 │   │   else:                                                          │
│   500 │   │   │   raise ValueError(f"Unknown state_dict_type: {self._state_d │
│   501                                                                        │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:618 in save            │
│                                                                              │
│    615 │   _check_save_filelike(f)                                           │
│    616 │                                                                     │
│    617 │   if _use_new_zipfile_serialization:                                │
│ ❱  618 │   │   with _open_zipfile_writer(f) as opened_zipfile:               │
│    619 │   │   │   _save(obj, opened_zipfile, pickle_module, pickle_protocol │
│    620 │   │   │   return                                                    │
│    621 │   else:                                                             │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:492 in                 │
│ _open_zipfile_writer                                                         │
│                                                                              │
│    489 │   │   container = _open_zipfile_writer_file                         │
│    490 │   else:                                                             │
│    491 │   │   container = _open_zipfile_writer_buffer                       │
│ ❱  492 │   return container(name_or_buffer)                                  │
│    493                                                                       │
│    494                                                                       │
│    495 def _is_compressed_file(f) -> bool:                                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:463 in __init__        │
│                                                                              │
│    460 │   │   │   self.file_stream = io.FileIO(self.name, mode='w')         │
│    461 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.file_str │
│    462 │   │   else:                                                         │
│ ❱  463 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.name))   │
│    464 │                                                                     │
│    465 │   def __exit__(self, *args) -> None:                                │
│    466 │   │   self.file_like.write_end_of_file()                            │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Parent directory
s3:/model-training-us/models/experiment does not exist.

The path s3://model-training-us/models/ exists.

Isn't saving to an S3 bucket supported by fabric.save? Or am I encountering a weird bug?

Thank you in advance for your attention.

Kind regards,
Claudio

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /vision-language-model/vlm/train_vlm_fabric.py:520 in <module>               │
│                                                                              │
│   517 │                                                                      │
│   518 │   cfg = Config(args)                                                 │
│   519 │                                                                      │
│ ❱ 520 │   main(cfg)                                                          │
│   521                                                                        │
│                                                                              │
│ /vision-language-model/vlm/train_vlm_fabric.py:387 in main                   │
│                                                                              │
│   384 │   │   "step_count": 0,                                               │
│   385 │   }                                                                  │
│   386 │                                                                      │
│ ❱ 387 │   fabric.save(                                                       │
│   388 │   │   "s3://model-training-us/models/experiment1/t │
│   389 │   │   state,                                                         │
│   390 │   )                                                                  │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/fabric.py:738 in save        │
│                                                                              │
│    735 │   │   │   for k, v in filter.items():                               │
│    736 │   │   │   │   if not callable(v):                                   │
│    737 │   │   │   │   │   raise TypeError(f"Expected `fabric.save(filter=.. │
│ ❱  738 │   │   self._strategy.save_checkpoint(path=path, state=_unwrap_objec │
│    739 │   │   self.barrier()                                                │
│    740 │                                                                     │
│    741 │   def load(                                                         │
│                                                                              │
│ /usr/lib/python3/dist-packages/lightning/fabric/strategies/fsdp.py:498 in    │
│ save_checkpoint                                                              │
│                                                                              │
│   495 │   │   │   │   │   _apply_filter(key, filter or {}, converted, full_s │
│   496 │   │   │                                                              │
│   497 │   │   │   if self.global_rank == 0:                                  │
│ ❱ 498 │   │   │   │   torch.save(full_state, path)                           │
│   499 │   │   else:                                                          │
│   500 │   │   │   raise ValueError(f"Unknown state_dict_type: {self._state_d │
│   501                                                                        │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:618 in save            │
│                                                                              │
│    615 │   _check_save_filelike(f)                                           │
│    616 │                                                                     │
│    617 │   if _use_new_zipfile_serialization:                                │
│ ❱  618 │   │   with _open_zipfile_writer(f) as opened_zipfile:               │
│    619 │   │   │   _save(obj, opened_zipfile, pickle_module, pickle_protocol │
│    620 │   │   │   return                                                    │
│    621 │   else:                                                             │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:492 in                 │
│ _open_zipfile_writer                                                         │
│                                                                              │
│    489 │   │   container = _open_zipfile_writer_file                         │
│    490 │   else:                                                             │
│    491 │   │   container = _open_zipfile_writer_buffer                       │
│ ❱  492 │   return container(name_or_buffer)                                  │
│    493                                                                       │
│    494                                                                       │
│    495 def _is_compressed_file(f) -> bool:                                   │
│                                                                              │
│ /usr/lib/python3/dist-packages/torch/serialization.py:463 in __init__        │
│                                                                              │
│    460 │   │   │   self.file_stream = io.FileIO(self.name, mode='w')         │
│    461 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.file_str │
│    462 │   │   else:                                                         │
│ ❱  463 │   │   │   super().__init__(torch._C.PyTorchFileWriter(self.name))   │
│    464 │                                                                     │
│    465 │   def __exit__(self, *args) -> None:                                │
│    466 │   │   self.file_like.write_end_of_file()                            │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Parent directory
s3:/model-training-us/models/experiment does not exist.

Environment

  • Lightning:
    • lightning: 2.1.0
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.1.0
    • pytorch-ranger: 0.1.1
    • torch: 2.1.0
    • torch-optimizer: 0.3.0
    • torchmetrics: 1.0.3
    • torchvision: 0.16.0

More info

No response

cc @Borda @awaelchli @carmocca @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    checkpointingRelated to checkpointingduplicateThis issue or pull request already existsfabriclightning.fabric.FabricfeatureIs an improvement or enhancementver: 2.1.x

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions