@@ -5,7 +5,6 @@ use once_cell::sync::Lazy;
55use reqwest:: Url ;
66use std:: hash:: Hasher ;
77use std:: sync:: Arc ;
8- use std:: sync:: Mutex ;
98use tokio_util:: compat:: FuturesAsyncWriteCompatExt ;
109use tracing:: debug;
1110use tracing:: instrument;
@@ -25,17 +24,16 @@ use ort::session::Session;
2524
2625use crate :: onnx:: ensure_onnx_env_init;
2726
28- static SESSIONS : Lazy < DashMap < String , Arc < Mutex < Session > > > > =
29- Lazy :: new ( DashMap :: new) ;
27+ static SESSIONS : Lazy < DashMap < String , Arc < Session > > > = Lazy :: new ( DashMap :: new) ;
3028
3129#[ derive( Debug ) ]
3230pub struct SessionWithId {
3331 pub ( crate ) id : String ,
34- pub ( crate ) session : Arc < Mutex < Session > > ,
32+ pub ( crate ) session : Arc < Session > ,
3533}
3634
37- impl From < ( String , Arc < Mutex < Session > > ) > for SessionWithId {
38- fn from ( value : ( String , Arc < Mutex < Session > > ) ) -> Self {
35+ impl From < ( String , Arc < Session > ) > for SessionWithId {
36+ fn from ( value : ( String , Arc < Session > ) ) -> Self {
3937 Self {
4038 id : value. 0 ,
4139 session : value. 1 ,
@@ -50,7 +48,7 @@ impl std::fmt::Display for SessionWithId {
5048}
5149
5250impl SessionWithId {
53- pub fn into_split ( self ) -> ( String , Arc < Mutex < Session > > ) {
51+ pub fn into_split ( self ) -> ( String , Arc < Session > ) {
5452 ( self . id , self . session )
5553 }
5654}
@@ -106,7 +104,7 @@ fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
106104 [ cpu] . to_vec ( )
107105}
108106
109- fn create_session ( model_bytes : & [ u8 ] ) -> Result < Arc < Mutex < Session > > , Error > {
107+ fn create_session ( model_bytes : & [ u8 ] ) -> Result < Arc < Session > , Error > {
110108 let session = {
111109 if let Some ( err) = ensure_onnx_env_init ( ) {
112110 return Err ( anyhow ! ( "failed to create onnx environment: {err}" ) ) ;
@@ -117,7 +115,14 @@ fn create_session(model_bytes: &[u8]) -> Result<Arc<Mutex<Session>>, Error> {
117115 . commit_from_memory ( model_bytes) ?
118116 } ;
119117
120- Ok ( Arc :: new ( Mutex :: new ( session) ) )
118+ Ok ( Arc :: new ( session) )
119+ }
120+
121+ #[ allow( mutable_transmutes) ]
122+ #[ allow( clippy:: mut_from_ref) ]
123+ pub ( crate ) unsafe fn as_mut_session ( session : & Arc < Session > ) -> & mut Session {
124+ // SAFETY: CPU EP https://github.com/pykeio/ort/issues/402#issuecomment-2949993914
125+ unsafe { std:: mem:: transmute :: < & Session , & mut Session > ( & session. clone ( ) ) }
121126}
122127
123128#[ instrument( level = "debug" , skip_all, fields( model_bytes = model_bytes. len( ) ) , err) ]
@@ -174,7 +179,7 @@ pub(crate) async fn load_session_from_url(
174179 Ok ( ( session_id, session) . into ( ) )
175180}
176181
177- pub ( crate ) async fn get_session ( id : & str ) -> Option < Arc < Mutex < Session > > > {
182+ pub ( crate ) async fn get_session ( id : & str ) -> Option < Arc < Session > > {
178183 SESSIONS . get ( id) . map ( |value| value. pair ( ) . 1 . clone ( ) )
179184}
180185
0 commit comments