@@ -143,81 +143,135 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
143143 }
144144
145145 fn iter_mut < ' b > ( & ' b mut self , axis : u8 ) -> Box < dyn Iterator < Item = & ' b mut T > + ' b > {
146+ assert ! (
147+ axis == 1 || axis == 0 ,
148+ "For two dimensional array `axis` should be either 0 or 1"
149+ ) ;
150+
146151 let column_major = self . column_major ;
147152 let stride = self . stride ;
148153 let nrows = self . nrows ;
149154 let ncols = self . ncols ;
150- let ptr = self . values . as_mut_ptr ( ) ;
151-
152- // Safety: for each (r, c) pair the offset is uniquely determined by the
153- // index formula below, so no two iterations alias the same memory location.
154- // We assert this in debug mode by verifying the traversal covers exactly
155- // nrows * ncols distinct offsets within [0, values.len()).
156- #[ cfg( debug_assertions) ]
157- {
158- let len = self . values . len ( ) ;
159- let mut seen = std:: collections:: HashSet :: new ( ) ;
160- match axis {
161- 0 => {
162- for r in 0 ..nrows {
163- for c in 0 ..ncols {
164- let off = if column_major {
165- r + c * stride
166- } else {
167- r * stride + c
168- } ;
169- assert ! (
170- off < len,
171- "iterator_mut: offset {off} out of bounds (len={len})"
172- ) ;
173- assert ! (
174- seen. insert( off) ,
175- "iterator_mut: aliasing detected at offset {off}"
176- ) ;
177- }
155+
156+ // Axis = 0: row-by-row (outer loop over rows, inner over cols)
157+ // Axis = 1: col-by-col (outer loop over cols, inner over rows)
158+ // Four cases: column-major (axis 0 or 1), row-major (axis 1 or 0)
159+
160+ // Collect all mutable references up-front using split_at_mut so
161+ // that the resulting iterator owns no borrow of "self.values"
162+
163+ match ( column_major, axis) {
164+ // Case B: column-major, col-by-col
165+ ( true , 1 ) => {
166+ let mut refs: Vec < & ' b mut T > = Vec :: with_capacity ( ncols * nrows) ;
167+ let mut remaining: & ' b mut [ T ] = self . values ;
168+ for _c in 0 ..ncols {
169+ let col_end = if _c == ncols - 1 {
170+ remaining. len ( )
171+ } else {
172+ stride
173+ } ;
174+ let ( col_slice, tail) = remaining. split_at_mut ( col_end) ;
175+ for elem in col_slice[ ..nrows] . iter_mut ( ) {
176+ refs. push ( elem) ;
178177 }
178+ remaining = tail;
179179 }
180- _ => {
181- for c in 0 ..ncols {
182- for r in 0 ..nrows {
183- let off = if column_major {
184- r + c * stride
185- } else {
186- r * stride + c
187- } ;
188- assert ! (
189- off < len,
190- "iterator_mut: offset {off} out of bounds (len={len})"
191- ) ;
192- assert ! (
193- seen. insert( off) ,
194- "iterator_mut: aliasing detected at offset {off}"
195- ) ;
180+ Box :: new ( refs. into_iter ( ) )
181+ }
182+
183+ // Case A: column-major, row-by-row
184+ ( true , _) => {
185+ let mut refs: Vec < & ' b mut T > = Vec :: with_capacity ( nrows * ncols) ;
186+
187+ let total = nrows * ncols;
188+
189+ let mut by_col: Vec < & ' b mut T > = Vec :: with_capacity ( total) ;
190+ {
191+ let mut remaining: & ' b mut [ T ] = self . values ;
192+ for _c in 0 ..ncols {
193+ let col_end = if _c == ncols - 1 {
194+ remaining. len ( )
195+ } else {
196+ stride
197+ } ;
198+ let ( col_slice, tail) = remaining. split_at_mut ( col_end) ;
199+ for elem in col_slice[ ..nrows] . iter_mut ( ) {
200+ by_col. push ( elem) ;
196201 }
202+ remaining = tail;
197203 }
198204 }
199- }
200- }
201205
202- match axis {
203- 0 => Box :: new ( ( 0 ..nrows) . flat_map ( move |r| {
204- ( 0 ..ncols) . map ( move |c| unsafe {
205- & mut * ptr. add ( if column_major {
206- r + c * stride
207- } else {
208- r * stride + c
206+ let mut indexed: Vec < ( usize , & ' b mut T ) > = by_col
207+ . into_iter ( )
208+ . enumerate ( )
209+ . map ( |( flat_col_idx, r) | {
210+ let c = flat_col_idx / nrows;
211+ let row = flat_col_idx % nrows;
212+ let out_idx = row * ncols + c;
213+ ( out_idx, r)
209214 } )
210- } )
211- } ) ) ,
212- _ => Box :: new ( ( 0 ..ncols) . flat_map ( move |c| {
213- ( 0 ..nrows) . map ( move |r| unsafe {
214- & mut * ptr. add ( if column_major {
215- r + c * stride
215+ . collect ( ) ;
216+ indexed. sort_unstable_by_key ( |( idx, _) | * idx) ;
217+ refs. extend ( indexed. into_iter ( ) . map ( |( _, r) | r) ) ;
218+ Box :: new ( refs. into_iter ( ) )
219+ }
220+
221+ // Case C: row-major, row-by-row
222+ ( false , 0 ) => {
223+ let mut refs: Vec < & ' b mut T > = Vec :: with_capacity ( nrows * ncols) ;
224+ let mut remaining: & ' b mut [ T ] = self . values ;
225+ for _r in 0 ..nrows {
226+ let row_end = if _r == nrows - 1 {
227+ remaining. len ( )
216228 } else {
217- r * stride + c
229+ stride
230+ } ;
231+ let ( row_slice, tail) = remaining. split_at_mut ( row_end) ;
232+ for elem in row_slice[ ..ncols] . iter_mut ( ) {
233+ refs. push ( elem) ;
234+ }
235+ remaining = tail;
236+ }
237+ Box :: new ( refs. into_iter ( ) )
238+ }
239+
240+ // Case D: row-major, col-by-col
241+ ( false , _) => {
242+ let total = nrows * ncols;
243+ let mut by_row: Vec < & ' b mut T > = Vec :: with_capacity ( total) ;
244+ {
245+ let mut remaining: & ' b mut [ T ] = self . values ;
246+ for _r in 0 ..nrows {
247+ let row_end = if _r == nrows - 1 {
248+ remaining. len ( )
249+ } else {
250+ stride
251+ } ;
252+ let ( row_slice, tail) = remaining. split_at_mut ( row_end) ;
253+ for elem in row_slice[ ..ncols] . iter_mut ( ) {
254+ by_row. push ( elem) ;
255+ }
256+ remaining = tail;
257+ }
258+ }
259+
260+ let mut indexed: Vec < ( usize , & ' b mut T ) > = by_row
261+ . into_iter ( )
262+ . enumerate ( )
263+ . map ( |( flat_row_idx, r) | {
264+ let row = flat_row_idx / ncols;
265+ let col = flat_row_idx % ncols;
266+ let out_idx = col * nrows + row;
267+ ( out_idx, r)
218268 } )
219- } )
220- } ) ) ,
269+ . collect ( ) ;
270+ indexed. sort_unstable_by_key ( |( idx, _) | * idx) ;
271+ let mut refs: Vec < & ' b mut T > = Vec :: with_capacity ( total) ;
272+ refs. extend ( indexed. into_iter ( ) . map ( |( _, r) | r) ) ;
273+ Box :: new ( refs. into_iter ( ) )
274+ }
221275 }
222276 }
223277}
@@ -502,49 +556,84 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMat
502556 }
503557
504558 fn iterator_mut < ' b > ( & ' b mut self , axis : u8 ) -> Box < dyn Iterator < Item = & ' b mut T > + ' b > {
505- let ptr = self . values . as_mut_ptr ( ) ;
559+ assert ! (
560+ axis == 1 || axis == 0 ,
561+ "For two dimensional array `axis` should be either 0 or 1"
562+ ) ;
563+
506564 let column_major = self . column_major ;
507565 let ( nrows, ncols) = self . shape ( ) ;
508566
509- #[ cfg( debug_assertions) ]
510- {
511- let len = self . values . len ( ) ;
512- let mut seen = std:: collections:: HashSet :: new ( ) ;
513- for r in 0 ..nrows {
514- for c in 0 ..ncols {
515- let off = if column_major {
516- r + c * nrows
517- } else {
518- r * ncols + c
519- } ;
520- assert ! (
521- off < len,
522- "iterator_mut: offset {off} out of bounds (len={len})"
523- ) ;
524- assert ! ( seen. insert( off) , "iterator_mut: aliasing at offset {off}" ) ;
525- }
567+ match ( column_major, axis) {
568+ // Case B: column-major, col-by-col
569+ ( true , 1 ) => {
570+ let refs: Vec < & ' b mut T > = self
571+ . values
572+ . chunks_mut ( nrows)
573+ . flat_map ( |col| col. iter_mut ( ) )
574+ . collect ( ) ;
575+ Box :: new ( refs. into_iter ( ) )
526576 }
527- }
528577
529- match axis {
530- 0 => Box :: new ( ( 0 ..nrows) . flat_map ( move |r| {
531- ( 0 ..ncols) . map ( move |c| unsafe {
532- & mut * ptr. add ( if column_major {
533- r + c * nrows
534- } else {
535- r * ncols + c
578+ // Case A: column-major, row-by-row
579+ ( true , _) => {
580+ let total = nrows * ncols;
581+ let by_col: Vec < & ' b mut T > = self
582+ . values
583+ . chunks_mut ( nrows)
584+ . flat_map ( |col| col. iter_mut ( ) )
585+ . collect ( ) ;
586+
587+ let mut indexed: Vec < ( usize , & ' b mut T ) > = by_col
588+ . into_iter ( )
589+ . enumerate ( )
590+ . map ( |( flat_col_idx, elem) | {
591+ let c = flat_col_idx / nrows;
592+ let r = flat_col_idx % nrows;
593+ ( r * ncols + c, elem)
536594 } )
537- } )
538- } ) ) ,
539- _ => Box :: new ( ( 0 ..ncols) . flat_map ( move |c| {
540- ( 0 ..nrows) . map ( move |r| unsafe {
541- & mut * ptr. add ( if column_major {
542- r + c * nrows
543- } else {
544- r * ncols + c
595+ . collect ( ) ;
596+ indexed. sort_unstable_by_key ( |( idx, _) | * idx) ;
597+
598+ let mut refs: Vec < & ' b mut T > = Vec :: with_capacity ( total) ;
599+ refs. extend ( indexed. into_iter ( ) . map ( |( _, e) | e) ) ;
600+ Box :: new ( refs. into_iter ( ) )
601+ }
602+
603+ // Case C: row-major, row-by-row
604+ ( false , 0 ) => {
605+ let refs: Vec < & ' b mut T > = self
606+ . values
607+ . chunks_mut ( ncols)
608+ . flat_map ( |row| row. iter_mut ( ) )
609+ . collect ( ) ;
610+ Box :: new ( refs. into_iter ( ) )
611+ }
612+
613+ // Case D: row-major, col-by-col
614+ ( false , _) => {
615+ let total = nrows * ncols;
616+ let by_row: Vec < & ' b mut T > = self
617+ . values
618+ . chunks_mut ( ncols)
619+ . flat_map ( |row| row. iter_mut ( ) )
620+ . collect ( ) ;
621+
622+ let mut indexed: Vec < ( usize , & ' b mut T ) > = by_row
623+ . into_iter ( )
624+ . enumerate ( )
625+ . map ( |( flat_row_idx, elem) | {
626+ let r = flat_row_idx / ncols;
627+ let c = flat_row_idx % ncols;
628+ ( c * nrows + r, elem)
545629 } )
546- } )
547- } ) ) ,
630+ . collect ( ) ;
631+ indexed. sort_unstable_by_key ( |( idx, _) | * idx) ;
632+
633+ let mut refs: Vec < & ' b mut T > = Vec :: with_capacity ( total) ;
634+ refs. extend ( indexed. into_iter ( ) . map ( |( _, e) | e) ) ;
635+ Box :: new ( refs. into_iter ( ) )
636+ }
548637 }
549638 }
550639}
@@ -910,7 +999,9 @@ mod tests {
910999 assert_eq ! ( vec![ "1" , "4" , "7" , "2" , "5" , "8" , "3" , "6" , "9" ] , x. values) ;
9111000 x. iterator_mut ( 0 ) . for_each ( |v| * v = "str" ) ;
9121001 assert_eq ! (
913- vec![ "str" , "str" , "str" , "str" , "str" , "str" , "str" , "str" , "str" ] ,
1002+ vec![
1003+ "str" , "str" , "str" , "str" , "str" , "str" , "str" , "str" , "str"
1004+ ] ,
9141005 x. values
9151006 ) ;
9161007 }
0 commit comments