From 78d0278755555e934d68024b3fbf4a15f65443ad Mon Sep 17 00:00:00 2001 From: Martin Stancsics Date: Wed, 9 Aug 2023 15:09:39 +0200 Subject: [PATCH] Be consistent when instantiating from 1d arrays --- src/tabmat/sparse_matrix.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tabmat/sparse_matrix.py b/src/tabmat/sparse_matrix.py index 8c2a3b2b..d98f180f 100644 --- a/src/tabmat/sparse_matrix.py +++ b/src/tabmat/sparse_matrix.py @@ -31,8 +31,14 @@ class SparseMatrix(MatrixBase): SparseMatrix is instantiated in the same way as scipy.sparse.csc_matrix. """ - def __init__(self, arg1, shape=None, dtype=None, copy=False): - self._array = sps.csc_matrix(arg1, shape, dtype, copy) + def __init__(self, input_array, shape=None, dtype=None, copy=False): + if isinstance(input_array, np.ndarray): + if input_array.ndim == 1: + input_array = input_array.reshape(-1, 1) + elif input_array.ndim > 2: + raise ValueError("Input array must be 1- or 2-dimensional") + + self._array = sps.csc_matrix(input_array, shape, dtype, copy) self.idx_dtype = max(self._array.indices.dtype, self._array.indptr.dtype) if self._array.indices.dtype != self.idx_dtype: