Skip to content

Commit 2db6b93

Browse files
authored
Merge pull request #11 from KlugerLab/10-check-gene_expression-matrix-is-positive
10 check gene expression matrix is positive
2 parents 35184f0 + 9989cff commit 2db6b93

8 files changed

+138
-8
lines changed

gene_trajectory/coarse_grain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import scanpy as sc
55
from sklearn.cluster import KMeans
66

7+
from gene_trajectory.util.input_validation import validate_matrix
8+
79

810
def select_top_genes(
911
adata: sc.AnnData,
@@ -64,6 +66,10 @@ def coarse_grain(
6466
:param random_seed: the random seed
6567
:return: the updated cell embedding and gene expression matrices
6668
"""
69+
validate_matrix(gene_expression, obj_name='Gene Expression Matrix', min_value=0)
70+
ncells, ngenes = gene_expression.shape
71+
validate_matrix(cell_embedding, obj_name='Cell embedding', nrows=ncells)
72+
6773
if cluster is None:
6874
k_means = KMeans(n_clusters=n, random_state=random_seed).fit(cell_embedding)
6975
cluster = k_means.labels_ # noqa

gene_trajectory/diffusion_map.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Union
22
import numpy as np
33

4+
from gene_trajectory.util.input_validation import validate_matrix
5+
46

57
def diffusion_map(
68
dist_mat: np.array,
@@ -19,6 +21,8 @@ def diffusion_map(
1921
:param t: Number of diffusion times
2022
:return: the diffusion embedding and the eigenvalues
2123
"""
24+
validate_matrix(dist_mat, square=True)
25+
2226
affinity_matrix_symm = get_symmetrized_affinity_matrix(dist_mat=dist_mat, k=k, sigma=sigma)
2327
normalized_vec = np.sqrt(1 / affinity_matrix_symm.sum(axis=1))
2428
affinity_matrix_norm = (affinity_matrix_symm * normalized_vec * normalized_vec[:, None])
@@ -50,7 +54,8 @@ def get_symmetrized_affinity_matrix(
5054
5155
:return:
5256
"""
53-
assert dist_mat.shape[0] == dist_mat.shape[1]
57+
validate_matrix(dist_mat, square=True)
58+
5459
dists = np.nan_to_num(dist_mat, 1e-6) # noqa
5560
k = min(k, dist_mat.shape[0])
5661

gene_trajectory/extract_gene_trajectory.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88

99
from gene_trajectory.diffusion_map import diffusion_map, get_symmetrized_affinity_matrix
10+
from gene_trajectory.util.input_validation import validate_matrix
1011

1112
logger = logging.getLogger()
1213

@@ -28,6 +29,8 @@ def get_gene_embedding(
2829
:param t: Number of diffusion times
2930
:return: the diffusion embedding and the eigenvalues
3031
"""
32+
validate_matrix(dist_mat, square=True)
33+
3134
k = min(k, dist_mat.shape[0])
3235
n_ev = min(n_ev + 1, dist_mat.shape[0])
3336
diffu_emb, eigen_vals = diffusion_map(dist_mat=dist_mat, k=k, sigma=sigma, n_ev=n_ev, t=t)
@@ -47,6 +50,8 @@ def get_randow_walk_matrix(
4750
:param k: Adaptive kernel bandwidth for each point set to be the distance to its `K`-th nearest neighbor
4851
:return: Random-walk matrix
4952
"""
53+
validate_matrix(dist_mat, square=True)
54+
5055
affinity_matrix_symm = get_symmetrized_affinity_matrix(dist_mat=dist_mat, k=k)
5156
normalized_vec = 1 / affinity_matrix_symm.sum(axis=1)
5257
affinity_matrix_norm = (affinity_matrix_symm * normalized_vec[:, None])
@@ -67,7 +72,7 @@ def get_gene_pseudoorder(
6772
:param max_id: Index of the terminal gene
6873
:return: The pseudoorder
6974
"""
70-
assert dist_mat.shape[0] == dist_mat.shape[1]
75+
validate_matrix(dist_mat, square=True)
7176

7277
emd = dist_mat[subset][:, subset]
7378
dm_emb, _ = diffusion_map(emd)
@@ -108,6 +113,8 @@ def extract_gene_trajectory(
108113
:param other: Label for genes not in a trajectory. Default: 'Other'
109114
:return: A data frame indicating gene trajectories and gene ordering along each trajectory
110115
"""
116+
validate_matrix(dist_mat, square=True)
117+
111118
if np.isscalar(t_list):
112119
if n is None:
113120
raise ValueError(f'n should be specified if t_list is a number: {t_list}')

gene_trajectory/gene_distance_shared.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import ot
1010
from tqdm import tqdm
1111

12+
from gene_trajectory.util.input_validation import validate_matrix
1213
from gene_trajectory.util.shared_array import SharedArray, PartialStarApply
1314

1415
logger = logging.getLogger()
@@ -42,18 +43,20 @@ def cal_ot_mat(
4243
:return: the distance matrix
4344
"""
4445
processes = int(processes) if isinstance(processes, float) else os.cpu_count()
45-
n = gene_expr.shape[1]
46+
validate_matrix(gene_expr, obj_name='Gene Expression Matrix', min_value=0)
47+
ncells, ngenes = gene_expr.shape
48+
validate_matrix(ot_cost, obj_name='Cost Matrix', shape=(ncells, ncells), min_value=0)
49+
4650
if show_progress_bar:
4751
logger.info(f'Computing emd distance..')
4852

4953
if gene_pairs is None:
50-
pairs = ((i, j) for i in range(0, n - 1) for j in range(i + 1, n))
51-
npairs = (n * (n - 1)) // 2
54+
pairs = ((i, j) for i in range(0, ngenes - 1) for j in range(i + 1, ngenes))
55+
npairs = (ngenes * (ngenes - 1)) // 2
5256
else:
5357
pairs = gene_pairs
5458
npairs = len(gene_pairs)
55-
56-
emd_mat = np.full((n, n), fill_value=np.NaN)
59+
emd_mat = np.full((ngenes, ngenes), fill_value=np.NaN)
5760

5861
with SharedMemoryManager() as manager:
5962
start_time = time.perf_counter()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
5+
6+
def validate_not_none(obj, obj_name: str = 'input'):
7+
if obj is None:
8+
raise ValueError(f"{obj_name} is None")
9+
10+
11+
def validate_matrix(
12+
m: np.array,
13+
obj_name: str = 'input',
14+
nrows: Optional[int] = None,
15+
ncols: Optional[int] = None,
16+
shape: Optional[tuple[int, int]] = None,
17+
square: Optional[bool] = None,
18+
min_size: Optional[int] = 1,
19+
min_value: Optional = None,
20+
max_value: Optional = None,
21+
):
22+
"""
23+
Validates an input matrix
24+
@param m: the input matrix
25+
@param obj_name: the name of the object, for error reporting
26+
@param min_size: Minimum matrix size in each dimension. Defaults to 1, and will raise for an empty matrix
27+
@param nrows: Number of rows in the matrix
28+
@param ncols: Number of rows in the matrix
29+
@param shape: the expected shape of the matrix
30+
@param square: If True, ensures the matrix is square. If False, ensures the matrix is not
31+
@param min_value: Minimum value for each element
32+
@param max_value: Maximum value for each element
33+
"""
34+
validate_not_none(m, obj_name=obj_name)
35+
36+
if len(m.shape) != 2:
37+
raise ValueError(f"{obj_name} is not a matrix. Shape: {m.shape}")
38+
mr, mc = m.shape
39+
40+
if nrows is not None:
41+
if mr != nrows:
42+
raise ValueError(f"{obj_name} does not have {nrows} rows. Shape: {m.shape}")
43+
44+
if ncols is not None:
45+
if mc != ncols:
46+
raise ValueError(f"{obj_name} does not have {ncols} columns. Shape: {m.shape}")
47+
48+
if shape is not None:
49+
if m.shape != shape:
50+
raise ValueError(f"{obj_name} does not have shape {shape}. Shape: {m.shape}")
51+
52+
if square is True:
53+
if mr != mc:
54+
raise ValueError(f"{obj_name} is not a square matrix. Shape: {m.shape}")
55+
elif square is False:
56+
if mr == mc:
57+
raise ValueError(f"{obj_name} is a square matrix. Shape: {m.shape}")
58+
59+
if min_size is not None:
60+
for s in m.shape:
61+
if s < min_size:
62+
raise ValueError(f"{obj_name} does not have enough elements. Min_size: {min_size}, Shape: {m.shape}")
63+
64+
if min_value is not None:
65+
if m.min() < min_value:
66+
raise ValueError(f"{obj_name} should not have values less than {min_value}. Minimum found: {m.min()}")
67+
68+
if max_value is not None:
69+
if m.max() > max_value:
70+
raise ValueError(f"{obj_name} should not have values greater than {max_value}. Maximum found: {m.max()}")

tests/test_compute_gene_distance_cmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from gene_trajectory.compute_gene_distance_cmd import cal_ot_mat
88

99

10-
class DiffusionMapTestCase(unittest.TestCase):
10+
class ComputeGeneDistanceTestCase(unittest.TestCase):
1111
gdm = np.array([
1212
[0, 1, 2],
1313
[1, 0, 2],

tests/test_gene_distance_shared.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ def test_gene_distance_shared(self):
2828
mt = cal_ot_mat(ot_cost=self.gdm, gene_expr=self.gem.T, show_progress_bar=False)
2929
np.testing.assert_almost_equal(self.expected_emd, mt, 6)
3030

31+
def test_gene_distance_input_validation(self):
32+
with self.assertRaisesRegexp(ValueError, 'Cost Matrix does not have shape.*'):
33+
cal_ot_mat(ot_cost=self.gdm, gene_expr=np.ones(shape=(6, 3)), show_progress_bar=False)
34+
35+
with self.assertRaisesRegexp(ValueError, 'Cost Matrix does not have shape.*'):
36+
cal_ot_mat(ot_cost=np.ones(shape=(6, 3)), gene_expr=self.gem.T, show_progress_bar=False)
37+
38+
with self.assertRaisesRegexp(ValueError, 'Gene Expression Matrix should not have values less than 0.*'):
39+
cal_ot_mat(ot_cost=np.ones(shape=(6, 3)), gene_expr=self.gem.T - 1, show_progress_bar=False)
40+
3141
def test_cal_ot_mat_gene_pairs(self):
3242
exp = self.expected_emd.copy()
3343
exp[0, 2] = exp[2, 0] = 900

tests/test_input_validation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
from gene_trajectory.util.input_validation import validate_matrix
6+
7+
8+
class InputValidationTestCase(unittest.TestCase):
9+
def test_validate_matrix(self):
10+
m = np.array([[1, 2], [3, 4]])
11+
12+
validate_matrix(m, min_value=1, max_value=4, square=True, shape=(2, 2))
13+
14+
with self.assertRaisesRegexp(ValueError, '.*does not have 3 rows.*'):
15+
validate_matrix(m, nrows=3)
16+
with self.assertRaisesRegexp(ValueError, '.*does not have 8 columns.*'):
17+
validate_matrix(m, ncols=8)
18+
with self.assertRaisesRegexp(ValueError, '.*does not have shape \\(1, 1\\)'):
19+
validate_matrix(m, shape=(1, 1))
20+
with self.assertRaisesRegexp(ValueError, '.*Min_size: 3.*'):
21+
validate_matrix(m, min_size=3)
22+
with self.assertRaisesRegexp(ValueError, '.*should not have values less than 5.*'):
23+
validate_matrix(m, min_value=5)
24+
with self.assertRaisesRegexp(ValueError, '.*should not have values greater than 1.*'):
25+
validate_matrix(m, max_value=1)
26+
27+
28+
if __name__ == '__main__':
29+
unittest.main()

0 commit comments

Comments
 (0)