From 2e2a5c7fbac62cf143fd00450906effb81175115 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 23 Oct 2025 22:26:26 +0000 Subject: [PATCH 1/4] feat: allow specifying equations via extern C functions --- diffsol/Cargo.toml | 2 + diffsol/src/ode_equations/diffsl.rs | 114 +++++---- diffsol/src/ode_equations/external_linkage.rs | 238 ++++++++++++++++++ diffsol/src/ode_equations/mod.rs | 3 + diffsol/src/ode_solver/builder.rs | 19 ++ examples/external-linkage/Cargo.toml | 10 + examples/external-linkage/src/main.rs | 200 +++++++++++++++ 7 files changed, 540 insertions(+), 46 deletions(-) create mode 100644 diffsol/src/ode_equations/external_linkage.rs create mode 100644 examples/external-linkage/Cargo.toml create mode 100644 examples/external-linkage/src/main.rs diff --git a/diffsol/Cargo.toml b/diffsol/Cargo.toml index 495dd6b5..c6e16e04 100644 --- a/diffsol/Cargo.toml +++ b/diffsol/Cargo.toml @@ -19,6 +19,8 @@ cuda = ["dep:cudarc"] sundials = ["suitesparse_sys", "bindgen", "cc"] suitesparse = ["suitesparse_sys"] diffsl = [] +diffsl-ext = ["dep:diffsl", "diffsl"] +diffsl-ext-sens = ["diffsl-ext"] diffsl-cranelift = ["diffsl/cranelift", "diffsl"] diffsl-llvm = [] diffsl-llvm15 = ["diffsl/llvm15-0", "diffsl", "diffsl-llvm"] diff --git a/diffsol/src/ode_equations/diffsl.rs b/diffsol/src/ode_equations/diffsl.rs index d5d23811..5ef2f42e 100644 --- a/diffsol/src/ode_equations/diffsl.rs +++ b/diffsol/src/ode_equations/diffsl.rs @@ -8,12 +8,15 @@ use diffsl::{ }; use crate::{ - error::DiffsolError, find_jacobian_non_zeros, find_matrix_non_zeros, find_sens_non_zeros, - jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity, - op::nonlinear_op::NonLinearOpJacobian, ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, - LinearOp, LinearOpTranspose, Matrix, MatrixHost, NonLinearOp, NonLinearOpAdjoint, - NonLinearOpSens, NonLinearOpSensAdjoint, OdeEquations, OdeEquationsRef, Op, Scale, Vector, - VectorHost, + error::DiffsolError, + find_jacobian_non_zeros, find_matrix_non_zeros, find_sens_non_zeros, + jacobian::JacobianColoring, + matrix::sparsity::MatrixSparsity, + ode_equations::external_linkage::{symbol_map, ExtLinkModule}, + op::nonlinear_op::NonLinearOpJacobian, + ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix, + MatrixHost, NonLinearOp, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint, + OdeEquations, OdeEquationsRef, Op, Scale, Vector, VectorHost, }; pub type T = f64; @@ -40,37 +43,8 @@ pub struct DiffSlContext, CG: CodegenModule> { ctx: M::C, } -impl, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContext { - /// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/). - /// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE. - /// - /// # Arguments - /// - /// * `text` - The text of the ODE equations in the DiffSL language. - /// * `nthreads` - The number of threads to use for code generation (0 for automatic, 1 for single-threaded). - /// - pub fn new(text: &str, nthreads: usize, ctx: M::C) -> Result { - let mode = match nthreads { - 0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None), - 1 => diffsl::execution::compiler::CompilerMode::SingleThreaded, - _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)), - }; - let compiler = Compiler::from_discrete_str(text, mode) - .map_err(|e| DiffsolError::Other(e.to_string()))?; - let (nstates, _nparams, _nout, _ndata, _nroots, _has_mass) = compiler.get_dims(); - - let compiler = if nthreads == 0 { - let num_cpus = std::thread::available_parallelism().unwrap().get(); - let nthreads = num_cpus.min(nstates / 1000).max(1); - Compiler::from_discrete_str( - text, - diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)), - ) - .map_err(|e| DiffsolError::Other(e.to_string()))? - } else { - compiler - }; - +impl, CG: CodegenModule> DiffSlContext { + fn from_compiler(compiler: Compiler, ctx: M::C) -> Result { let (nstates, nparams, nout, _ndata, nroots, has_mass) = compiler.get_dims(); let has_root = nroots > 0; @@ -104,6 +78,41 @@ impl, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContex } } +impl, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContext { + /// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/). + /// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE. + /// + /// # Arguments + /// + /// * `text` - The text of the ODE equations in the DiffSL language. + /// * `nthreads` - The number of threads to use for code generation (0 for automatic, 1 for single-threaded). + /// + pub fn new(text: &str, nthreads: usize, ctx: M::C) -> Result { + let mode = match nthreads { + 0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None), + 1 => diffsl::execution::compiler::CompilerMode::SingleThreaded, + _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)), + }; + let compiler = Compiler::from_discrete_str(text, mode) + .map_err(|e| DiffsolError::Other(e.to_string()))?; + Self::from_compiler(compiler, ctx) + } +} + +impl> DiffSlContext { + pub fn from_external_linkage(nthreads: usize, ctx: M::C) -> Result { + let mode = match nthreads { + 0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None), + 1 => diffsl::execution::compiler::CompilerMode::SingleThreaded, + _ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)), + }; + let symbol_map = symbol_map(); + let compiler = Compiler::new(ExtLinkModule, symbol_map, mode) + .map_err(|e| DiffsolError::Other(e.to_string()))?; + Self::from_compiler(compiler, ctx) + } +} + impl, CG: CodegenModuleJit + CodegenModuleCompile> Default for DiffSlContext { @@ -137,15 +146,7 @@ pub struct DiffSl, CG: CodegenModule> { rhs_sens_adjoint_coloring: Option>, } -impl, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl { - pub fn compile( - code: &str, - ctx: M::C, - include_sensitivities: bool, - ) -> Result { - let context = DiffSlContext::::new(code, 1, ctx)?; - Ok(Self::from_context(context, include_sensitivities)) - } +impl, CG: CodegenModule> DiffSl { pub fn from_context(context: DiffSlContext, include_sensitivities: bool) -> Self { let mut ret = Self { context, @@ -233,6 +234,27 @@ impl, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl> DiffSl { + pub fn from_external_linkage( + ctx: M::C, + include_sensitivities: bool, + ) -> Result { + let context = DiffSlContext::::from_external_linkage(1, ctx)?; + Ok(Self::from_context(context, include_sensitivities)) + } +} + +impl, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl { + pub fn compile( + code: &str, + ctx: M::C, + include_sensitivities: bool, + ) -> Result { + let context = DiffSlContext::::new(code, 1, ctx)?; + Ok(Self::from_context(context, include_sensitivities)) + } +} + pub struct DiffSlRoot<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); pub struct DiffSlOut<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); pub struct DiffSlRhs<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); diff --git a/diffsol/src/ode_equations/external_linkage.rs b/diffsol/src/ode_equations/external_linkage.rs new file mode 100644 index 00000000..dbb06a1f --- /dev/null +++ b/diffsol/src/ode_equations/external_linkage.rs @@ -0,0 +1,238 @@ +use diffsl::execution::module::CodegenModule; +use std::collections::HashMap; + +pub type RealType = f64; +pub type UIntType = u32; + +extern "C" { + pub fn set_constants(thread_id: UIntType, thread_dim: UIntType); + pub fn stop( + time: RealType, + u: *const RealType, + data: *mut RealType, + root: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn rhs( + time: RealType, + u: *const RealType, + data: *mut RealType, + rr: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn rhs_grad( + time: RealType, + u: *const RealType, + du: *const RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + + pub fn mass( + time: RealType, + u: *const RealType, + data: *mut RealType, + mv: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn set_u0(u: *mut RealType, data: *mut RealType, thread_id: UIntType, thread_dim: UIntType); + + pub fn set_u0_grad( + u: *const RealType, + du: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + + pub fn calc_out( + time: RealType, + u: *const RealType, + data: *mut RealType, + out: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn calc_out_grad( + time: RealType, + u: *const RealType, + du: *const RealType, + data: *const RealType, + ddata: *mut RealType, + out: *const RealType, + dout: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + + pub fn get_dims( + states: *mut UIntType, + inputs: *mut UIntType, + outputs: *mut UIntType, + data: *mut UIntType, + stop: *mut UIntType, + has_mass: *mut UIntType, + ); + pub fn set_inputs(inputs: *const RealType, data: *mut RealType); + pub fn get_inputs(inputs: *mut RealType, data: *const RealType); + pub fn set_inputs_grad( + inputs: *const RealType, + dinputs: *const RealType, + data: *const RealType, + ddata: *mut RealType, + ); + + pub fn set_id(id: *mut RealType); + pub fn get_tensor( + data: *const RealType, + tensor_data: *mut *mut RealType, + tensor_size: *mut UIntType, + ); +} + +#[cfg(feature = "diffsl-ext-sens")] +extern "C" { + pub fn rhs_rgrad( + time: RealType, + u: *const RealType, + du: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn set_u0_rgrad( + u: *const RealType, + du: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn calc_out_rgrad( + time: RealType, + u: *const RealType, + du: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + out: *const RealType, + dout: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn set_inputs_rgrad( + inputs: *const RealType, + dinputs: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + ); + + pub fn rhs_srgrad( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn calc_out_srgrad( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + out: *const RealType, + dout: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); +} + +#[cfg(feature = "diffsl-ext-sens")] +extern "C" { + pub fn rhs_sgrad( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + rr: *const RealType, + drr: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn set_u0_sgrad( + u: *const RealType, + du: *mut RealType, + data: *const RealType, + ddata: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); + pub fn calc_out_sgrad( + time: RealType, + u: *const RealType, + data: *const RealType, + ddata: *mut RealType, + out: *const RealType, + dout: *mut RealType, + thread_id: UIntType, + thread_dim: UIntType, + ); +} + +pub struct ExtLinkModule; + +impl CodegenModule for ExtLinkModule {} + +pub(crate) fn symbol_map() -> HashMap { + let mut map = HashMap::new(); + map.insert("set_u0".to_string(), set_u0 as *const u8); + map.insert("rhs".to_string(), rhs as *const u8); + map.insert("mass".to_string(), mass as *const u8); + map.insert("calc_out".to_string(), calc_out as *const u8); + map.insert("calc_stop".to_string(), stop as *const u8); + map.insert("set_id".to_string(), set_id as *const u8); + map.insert("get_dims".to_string(), get_dims as *const u8); + map.insert("set_inputs".to_string(), set_inputs as *const u8); + map.insert("get_inputs".to_string(), get_inputs as *const u8); + map.insert("set_constants".to_string(), set_constants as *const u8); + + map.insert("set_u0_grad".to_string(), set_u0_grad as *const u8); + map.insert("rhs_grad".to_string(), rhs_grad as *const u8); + map.insert("calc_out_grad".to_string(), calc_out_grad as *const u8); + map.insert("set_inputs_grad".to_string(), set_inputs_grad as *const u8); + + #[cfg(feature = "diffsl-ext-sens")] + { + map.insert("set_u0_rgrad".to_string(), set_u0_rgrad as *const u8); + map.insert("rhs_rgrad".to_string(), rhs_rgrad as *const u8); + map.insert("calc_out_rgrad".to_string(), calc_out_rgrad as *const u8); + map.insert( + "set_inputs_rgrad".to_string(), + set_inputs_rgrad as *const u8, + ); + + map.insert("rhs_srgrad".to_string(), rhs_srgrad as *const u8); + map.insert("calc_out_srgrad".to_string(), calc_out_srgrad as *const u8); + } + + #[cfg(feature = "diffsl-ext-sens")] + { + map.insert("rhs_sgrad".to_string(), rhs_sgrad as *const u8); + map.insert("calc_out_sgrad".to_string(), calc_out_sgrad as *const u8); + map.insert("set_u0_sgrad".to_string(), set_u0_sgrad as *const u8); + } + + map +} diff --git a/diffsol/src/ode_equations/mod.rs b/diffsol/src/ode_equations/mod.rs index b8800fbb..1bbe2096 100644 --- a/diffsol/src/ode_equations/mod.rs +++ b/diffsol/src/ode_equations/mod.rs @@ -8,6 +8,9 @@ use serde::Serialize; pub mod adjoint_equations; #[cfg(feature = "diffsl")] pub mod diffsl; +#[cfg(feature = "diffsl-ext")] +pub mod external_linkage; + pub mod sens_equations; pub mod test_models; diff --git a/diffsol/src/ode_solver/builder.rs b/diffsol/src/ode_solver/builder.rs index eb9b5ee8..af878c77 100644 --- a/diffsol/src/ode_solver/builder.rs +++ b/diffsol/src/ode_solver/builder.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "diffsl-ext")] +use crate::ode_equations::external_linkage::ExtLinkModule; use crate::{ error::{DiffsolError, OdeSolverError}, matrix::dense_nalgebra_serial::NalgebraMat, @@ -996,6 +998,23 @@ where self.build_from_eqn(eqn) } + #[cfg(feature = "diffsl-ext")] + pub fn build_from_external_linkage( + mut self, + ) -> Result>, DiffsolError> + where + M: Matrix, + { + let include_sensitivities = M::is_sparse() && cfg!(feature = "diffsl-ext-sens"); + let eqn = crate::DiffSl::from_external_linkage(self.ctx.clone(), include_sensitivities)?; + // if the user hasn't set the parameters, resize them to match the number of parameters in the equations + let nparams = eqn.rhs().nparams(); + if self.p.len() != nparams && self.p.is_empty() { + self.p.resize(nparams, 0.0); + } + self.build_from_eqn(eqn) + } + /// Build an ODE problem from a set of equations pub fn build_from_eqn(self, mut eqn: Eqn) -> Result, DiffsolError> where diff --git a/examples/external-linkage/Cargo.toml b/examples/external-linkage/Cargo.toml new file mode 100644 index 00000000..c660495d --- /dev/null +++ b/examples/external-linkage/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "external-linkage" +version = "0.1.0" +edition.workspace = true +publish = false + +[dependencies] +diffsol = { path = "../../diffsol", features = ["diffsl-ext"] } +nalgebra = { workspace = true } +plotly = { workspace = true } \ No newline at end of file diff --git a/examples/external-linkage/src/main.rs b/examples/external-linkage/src/main.rs new file mode 100644 index 00000000..797a900c --- /dev/null +++ b/examples/external-linkage/src/main.rs @@ -0,0 +1,200 @@ +use diffsol::{ + ode_equations::external_linkage::{RealType, UIntType}, + DenseMatrix, OdeBuilder, OdeSolverMethod, +}; +type M = diffsol::NalgebraMat; +type LS = diffsol::NalgebraLU; + +#[no_mangle] +extern "C" fn stop( + _time: RealType, + _u: *const RealType, + _data: *mut RealType, + _root: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn rhs( + _time: RealType, + u: *const RealType, + data: *mut RealType, + rr: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + let r = unsafe { *data.add(0) }; + let k = unsafe { *data.add(1) }; + let u = unsafe { *u.add(0) }; + let f = r * u * (1.0 - u / k); + unsafe { + *rr.add(0) = f; + } +} + +#[no_mangle] +extern "C" fn rhs_grad( + _time: RealType, + u: *const RealType, + du: *const RealType, + data: *const RealType, + _ddata: *mut RealType, + _rr: *const RealType, + drr: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + let r = unsafe { *data.add(0) }; + let k = unsafe { *data.add(1) }; + let u = unsafe { *u.add(0) }; + let du = unsafe { *du.add(0) }; + + let df_du = r * (1.0 - 2.0 * u / k); + unsafe { + *drr.add(0) = df_du * du; + } +} + +#[no_mangle] +extern "C" fn mass( + _time: RealType, + _u: *const RealType, + _data: *mut RealType, + _mv: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn set_u0( + u: *mut RealType, + _data: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + unsafe { + *u.add(0) = 0.1; + } +} + +#[no_mangle] +extern "C" fn set_u0_grad( + _u: *const RealType, + du: *mut RealType, + _data: *const RealType, + _ddata: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + unsafe { + *du.add(0) = 0.0; + } +} + +#[no_mangle] +extern "C" fn calc_out( + _time: RealType, + _u: *const RealType, + _data: *mut RealType, + _out: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn calc_out_grad( + _time: RealType, + _u: *const RealType, + _du: *const RealType, + _data: *const RealType, + _ddata: *mut RealType, + _out: *const RealType, + _dout: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn get_dims( + states: *mut UIntType, + inputs: *mut UIntType, + outputs: *mut UIntType, + data: *mut UIntType, + stop: *mut UIntType, + has_mass: *mut UIntType, +) { + unsafe { + *states = 1; + *inputs = 2; + *outputs = 0; + *data = 2; + *stop = 0; + *has_mass = 0; + } +} +#[no_mangle] +extern "C" fn set_inputs(inputs: *const RealType, data: *mut RealType) { + unsafe { + *data.add(0) = *inputs.add(0); + *data.add(1) = *inputs.add(1); + } +} +#[no_mangle] +extern "C" fn get_inputs(inputs: *mut RealType, data: *const RealType) { + unsafe { + *inputs.add(0) = *data.add(0); + *inputs.add(1) = *data.add(1); + } +} +#[no_mangle] +extern "C" fn set_inputs_grad( + _inputs: *const RealType, + dinputs: *const RealType, + _data: *const RealType, + ddata: *mut RealType, +) { + unsafe { + *ddata.add(0) = *dinputs.add(0); + *ddata.add(1) = *dinputs.add(1); + } +} + +#[no_mangle] +extern "C" fn set_id(id: *mut RealType) { + unsafe { + *id.add(0) = 1.0; + } +} + +#[no_mangle] +extern "C" fn set_constants(_thread_id: UIntType, _thread_dim: UIntType) {} + +fn main() { + let r = 1.0; + let k = 10.0; + let y0 = 0.1; + let problem = OdeBuilder::::new() + .rtol(1e-6) + .p([r, k]) + .build_from_external_linkage() + .unwrap(); + let mut solver = problem.bdf::().unwrap(); + let t = 0.4; + let (ys, ts) = solver.solve(t).unwrap(); + for (i, t) in ts.iter().enumerate() { + let y = ys.column(i); + let expect_y = k / (1.0 + (k - y0) * (-r * t).exp() / y0); + assert!( + (y[0] - expect_y).abs() < 1e-6, + "at t={:.3}, got {}, expected {}", + t, + y[0], + expect_y + ); + } +} From d58352a6e8296c72cfff8472f72df5ede9ff389a Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 23 Oct 2025 22:51:08 +0000 Subject: [PATCH 2/4] feature fixes --- diffsol/src/ode_equations/diffsl.rs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/diffsol/src/ode_equations/diffsl.rs b/diffsol/src/ode_equations/diffsl.rs index 5ef2f42e..ce026a5a 100644 --- a/diffsol/src/ode_equations/diffsl.rs +++ b/diffsol/src/ode_equations/diffsl.rs @@ -7,16 +7,15 @@ use diffsl::{ Compiler, }; +#[cfg(feature = "diffsl-ext")] +use crate::ode_equations::external_linkage::ExtLinkModule; use crate::{ - error::DiffsolError, - find_jacobian_non_zeros, find_matrix_non_zeros, find_sens_non_zeros, - jacobian::JacobianColoring, - matrix::sparsity::MatrixSparsity, - ode_equations::external_linkage::{symbol_map, ExtLinkModule}, - op::nonlinear_op::NonLinearOpJacobian, - ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix, - MatrixHost, NonLinearOp, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint, - OdeEquations, OdeEquationsRef, Op, Scale, Vector, VectorHost, + error::DiffsolError, find_jacobian_non_zeros, find_matrix_non_zeros, find_sens_non_zeros, + jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity, + op::nonlinear_op::NonLinearOpJacobian, ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, + LinearOp, LinearOpTranspose, Matrix, MatrixHost, NonLinearOp, NonLinearOpAdjoint, + NonLinearOpSens, NonLinearOpSensAdjoint, OdeEquations, OdeEquationsRef, Op, Scale, Vector, + VectorHost, }; pub type T = f64; @@ -99,8 +98,11 @@ impl, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContex } } +#[cfg(feature = "diffsl-ext")] impl> DiffSlContext { pub fn from_external_linkage(nthreads: usize, ctx: M::C) -> Result { + use crate::ode_equations::external_linkage::symbol_map; + let mode = match nthreads { 0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None), 1 => diffsl::execution::compiler::CompilerMode::SingleThreaded, @@ -234,6 +236,7 @@ impl, CG: CodegenModule> DiffSl { } } +#[cfg(feature = "diffsl-ext")] impl> DiffSl { pub fn from_external_linkage( ctx: M::C, From ee417476c201b095dc69a50bf644cc80480718dc Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 24 Oct 2025 16:47:50 +0000 Subject: [PATCH 3/4] add integration test, add diffsl-ext to CI --- .github/workflows/rust.yml | 6 +- diffsol/src/ode_equations/diffsl.rs | 1 + diffsol/tests/external_linkage.rs | 205 ++++++++++++++++++++++++++ examples/external-linkage/src/main.rs | 13 +- 4 files changed, 210 insertions(+), 15 deletions(-) create mode 100644 diffsol/tests/external_linkage.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 132a90d2..ff2bf5ba 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -117,7 +117,7 @@ jobs: - name: Set features variable and install dependencies (Ubuntu) if: matrix.os == 'ubuntu-latest' && matrix.book == false run: | - echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-llvm17 --features diffsl-cranelift --features suitesparse" >> $GITHUB_ENV + echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-ext --features diffsl-llvm17 --features diffsl-cranelift --features suitesparse" >> $GITHUB_ENV sudo apt-get update sudo apt-get install -y libsuitesparse-dev - name: Install Trunk @@ -134,11 +134,11 @@ jobs: - name: Set features variable and install dependencies (macOS) if: matrix.os == 'macos-13' || matrix.os == 'macos-latest' run: | - echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-llvm17 --features diffsl-cranelift" >> $GITHUB_ENV + echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-ext --features diffsl-llvm17 --features diffsl-cranelift" >> $GITHUB_ENV - name: Set features variable and install dependencies (Windows) if: matrix.os == 'windows-latest' run: | - echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-cranelift" >> $GITHUB_ENV + echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-ext --features diffsl-cranelift" >> $GITHUB_ENV - name: Run tests - default features if: matrix.tests == true run: cargo test --verbose diff --git a/diffsol/src/ode_equations/diffsl.rs b/diffsol/src/ode_equations/diffsl.rs index ce026a5a..656bcb9c 100644 --- a/diffsol/src/ode_equations/diffsl.rs +++ b/diffsol/src/ode_equations/diffsl.rs @@ -821,6 +821,7 @@ mod tests { diffsl_logistic_growth::, diffsl::LlvmModule>(); } + #[allow(dead_code)] fn diffsl_logistic_growth< M: Matrix + DefaultSolver, CG: CodegenModuleJit + CodegenModuleCompile, diff --git a/diffsol/tests/external_linkage.rs b/diffsol/tests/external_linkage.rs new file mode 100644 index 00000000..d74af078 --- /dev/null +++ b/diffsol/tests/external_linkage.rs @@ -0,0 +1,205 @@ +#[cfg(not(feature = "diffsl-ext"))] +type RealType = f64; +#[cfg(not(feature = "diffsl-ext"))] +type UIntType = u32; +#[cfg(feature = "diffsl-ext")] +use diffsol::ode_equations::external_linkage::{RealType, UIntType}; + +#[no_mangle] +extern "C" fn stop( + _time: RealType, + _u: *const RealType, + _data: *mut RealType, + _root: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn rhs( + _time: RealType, + u: *const RealType, + data: *mut RealType, + rr: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + let r = unsafe { *data.add(0) }; + let k = unsafe { *data.add(1) }; + let u = unsafe { *u.add(0) }; + let f = r * u * (1.0 - u / k); + unsafe { + *rr.add(0) = f; + } +} + +#[no_mangle] +extern "C" fn rhs_grad( + _time: RealType, + u: *const RealType, + du: *const RealType, + data: *const RealType, + _ddata: *mut RealType, + _rr: *const RealType, + drr: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + let r = unsafe { *data.add(0) }; + let k = unsafe { *data.add(1) }; + let u = unsafe { *u.add(0) }; + let du = unsafe { *du.add(0) }; + + let df_du = r * (1.0 - 2.0 * u / k); + unsafe { + *drr.add(0) = df_du * du; + } +} + +#[no_mangle] +extern "C" fn mass( + _time: RealType, + _u: *const RealType, + _data: *mut RealType, + _mv: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn set_u0( + u: *mut RealType, + _data: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + unsafe { + *u.add(0) = 0.1; + } +} + +#[no_mangle] +extern "C" fn set_u0_grad( + _u: *const RealType, + du: *mut RealType, + _data: *const RealType, + _ddata: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { + unsafe { + *du.add(0) = 0.0; + } +} + +#[no_mangle] +extern "C" fn calc_out( + _time: RealType, + _u: *const RealType, + _data: *mut RealType, + _out: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn calc_out_grad( + _time: RealType, + _u: *const RealType, + _du: *const RealType, + _data: *const RealType, + _ddata: *mut RealType, + _out: *const RealType, + _dout: *mut RealType, + _thread_id: UIntType, + _thread_dim: UIntType, +) { +} + +#[no_mangle] +extern "C" fn get_dims( + states: *mut UIntType, + inputs: *mut UIntType, + outputs: *mut UIntType, + data: *mut UIntType, + stop: *mut UIntType, + has_mass: *mut UIntType, +) { + unsafe { + *states = 1; + *inputs = 2; + *outputs = 0; + *data = 2; + *stop = 0; + *has_mass = 0; + } +} +#[no_mangle] +extern "C" fn set_inputs(inputs: *const RealType, data: *mut RealType) { + unsafe { + *data.add(0) = *inputs.add(0); + *data.add(1) = *inputs.add(1); + } +} +#[no_mangle] +extern "C" fn get_inputs(inputs: *mut RealType, data: *const RealType) { + unsafe { + *inputs.add(0) = *data.add(0); + *inputs.add(1) = *data.add(1); + } +} +#[no_mangle] +extern "C" fn set_inputs_grad( + _inputs: *const RealType, + dinputs: *const RealType, + _data: *const RealType, + ddata: *mut RealType, +) { + unsafe { + *ddata.add(0) = *dinputs.add(0); + *ddata.add(1) = *dinputs.add(1); + } +} + +#[no_mangle] +extern "C" fn set_id(id: *mut RealType) { + unsafe { + *id.add(0) = 1.0; + } +} + +#[no_mangle] +extern "C" fn set_constants(_thread_id: UIntType, _thread_dim: UIntType) {} + +#[cfg(feature = "diffsl-ext")] +#[test] +fn logistic() { + use diffsol::{DenseMatrix, OdeBuilder, OdeSolverMethod}; + type M = diffsol::NalgebraMat; + type LS = diffsol::NalgebraLU; + let r = 1.0; + let k = 10.0; + let y0 = 0.1; + let problem = OdeBuilder::::new() + .rtol(1e-6) + .p([r, k]) + .build_from_external_linkage() + .unwrap(); + let mut solver = problem.bdf::().unwrap(); + let t = 0.4; + let (ys, ts) = solver.solve(t).unwrap(); + for (i, t) in ts.iter().enumerate() { + let y = ys.column(i); + let expect_y = k / (1.0 + (k - y0) * (-r * t).exp() / y0); + assert!( + (y[0] - expect_y).abs() < 1e-6, + "at t={:.3}, got {}, expected {}", + t, + y[0], + expect_y + ); + } +} diff --git a/examples/external-linkage/src/main.rs b/examples/external-linkage/src/main.rs index 797a900c..d65b48a4 100644 --- a/examples/external-linkage/src/main.rs +++ b/examples/external-linkage/src/main.rs @@ -185,16 +185,5 @@ fn main() { .unwrap(); let mut solver = problem.bdf::().unwrap(); let t = 0.4; - let (ys, ts) = solver.solve(t).unwrap(); - for (i, t) in ts.iter().enumerate() { - let y = ys.column(i); - let expect_y = k / (1.0 + (k - y0) * (-r * t).exp() / y0); - assert!( - (y[0] - expect_y).abs() < 1e-6, - "at t={:.3}, got {}, expected {}", - t, - y[0], - expect_y - ); - } + let _ = solver.solve(t).unwrap(); } From 1d97505d3011dcfbe6309adae1de6739b355f635 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 24 Oct 2025 16:57:32 +0000 Subject: [PATCH 4/4] fix example --- examples/external-linkage/src/main.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/external-linkage/src/main.rs b/examples/external-linkage/src/main.rs index d65b48a4..20aa01b5 100644 --- a/examples/external-linkage/src/main.rs +++ b/examples/external-linkage/src/main.rs @@ -1,6 +1,6 @@ use diffsol::{ ode_equations::external_linkage::{RealType, UIntType}, - DenseMatrix, OdeBuilder, OdeSolverMethod, + OdeBuilder, OdeSolverMethod, }; type M = diffsol::NalgebraMat; type LS = diffsol::NalgebraLU; @@ -177,7 +177,6 @@ extern "C" fn set_constants(_thread_id: UIntType, _thread_dim: UIntType) {} fn main() { let r = 1.0; let k = 10.0; - let y0 = 0.1; let problem = OdeBuilder::::new() .rtol(1e-6) .p([r, k])