@@ -7,7 +7,7 @@ use std::{
7
7
time:: Duration ,
8
8
} ;
9
9
10
- use anyhow:: Result ;
10
+ use anyhow:: { bail , Result } ;
11
11
use derivative:: Derivative ;
12
12
use flume:: { Receiver , Sender , TryRecvError } ;
13
13
use itertools:: Itertools ;
@@ -400,22 +400,30 @@ impl CoreRuntime {
400
400
) ;
401
401
Ok ( id)
402
402
}
403
- InputState :: Path ( path ) => {
404
- let name = path . name . clone ( ) ;
405
- let id = path . id ;
403
+ InputState :: File ( file ) => {
404
+ let name = file . name . clone ( ) ;
405
+ let id = file . id ;
406
406
let default = false ;
407
407
408
- let file = tokio:: fs:: File :: open ( & path . path ) . await ?;
408
+ let file = tokio:: fs:: File :: open ( & file . path ) . await ?;
409
409
let data = unsafe { Mmap :: map ( & file) } ?;
410
- let model = SafeTensors :: deserialize ( & data) ?;
411
- let data = load_model_state ( & self . context , & self . info , model) . await ?;
412
410
413
- let state = InitState {
414
- name,
415
- id,
416
- default,
417
- data,
411
+ let st = SafeTensors :: deserialize ( & data) ;
412
+ let prefab = cbor4ii:: serde:: from_slice :: < InitState > ( & data) ;
413
+ let state = match ( st, prefab) {
414
+ ( Ok ( model) , _) => {
415
+ let data = load_model_state ( & self . context , & self . info , model) . await ?;
416
+ InitState {
417
+ name,
418
+ id,
419
+ default,
420
+ data,
421
+ }
422
+ }
423
+ ( _, Ok ( state) ) => state,
424
+ _ => bail ! ( "failed to load init state" ) ,
418
425
} ;
426
+
419
427
let mut caches = self . caches . lock ( ) . await ;
420
428
caches. backed . insert (
421
429
id,
0 commit comments