Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
- name: Set features variable and install dependencies (Ubuntu)
if: matrix.os == 'ubuntu-latest' && matrix.book == false
run: |
echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-llvm17 --features diffsl-cranelift --features suitesparse" >> $GITHUB_ENV
echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-ext --features diffsl-llvm17 --features diffsl-cranelift --features suitesparse" >> $GITHUB_ENV
sudo apt-get update
sudo apt-get install -y libsuitesparse-dev
- name: Install Trunk
Expand All @@ -134,11 +134,11 @@ jobs:
- name: Set features variable and install dependencies (macOS)
if: matrix.os == 'macos-13' || matrix.os == 'macos-latest'
run: |
echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-llvm17 --features diffsl-cranelift" >> $GITHUB_ENV
echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-ext --features diffsl-llvm17 --features diffsl-cranelift" >> $GITHUB_ENV
- name: Set features variable and install dependencies (Windows)
if: matrix.os == 'windows-latest'
run: |
echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-cranelift" >> $GITHUB_ENV
echo "ADDITIONAL_FEATURES_FLAGS=--features diffsl-ext --features diffsl-cranelift" >> $GITHUB_ENV
- name: Run tests - default features
if: matrix.tests == true
run: cargo test --verbose
Expand Down
2 changes: 2 additions & 0 deletions diffsol/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ cuda = ["dep:cudarc"]
sundials = ["suitesparse_sys", "bindgen", "cc"]
suitesparse = ["suitesparse_sys"]
diffsl = []
diffsl-ext = ["dep:diffsl", "diffsl"]
diffsl-ext-sens = ["diffsl-ext"]
diffsl-cranelift = ["diffsl/cranelift", "diffsl"]
diffsl-llvm = []
diffsl-llvm15 = ["diffsl/llvm15-0", "diffsl", "diffsl-llvm"]
Expand Down
106 changes: 66 additions & 40 deletions diffsol/src/ode_equations/diffsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use diffsl::{
Compiler,
};

