Skip to content

Commit fa59d75

Browse files
committed
Allow reading state prefab for init state.
1 parent b60b570 commit fa59d75

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

crates/ai00-core/src/lib.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ pub struct StateValue {
278278
}
279279

280280
#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
281-
pub struct StatePath {
281+
pub struct StateFile {
282282
pub name: String,
283283
pub id: StateId,
284284
#[salvo(schema(value_type = String))]
@@ -291,7 +291,7 @@ pub struct StatePath {
291291
pub enum InputState {
292292
Key(StateId),
293293
Value(StateValue),
294-
Path(StatePath),
294+
File(StateFile),
295295
}
296296

297297
impl Default for InputState {
@@ -305,12 +305,12 @@ impl InputState {
305305
match self {
306306
InputState::Key(id) => *id,
307307
InputState::Value(value) => value.id,
308-
InputState::Path(path) => path.id,
308+
InputState::File(file) => file.id,
309309
}
310310
}
311311
}
312312

313-
#[derive(Derivative, Clone)]
313+
#[derive(Derivative, Clone, Serialize, Deserialize)]
314314
#[derivative(Debug)]
315315
pub struct InitState {
316316
pub name: String,

crates/ai00-core/src/run.rs

+20-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
time::Duration,
88
};
99

10-
use anyhow::Result;
10+
use anyhow::{bail, Result};
1111
use derivative::Derivative;
1212
use flume::{Receiver, Sender, TryRecvError};
1313
use itertools::Itertools;
@@ -400,22 +400,30 @@ impl CoreRuntime {
400400
);
401401
Ok(id)
402402
}
403-
InputState::Path(path) => {
404-
let name = path.name.clone();
405-
let id = path.id;
403+
InputState::File(file) => {
404+
let name = file.name.clone();
405+
let id = file.id;
406406
let default = false;
407407

408-
let file = tokio::fs::File::open(&path.path).await?;
408+
let file = tokio::fs::File::open(&file.path).await?;
409409
let data = unsafe { Mmap::map(&file) }?;
410-
let model = SafeTensors::deserialize(&data)?;
411-
let data = load_model_state(&self.context, &self.info, model).await?;
412410

413-
let state = InitState {
414-
name,
415-
id,
416-
default,
417-
data,
411+
let st = SafeTensors::deserialize(&data);
412+
let prefab = cbor4ii::serde::from_slice::<InitState>(&data);
413+
let state = match (st, prefab) {
414+
(Ok(model), _) => {
415+
let data = load_model_state(&self.context, &self.info, model).await?;
416+
InitState {
417+
name,
418+
id,
419+
default,
420+
data,
421+
}
422+
}
423+
(_, Ok(state)) => state,
424+
_ => bail!("failed to load init state"),
418425
};
426+
419427
let mut caches = self.caches.lock().await;
420428
caches.backed.insert(
421429
id,

0 commit comments

Comments
 (0)