-
Notifications
You must be signed in to change notification settings - Fork 652
Add safetensors export feature #3345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I think I know what the issue you are facing is regarding the serializer going all the way down. It's okay to stop if you come across a My recommendation:
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
❌ Your project check has failed because the head coverage (35.19%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #3345 +/- ##
===========================================
- Coverage 82.49% 35.19% -47.31%
===========================================
Files 990 342 -648
Lines 127088 53164 -73924
===========================================
- Hits 104846 18709 -86137
- Misses 22242 34455 +12213 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Congrats to make it work! I know dealing with Serde is not easy.
I did a quick pass and here are my quick high level comments. I will dive in more details later.
- We should hook save_item under SafetensorsFileRecorder.
- We should aim not to duplicate tensor data. The tensor view should be sufficient for safetensors to pull data directly.
- It seems the serde serializer could be simplified and made more robust. I am still looking into the implementation, so I do not have any concrete suggestions.
- We should verify serialization of tuple, enum and vec of modules. It might be worth to update safetensors tests with more complete test data.
- We should add module adapter just like in SafetensorsFileRecorder because our modules are not one to one. SafetensorsFileRecorder's default adapter is pytorch.
@@ -26,6 +26,7 @@ safetensors = [ | |||
"thiserror", | |||
"zip", | |||
"candle-core", | |||
"dep:safetensors", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don’t need to use dep: for the newer rust editions.
let mut file = File::create("model.safetensors").unwrap(); | ||
file.write_all(&serialized).unwrap(); | ||
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default() | ||
.load( | ||
LoadArgs::new("model.safetensors".into()).with_adapter_type(AdapterType::NoAdapter), | ||
&device, | ||
) | ||
.expect("Should decode state successfully"); | ||
std::fs::remove_file("model.safetensors").unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend using NamedTempFile https://github.com/tracel-ai/burn/blob/main/crates/burn-dataset/src/dataset/sqlite.rs#L713. Otherwise you'll have collisions and hanging files.
.serialize(&mut ser) | ||
.unwrap(); | ||
safetensors::serialize(ser.into_map(), None) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also enhance SafetensorsFileRecorder's save_item method just like in
https://github.com/tracel-ai/burn/blob/main/crates/burn-core/src/record/file.rs#L194
This way SafetensorsFileRecorder feature is symmetrical
pub struct SafetensorsTensorData { | ||
bytes: Vec<u8>, | ||
shape: Vec<usize>, | ||
dtype: DType, | ||
} | ||
|
||
impl safetensors::View for SafetensorsTensorData { | ||
fn dtype(&self) -> safetensors::Dtype { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't create a view on TensorData directly without the need to save into bytes
? This way there won't be an intermediate memory usage for large models. So if we have 8GB model sitting in GPU, we will not create a copy in RAM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you are right, I was preocuppied with making it work and I didn't try using just a reference to the original bytes. I'll try modifying it so that instead of Vec<u8>
, we point to the original data using a &[u8]
.
This is still a work in progress, I started out following antimora's indications in: #3260 (comment)
I did not find a clean way to walk the Module with the Serializer, and perform the serialization only with the Recorder. The problem is that once the Serializer is finished I'd need to reconstruct each TensorData back from its
bytes
,dtype
andshape
fields as the serializer will walk up to the basic types.The alternative I found was to use the Serializer to get a mapping from ParamId to the TensorName, and use a ModuleVisitor to link to the TensorData. I think this is a bit cleaner, but the API is pretty different from the rest of the serializations, so I'm not sure, if it would be better to do it the other way, even if the logic ends up a bit more convoluted.
Any comments are appreciated :)