diff --git a/dcbc.py b/dcbc.py index 7164e0d..58b0880 100644 --- a/dcbc.py +++ b/dcbc.py @@ -19,7 +19,8 @@ def compute_DCBC(maxDist=35, binWidth=1, parcellation=np.empty([]), - func=None, dist=None, weighting=True, backend='torch'): + func=None, dist=None, weighting=True, backend='torch', + batch_size=None): """ DCBC calculation. Automatically chooses the backend or uses user-specified backend. @@ -52,7 +53,8 @@ def compute_DCBC(maxDist=35, binWidth=1, parcellation=np.empty([]), "All inputs must be pytorch tensors!" return compute_DCBC_pt(maxDist=maxDist, binWidth=binWidth, parcellation=parcellation, func=func, - dist=dist, weighting=weighting) + dist=dist, weighting=weighting, + batch_size=batch_size) elif backend == 'numpy' or not TORCH_AVAILABLE: return compute_DCBC_np(maxDist=maxDist, binWidth=binWidth, parcellation=parcellation, func=func, @@ -86,7 +88,7 @@ def compute_DCBC_np(maxDist=35, binWidth=1, parcellation=np.empty([]), cov, var = compute_var_cov(func, backend='numpy') # remove the nan value and medial wall from dist file - row, col, distance = sp.sparse.find(dist) + row, col, distance = sp.sparse.find(dist_scipy) num_within, num_between, corr_within, corr_between = [], [], [], [] for i in range(numBins): @@ -134,7 +136,7 @@ def compute_DCBC_np(maxDist=35, binWidth=1, parcellation=np.empty([]), def compute_DCBC_pt(maxDist=35, binWidth=1, parcellation=np.empty([]), - func=None, dist=None, weighting=True): + func=None, dist=None, weighting=True, batch_size=None): """ DCBC calculation (PyTorch version) Args: @@ -155,7 +157,7 @@ def compute_DCBC_pt(maxDist=35, binWidth=1, parcellation=np.empty([]), D: a dictionary contains necessary information for DCBC analysis """ numBins = int(np.floor(maxDist / binWidth)) - cov, var = compute_var_cov(func, backend='torch') + cov, var = compute_var_cov(func, backend='torch', batch_size=batch_size) # cor = np.corrcoef(func) if not dist.is_sparse: dist = dist.to_sparse() diff --git a/utilities.py b/utilities.py index 8621ac2..1eac552 100644 --- a/utilities.py +++ b/utilities.py @@ -73,28 +73,41 @@ def scan_subdirs(path): return sub_dirs -def euclidean_distance(a, b, decimals=3): - """ Compute euclidean similarity between samples in a and b. - K(X, Y) = / (||X||*||Y||) +def euclidean_distance(a, b=None, decimals=3, max_dist=None): + """ + Compute euclidean similarity between samples in a and b + incrementally to save memory. Args: - a (ndarray): shape of (n_samples, n_features) input data. - e.g (32492, 34) means 32,492 cortical nodes - with 34 task condition activation profile - b (ndarray): shape of (n_samples, n_features) input data. - If None, b = a - decimals: the precision when rounding + a (ndarray): shape of (n_samples, n_features) input data. + e.g (32492, 34) means 32,492 cortical nodes with + 34 tasks + condition activation profile + b (ndarray): shape of (n_samples, n_features) input data. + If None, b = a. + decimals: the precision when rounding. + max_dist: Optional; distances greater than this value will be + set to 0. Returns: - r: the cosine similarity matrix between nodes. [N * N] - N is the number of cortical nodes + r: the pairwise distance matrix [N x N], where N is the number + of samples in a. """ - p1 = np.einsum('ij,ij->i', a, a)[:, np.newaxis] - p2 = np.einsum('ij,ij->i', b, b)[:, np.newaxis] - p3 = -2 * np.dot(a, b.T) + if b is None: + b = a + + N = a.shape[0] + dist = np.zeros((N, N), dtype=np.float32) + + for i in range(N): + # Compute distances incrementally for row 'i' + diff = a[i] - b + row_dist = np.sqrt(np.einsum('ij,ij->i', diff, diff)) + dist[i] = np.round(row_dist, decimals) - dist = np.round(np.sqrt(p1 + p2 + p3), decimals) - dist.flat[::dist.shape[0] + 1] = 0.0 + # Apply max_dist condition if provided + if max_dist is not None: + dist[i][dist[i] > max_dist] = 0 return dist @@ -123,8 +136,8 @@ def compute_dist_from_surface(files, type, max_dist=50, hems='L', sparse=True): dist = [] if files is None: file_name = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'parcellations', 'fs_LR_32k_template', - 'fs_LR.32k.%s.sphere.surf.gii' % hems) + 'parcellations', 'fs_LR_32k_template', + f'fs_LR.32k.{hems}.sphere.surf.gii') else: file_name = files @@ -133,9 +146,8 @@ def compute_dist_from_surface(files, type, max_dist=50, hems='L', sparse=True): surf = [x.data for x in mat.darrays] surf_vertices = surf[0] - dist = euclidean_distance(surf_vertices, surf_vertices) - dist[dist > max_dist] = 0 - + dist = euclidean_distance(surf_vertices, surf_vertices, + max_dist=max_dist) elif type == 'dijstra': mat = nb.load(file_name) surf = [x.data for x in mat.darrays] @@ -316,7 +328,8 @@ def compute_dist_np(coord, resolution=2): ### variance / covariance -def compute_var_cov(data, cond='all', mean_centering=True, backend='torch'): +def compute_var_cov(data, cond='all', mean_centering=True, backend='torch', + batch_size=None): """ Compute the variance and covariance for a given data matrix. Automatically chooses the backend or uses user-specified backend. @@ -340,7 +353,8 @@ def compute_var_cov(data, cond='all', mean_centering=True, backend='torch'): if type(data) is np.ndarray: data = pt.tensor(data, dtype=pt.get_default_dtype()) assert type(data) is pt.Tensor, "Input data must be pytorch tensor!" - return compute_var_cov_pt(data, cond=cond, mean_centering=mean_centering) + return compute_var_cov_pt(data, cond=cond, mean_centering=mean_centering, + batch_size=batch_size) elif backend == 'numpy' or not TORCH_AVAILABLE: return compute_var_cov_np(data, cond=cond, mean_centering=mean_centering) else: @@ -388,7 +402,8 @@ def compute_var_cov_np(data, cond='all', mean_centering=True): return cov, var -def compute_var_cov_pt(data, cond='all', mean_centering=True): +def compute_var_cov_pt(data, cond='all', mean_centering=True, + batch_size=None): """ Compute the variance and covariance for a given data matrix. (PyTorch GPU version) @@ -416,7 +431,17 @@ def compute_var_cov_pt(data, cond='all', mean_centering=True): raise TypeError("Invalid condition type input! cond must be either 'all'" " or the column indices of expected task conditions") k = data.shape[1] - cov = pt.matmul(data, data.T) / (k - 1) + + if batch_size: + p = data.shape[0] + cov = pt.zeros((p, p)) + for start in range(0, p, batch_size): + end = min(start + batch_size, p) + batch_data = data[start:end] + cov[start:end, :] = pt.matmul(batch_data, data.T) / (k - 1) + else: + cov = pt.matmul(data, data.T) / (k - 1) + # sd = data.std(dim=1).reshape(-1, 1) # standard deviation sd = pt.sqrt(pt.sum(data ** 2, dim=1, keepdim=True) / (k - 1)) var = pt.matmul(sd, sd.T)