From ac9cdbd4481b6385c7c6bde2134a96164d52c941 Mon Sep 17 00:00:00 2001 From: Philip Fabianek Date: Wed, 19 Feb 2025 10:58:29 +0100 Subject: [PATCH] Refactor From implementations by using macros, add tests (#2762) --- candle-core/src/shape.rs | 63 ++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index ca05d216a5..e6fcc05a73 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -43,43 +43,22 @@ impl From for Shape { } } -impl From<(usize,)> for Shape { - fn from(d1: (usize,)) -> Self { - Self(vec![d1.0]) - } -} - -impl From<(usize, usize)> for Shape { - fn from(d12: (usize, usize)) -> Self { - Self(vec![d12.0, d12.1]) - } -} - -impl From<(usize, usize, usize)> for Shape { - fn from(d123: (usize, usize, usize)) -> Self { - Self(vec![d123.0, d123.1, d123.2]) - } -} - -impl From<(usize, usize, usize, usize)> for Shape { - fn from(d1234: (usize, usize, usize, usize)) -> Self { - Self(vec![d1234.0, d1234.1, d1234.2, d1234.3]) - } -} - -impl From<(usize, usize, usize, usize, usize)> for Shape { - fn from(d12345: (usize, usize, usize, usize, usize)) -> Self { - Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4]) +macro_rules! impl_from_tuple { + ($tuple:ty, $($index:tt),+) => { + impl From<$tuple> for Shape { + fn from(d: $tuple) -> Self { + Self(vec![$(d.$index,)+]) + } + } } } -impl From<(usize, usize, usize, usize, usize, usize)> for Shape { - fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self { - Self(vec![ - d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5, - ]) - } -} +impl_from_tuple!((usize,), 0); +impl_from_tuple!((usize, usize), 0, 1); +impl_from_tuple!((usize, usize, usize), 0, 1, 2); +impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3); +impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4); +impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5); impl From> for Shape { fn from(dims: Vec) -> Self { @@ -636,4 +615,20 @@ mod tests { let shape = Shape::from((299, 792, 458)); assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } + + #[test] + fn test_from_tuple() { + let shape = Shape::from((2,)); + assert_eq!(shape.dims(), &[2]); + let shape = Shape::from((2, 3)); + assert_eq!(shape.dims(), &[2, 3]); + let shape = Shape::from((2, 3, 4)); + assert_eq!(shape.dims(), &[2, 3, 4]); + let shape = Shape::from((2, 3, 4, 5)); + assert_eq!(shape.dims(), &[2, 3, 4, 5]); + let shape = Shape::from((2, 3, 4, 5, 6)); + assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]); + let shape = Shape::from((2, 3, 4, 5, 6, 7)); + assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]); + } }