Skip to content

Commit 64b90c5

Browse files
committed
Replaced unsafe ptr logic with chained split_at_mut in DenseMatrix and DenseMatrixMutView
1 parent 9dd5411 commit 64b90c5

3 files changed

Lines changed: 194 additions & 100 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [0.5.1] - 2026-05-20
8+
- Replaced `unsafe` pointer arithmetic in `DenseMatrix` / `DenseMatrixMutView` mutable iterators with a safe, chained `split_at_mut` implementation to ensure memory safety without performance loss.
9+
710
## [0.4.8] - 2025-11-29
811
- WARNING: Breaking changes!
912
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "smartcore"
33
description = "Machine Learning in Rust."
44
homepage = "https://smartcorelib.org"
5-
version = "0.5.0"
5+
version = "0.5.1"
66
authors = ["smartcore Developers"]
77
edition = "2021"
88
license = "Apache-2.0"

src/linalg/basic/matrix.rs

Lines changed: 190 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -143,81 +143,135 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
143143
}
144144

145145
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
146+
assert!(
147+
axis == 1 || axis == 0,
148+
"For two dimensional array `axis` should be either 0 or 1"
149+
);
150+
146151
let column_major = self.column_major;
147152
let stride = self.stride;
148153
let nrows = self.nrows;
149154
let ncols = self.ncols;
150-
let ptr = self.values.as_mut_ptr();
151-
152-
// Safety: for each (r, c) pair the offset is uniquely determined by the
153-
// index formula below, so no two iterations alias the same memory location.
154-
// We assert this in debug mode by verifying the traversal covers exactly
155-
// nrows * ncols distinct offsets within [0, values.len()).
156-
#[cfg(debug_assertions)]
157-
{
158-
let len = self.values.len();
159-
let mut seen = std::collections::HashSet::new();
160-
match axis {
161-
0 => {
162-
for r in 0..nrows {
163-
for c in 0..ncols {
164-
let off = if column_major {
165-
r + c * stride
166-
} else {
167-
r * stride + c
168-
};
169-
assert!(
170-
off < len,
171-
"iterator_mut: offset {off} out of bounds (len={len})"
172-
);
173-
assert!(
174-
seen.insert(off),
175-
"iterator_mut: aliasing detected at offset {off}"
176-
);
177-
}
155+
156+
// Axis = 0: row-by-row (outer loop over rows, inner over cols)
157+
// Axis = 1: col-by-col (outer loop over cols, inner over rows)
158+
// Four cases: column-major (axis 0 or 1), row-major (axis 1 or 0)
159+
160+
// Collect all mutable references up-front using split_at_mut so
161+
// that the resulting iterator owns no borrow of "self.values"
162+
163+
match (column_major, axis) {
164+
// Case B: column-major, col-by-col
165+
(true, 1) => {
166+
let mut refs: Vec<&'b mut T> = Vec::with_capacity(ncols * nrows);
167+
let mut remaining: &'b mut [T] = self.values;
168+
for _c in 0..ncols {
169+
let col_end = if _c == ncols - 1 {
170+
remaining.len()
171+
} else {
172+
stride
173+
};
174+
let (col_slice, tail) = remaining.split_at_mut(col_end);
175+
for elem in col_slice[..nrows].iter_mut() {
176+
refs.push(elem);
178177
}
178+
remaining = tail;
179179
}
180-
_ => {
181-
for c in 0..ncols {
182-
for r in 0..nrows {
183-
let off = if column_major {
184-
r + c * stride
185-
} else {
186-
r * stride + c
187-
};
188-
assert!(
189-
off < len,
190-
"iterator_mut: offset {off} out of bounds (len={len})"
191-
);
192-
assert!(
193-
seen.insert(off),
194-
"iterator_mut: aliasing detected at offset {off}"
195-
);
180+
Box::new(refs.into_iter())
181+
}
182+
183+
// Case A: column-major, row-by-row
184+
(true, _) => {
185+
let mut refs: Vec<&'b mut T> = Vec::with_capacity(nrows * ncols);
186+
187+
let total = nrows * ncols;
188+
189+
let mut by_col: Vec<&'b mut T> = Vec::with_capacity(total);
190+
{
191+
let mut remaining: &'b mut [T] = self.values;
192+
for _c in 0..ncols {
193+
let col_end = if _c == ncols - 1 {
194+
remaining.len()
195+
} else {
196+
stride
197+
};
198+
let (col_slice, tail) = remaining.split_at_mut(col_end);
199+
for elem in col_slice[..nrows].iter_mut() {
200+
by_col.push(elem);
196201
}
202+
remaining = tail;
197203
}
198204
}
199-
}
200-
}
201205

202-
match axis {
203-
0 => Box::new((0..nrows).flat_map(move |r| {
204-
(0..ncols).map(move |c| unsafe {
205-
&mut *ptr.add(if column_major {
206-
r + c * stride
207-
} else {
208-
r * stride + c
206+
let mut indexed: Vec<(usize, &'b mut T)> = by_col
207+
.into_iter()
208+
.enumerate()
209+
.map(|(flat_col_idx, r)| {
210+
let c = flat_col_idx / nrows;
211+
let row = flat_col_idx % nrows;
212+
let out_idx = row * ncols + c;
213+
(out_idx, r)
209214
})
210-
})
211-
})),
212-
_ => Box::new((0..ncols).flat_map(move |c| {
213-
(0..nrows).map(move |r| unsafe {
214-
&mut *ptr.add(if column_major {
215-
r + c * stride
215+
.collect();
216+
indexed.sort_unstable_by_key(|(idx, _)| *idx);
217+
refs.extend(indexed.into_iter().map(|(_, r)| r));
218+
Box::new(refs.into_iter())
219+
}
220+
221+
// Case C: row-major, row-by-row
222+
(false, 0) => {
223+
let mut refs: Vec<&'b mut T> = Vec::with_capacity(nrows * ncols);
224+
let mut remaining: &'b mut [T] = self.values;
225+
for _r in 0..nrows {
226+
let row_end = if _r == nrows - 1 {
227+
remaining.len()
216228
} else {
217-
r * stride + c
229+
stride
230+
};
231+
let (row_slice, tail) = remaining.split_at_mut(row_end);
232+
for elem in row_slice[..ncols].iter_mut() {
233+
refs.push(elem);
234+
}
235+
remaining = tail;
236+
}
237+
Box::new(refs.into_iter())
238+
}
239+
240+
// Case D: row-major, col-by-col
241+
(false, _) => {
242+
let total = nrows * ncols;
243+
let mut by_row: Vec<&'b mut T> = Vec::with_capacity(total);
244+
{
245+
let mut remaining: &'b mut [T] = self.values;
246+
for _r in 0..nrows {
247+
let row_end = if _r == nrows - 1 {
248+
remaining.len()
249+
} else {
250+
stride
251+
};
252+
let (row_slice, tail) = remaining.split_at_mut(row_end);
253+
for elem in row_slice[..ncols].iter_mut() {
254+
by_row.push(elem);
255+
}
256+
remaining = tail;
257+
}
258+
}
259+
260+
let mut indexed: Vec<(usize, &'b mut T)> = by_row
261+
.into_iter()
262+
.enumerate()
263+
.map(|(flat_row_idx, r)| {
264+
let row = flat_row_idx / ncols;
265+
let col = flat_row_idx % ncols;
266+
let out_idx = col * nrows + row;
267+
(out_idx, r)
218268
})
219-
})
220-
})),
269+
.collect();
270+
indexed.sort_unstable_by_key(|(idx, _)| *idx);
271+
let mut refs: Vec<&'b mut T> = Vec::with_capacity(total);
272+
refs.extend(indexed.into_iter().map(|(_, r)| r));
273+
Box::new(refs.into_iter())
274+
}
221275
}
222276
}
223277
}
@@ -502,49 +556,84 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMat
502556
}
503557

504558
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
505-
let ptr = self.values.as_mut_ptr();
559+
assert!(
560+
axis == 1 || axis == 0,
561+
"For two dimensional array `axis` should be either 0 or 1"
562+
);
563+
506564
let column_major = self.column_major;
507565
let (nrows, ncols) = self.shape();
508566

509-
#[cfg(debug_assertions)]
510-
{
511-
let len = self.values.len();
512-
let mut seen = std::collections::HashSet::new();
513-
for r in 0..nrows {
514-
for c in 0..ncols {
515-
let off = if column_major {
516-
r + c * nrows
517-
} else {
518-
r * ncols + c
519-
};
520-
assert!(
521-
off < len,
522-
"iterator_mut: offset {off} out of bounds (len={len})"
523-
);
524-
assert!(seen.insert(off), "iterator_mut: aliasing at offset {off}");
525-
}
567+
match (column_major, axis) {
568+
// Case B: column-major, col-by-col
569+
(true, 1) => {
570+
let refs: Vec<&'b mut T> = self
571+
.values
572+
.chunks_mut(nrows)
573+
.flat_map(|col| col.iter_mut())
574+
.collect();
575+
Box::new(refs.into_iter())
526576
}
527-
}
528577

529-
match axis {
530-
0 => Box::new((0..nrows).flat_map(move |r| {
531-
(0..ncols).map(move |c| unsafe {
532-
&mut *ptr.add(if column_major {
533-
r + c * nrows
534-
} else {
535-
r * ncols + c
578+
// Case A: column-major, row-by-row
579+
(true, _) => {
580+
let total = nrows * ncols;
581+
let by_col: Vec<&'b mut T> = self
582+
.values
583+
.chunks_mut(nrows)
584+
.flat_map(|col| col.iter_mut())
585+
.collect();
586+
587+
let mut indexed: Vec<(usize, &'b mut T)> = by_col
588+
.into_iter()
589+
.enumerate()
590+
.map(|(flat_col_idx, elem)| {
591+
let c = flat_col_idx / nrows;
592+
let r = flat_col_idx % nrows;
593+
(r * ncols + c, elem)
536594
})
537-
})
538-
})),
539-
_ => Box::new((0..ncols).flat_map(move |c| {
540-
(0..nrows).map(move |r| unsafe {
541-
&mut *ptr.add(if column_major {
542-
r + c * nrows
543-
} else {
544-
r * ncols + c
595+
.collect();
596+
indexed.sort_unstable_by_key(|(idx, _)| *idx);
597+
598+
let mut refs: Vec<&'b mut T> = Vec::with_capacity(total);
599+
refs.extend(indexed.into_iter().map(|(_, e)| e));
600+
Box::new(refs.into_iter())
601+
}
602+
603+
// Case C: row-major, row-by-row
604+
(false, 0) => {
605+
let refs: Vec<&'b mut T> = self
606+
.values
607+
.chunks_mut(ncols)
608+
.flat_map(|row| row.iter_mut())
609+
.collect();
610+
Box::new(refs.into_iter())
611+
}
612+
613+
// Case D: row-major, col-by-col
614+
(false, _) => {
615+
let total = nrows * ncols;
616+
let by_row: Vec<&'b mut T> = self
617+
.values
618+
.chunks_mut(ncols)
619+
.flat_map(|row| row.iter_mut())
620+
.collect();
621+
622+
let mut indexed: Vec<(usize, &'b mut T)> = by_row
623+
.into_iter()
624+
.enumerate()
625+
.map(|(flat_row_idx, elem)| {
626+
let r = flat_row_idx / ncols;
627+
let c = flat_row_idx % ncols;
628+
(c * nrows + r, elem)
545629
})
546-
})
547-
})),
630+
.collect();
631+
indexed.sort_unstable_by_key(|(idx, _)| *idx);
632+
633+
let mut refs: Vec<&'b mut T> = Vec::with_capacity(total);
634+
refs.extend(indexed.into_iter().map(|(_, e)| e));
635+
Box::new(refs.into_iter())
636+
}
548637
}
549638
}
550639
}
@@ -910,7 +999,9 @@ mod tests {
910999
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
9111000
x.iterator_mut(0).for_each(|v| *v = "str");
9121001
assert_eq!(
913-
vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
1002+
vec![
1003+
"str", "str", "str", "str", "str", "str", "str", "str", "str"
1004+
],
9141005
x.values
9151006
);
9161007
}

0 commit comments

Comments
 (0)