Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the file extension of model checkpoints uploaded by NeptuneLogger #20581

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,17 +508,14 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
if not self._log_model_checkpoints:
return

from neptune.types import File

file_names = set()
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")

# save last model
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
with open(checkpoint_callback.last_model_path, "rb") as fp:
self.run[f"{checkpoints_namespace}/{model_last_name}"] = File.from_stream(fp)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)

# save best k models
if hasattr(checkpoint_callback, "best_k_models"):
Expand All @@ -533,8 +530,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:

model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
file_names.add(model_name)
with open(checkpoint_callback.best_model_path, "rb") as fp:
self.run[f"{checkpoints_namespace}/{model_name}"] = File.from_stream(fp)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path)

# remove old models logged to experiment if they are not part of best k models at this point
if self.run.exists(checkpoints_namespace):
Expand Down
9 changes: 4 additions & 5 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,10 @@ def test_after_save_checkpoint(neptune_mock):
mock_file.side_effect = mock.Mock()
logger.after_save_checkpoint(cb_mock)

assert run_instance_mock.__setitem__.call_count == 3
assert run_instance_mock.__getitem__.call_count == 2
assert run_attr_mock.upload.call_count == 2

assert mock_file.from_stream.call_count == 2
assert run_instance_mock.__setitem__.call_count == 1 # best_model_path
assert run_instance_mock.__getitem__.call_count == 4 # last_model_path, best_k_models, best_model_path
assert run_attr_mock.upload.call_count == 4 # last_model_path, best_k_models, best_model_path
assert mock_file.from_stream.call_count == 0

run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
Expand Down
Loading