#[cfg(feature = "diffsl-ext")]
use crate::ode_equations::external_linkage::ExtLinkModule;
use crate::{
error::DiffsolError, find_jacobian_non_zeros, find_matrix_non_zeros, find_sens_non_zeros,
jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity,
Expand Down Expand Up @@ -40,37 +42,8 @@ pub struct DiffSlContext<M: Matrix<T = T>, CG: CodegenModule> {
ctx: M::C,
}

impl<M: Matrix<T = T>, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContext<M, CG> {
/// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/).
/// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE.
///
/// # Arguments
///
/// * `text` - The text of the ODE equations in the DiffSL language.
/// * `nthreads` - The number of threads to use for code generation (0 for automatic, 1 for single-threaded).
///
pub fn new(text: &str, nthreads: usize, ctx: M::C) -> Result<Self, DiffsolError> {
let mode = match nthreads {
0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
_ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
};
let compiler = Compiler::from_discrete_str(text, mode)
.map_err(|e| DiffsolError::Other(e.to_string()))?;
let (nstates, _nparams, _nout, _ndata, _nroots, _has_mass) = compiler.get_dims();

let compiler = if nthreads == 0 {
let num_cpus = std::thread::available_parallelism().unwrap().get();
let nthreads = num_cpus.min(nstates / 1000).max(1);
Compiler::from_discrete_str(
text,
diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
)
.map_err(|e| DiffsolError::Other(e.to_string()))?
} else {
compiler
};

impl<M: Matrix<T = T>, CG: CodegenModule> DiffSlContext<M, CG> {
fn from_compiler(compiler: Compiler<CG>, ctx: M::C) -> Result<Self, DiffsolError> {
let (nstates, nparams, nout, _ndata, nroots, has_mass) = compiler.get_dims();

let has_root = nroots > 0;
Expand Down Expand Up @@ -104,6 +77,44 @@ impl<M: Matrix<T = T>, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContex
}
}

impl<M: Matrix<T = T>, CG: CodegenModuleCompile + CodegenModuleJit> DiffSlContext<M, CG> {
/// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/).
/// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE.
///
/// # Arguments
///
/// * `text` - The text of the ODE equations in the DiffSL language.
/// * `nthreads` - The number of threads to use for code generation (0 for automatic, 1 for single-threaded).
///
pub fn new(text: &str, nthreads: usize, ctx: M::C) -> Result<Self, DiffsolError> {
let mode = match nthreads {
0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
_ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
};
let compiler = Compiler::from_discrete_str(text, mode)
.map_err(|e| DiffsolError::Other(e.to_string()))?;
Self::from_compiler(compiler, ctx)
}
}

#[cfg(feature = "diffsl-ext")]
impl<M: Matrix<T = T>> DiffSlContext<M, ExtLinkModule> {
pub fn from_external_linkage(nthreads: usize, ctx: M::C) -> Result<Self, DiffsolError> {
use crate::ode_equations::external_linkage::symbol_map;

let mode = match nthreads {
0 => diffsl::execution::compiler::CompilerMode::MultiThreaded(None),
1 => diffsl::execution::compiler::CompilerMode::SingleThreaded,
_ => diffsl::execution::compiler::CompilerMode::MultiThreaded(Some(nthreads)),
};
let symbol_map = symbol_map();
let compiler = Compiler::new(ExtLinkModule, symbol_map, mode)
.map_err(|e| DiffsolError::Other(e.to_string()))?;
Self::from_compiler(compiler, ctx)
}
}

impl<M: Matrix<T = T>, CG: CodegenModuleJit + CodegenModuleCompile> Default
for DiffSlContext<M, CG>
{
Expand Down Expand Up @@ -137,15 +148,7 @@ pub struct DiffSl<M: Matrix<T = T>, CG: CodegenModule> {
rhs_sens_adjoint_coloring: Option<JacobianColoring<M>>,
}

impl<M: MatrixHost<T = T>, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl<M, CG> {
pub fn compile(
code: &str,
ctx: M::C,
include_sensitivities: bool,
) -> Result<Self, DiffsolError> {
let context = DiffSlContext::<M, CG>::new(code, 1, ctx)?;
Ok(Self::from_context(context, include_sensitivities))
}
impl<M: MatrixHost<T = T>, CG: CodegenModule> DiffSl<M, CG> {
pub fn from_context(context: DiffSlContext<M, CG>, include_sensitivities: bool) -> Self {
let mut ret = Self {
context,
Expand Down Expand Up @@ -233,6 +236,28 @@ impl<M: MatrixHost<T = T>, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl<M
}
}

#[cfg(feature = "diffsl-ext")]
impl<M: MatrixHost<T = T>> DiffSl<M, ExtLinkModule> {
pub fn from_external_linkage(
ctx: M::C,
include_sensitivities: bool,
) -> Result<Self, DiffsolError> {
let context = DiffSlContext::<M, _>::from_external_linkage(1, ctx)?;
Ok(Self::from_context(context, include_sensitivities))
}
}

impl<M: MatrixHost<T = T>, CG: CodegenModuleJit + CodegenModuleCompile> DiffSl<M, CG> {
pub fn compile(
code: &str,
ctx: M::C,
include_sensitivities: bool,
) -> Result<Self, DiffsolError> {
let context = DiffSlContext::<M, CG>::new(code, 1, ctx)?;
Ok(Self::from_context(context, include_sensitivities))
}
}

pub struct DiffSlRoot<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
pub struct DiffSlOut<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
pub struct DiffSlRhs<'a, M: Matrix<T = T>, CG: CodegenModule>(&'a DiffSl<M, CG>);
Expand Down Expand Up @@ -796,6 +821,7 @@ mod tests {
diffsl_logistic_growth::<crate::FaerSparseMat<f64>, diffsl::LlvmModule>();
}

#[allow(dead_code)]
fn diffsl_logistic_growth<
M: Matrix<V: VectorHost + DefaultDenseMatrix, T = f64> + DefaultSolver,
CG: CodegenModuleJit + CodegenModuleCompile,
Expand Down
Loading
Loading