Skip to content

Add safetensors export feature #3345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jonboh
Copy link
Contributor

@jonboh jonboh commented Jul 3, 2025

This is still a work in progress, I started out following antimora's indications in: #3260 (comment)

I did not find a clean way to walk the Module with the Serializer, and perform the serialization only with the Recorder. The problem is that once the Serializer is finished I'd need to reconstruct each TensorData back from its bytes, dtype and shape fields as the serializer will walk up to the basic types.

The alternative I found was to use the Serializer to get a mapping from ParamId to the TensorName, and use a ModuleVisitor to link to the TensorData. I think this is a bit cleaner, but the API is pretty different from the rest of the serializations, so I'm not sure, if it would be better to do it the other way, even if the logic ends up a bit more convoluted.
Any comments are appreciated :)

@antimora
Copy link
Collaborator

antimora commented Jul 3, 2025

I think I know what the issue you are facing is regarding the serializer going all the way down. It's okay to stop if you come across a TensorData struct type and handle serialization differently.

My recommendation:

  1. Don't use ModuleVisitor (you can get away with a serde serializer).
  2. The only intermediate data structure should be a hash type of full path name + tensor safetensor view.

@antimora antimora changed the title add safetensors export feature Add safetensors export feature Jul 12, 2025
@jonboh jonboh marked this pull request as ready for review July 13, 2025 12:01
@jonboh jonboh requested a review from antimora July 13, 2025 12:01
Copy link

codecov bot commented Jul 13, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 35.19%. Comparing base (f8273f0) to head (b504a8e).
Report is 41 commits behind head on main.

❌ Your project check has failed because the head coverage (35.19%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #3345       +/-   ##
===========================================
- Coverage   82.49%   35.19%   -47.31%     
===========================================
  Files         990      342      -648     
  Lines      127088    53164    -73924     
===========================================
- Hits       104846    18709    -86137     
- Misses      22242    34455    +12213     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats to make it work! I know dealing with Serde is not easy.

I did a quick pass and here are my quick high level comments. I will dive in more details later.

  1. We should hook save_item under SafetensorsFileRecorder.
  2. We should aim not to duplicate tensor data. The tensor view should be sufficient for safetensors to pull data directly.
  3. It seems the serde serializer could be simplified and made more robust. I am still looking into the implementation, so I do not have any concrete suggestions.
  4. We should verify serialization of tuple, enum and vec of modules. It might be worth to update safetensors tests with more complete test data.
  5. We should add module adapter just like in SafetensorsFileRecorder because our modules are not one to one. SafetensorsFileRecorder's default adapter is pytorch.

@@ -26,6 +26,7 @@ safetensors = [
"thiserror",
"zip",
"candle-core",
"dep:safetensors",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don’t need to use dep: for the newer rust editions.

Comment on lines +33 to +41
let mut file = File::create("model.safetensors").unwrap();
file.write_all(&serialized).unwrap();
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load(
LoadArgs::new("model.safetensors".into()).with_adapter_type(AdapterType::NoAdapter),
&device,
)
.expect("Should decode state successfully");
std::fs::remove_file("model.safetensors").unwrap();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend using NamedTempFile https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/dataset/sqlite.rs#L713. Otherwise you'll have collisions and hanging files.

.serialize(&mut ser)
.unwrap();
safetensors::serialize(ser.into_map(), None)
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also enhance SafetensorsFileRecorder's save_item method just like in
https://github.com/tracel-ai/burn/blob/main/crates/burn-core/src/record/file.rs#L194

This way SafetensorsFileRecorder feature is symmetrical

Comment on lines +10 to +17
pub struct SafetensorsTensorData {
bytes: Vec<u8>,
shape: Vec<usize>,
dtype: DType,
}

impl safetensors::View for SafetensorsTensorData {
fn dtype(&self) -> safetensors::Dtype {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't create a view on TensorData directly without the need to save into bytes? This way there won't be an intermediate memory usage for large models. So if we have 8GB model sitting in GPU, we will not create a copy in RAM.

Copy link
Contributor Author

@jonboh jonboh Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right, I was preocuppied with making it work and I didn't try using just a reference to the original bytes. I'll try modifying it so that instead of Vec<u8>, we point to the original data using a &[u8].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants