Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions dcbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


def compute_DCBC(maxDist=35, binWidth=1, parcellation=np.empty([]),
func=None, dist=None, weighting=True, backend='torch'):
func=None, dist=None, weighting=True, backend='torch',
batch_size=None):
""" DCBC calculation.
Automatically chooses the backend or uses user-specified backend.

Expand Down Expand Up @@ -52,7 +53,8 @@ def compute_DCBC(maxDist=35, binWidth=1, parcellation=np.empty([]),
"All inputs must be pytorch tensors!"
return compute_DCBC_pt(maxDist=maxDist, binWidth=binWidth,
parcellation=parcellation, func=func,
dist=dist, weighting=weighting)
dist=dist, weighting=weighting,
batch_size=batch_size)
elif backend == 'numpy' or not TORCH_AVAILABLE:
return compute_DCBC_np(maxDist=maxDist, binWidth=binWidth,
parcellation=parcellation, func=func,
Expand Down Expand Up @@ -86,7 +88,7 @@ def compute_DCBC_np(maxDist=35, binWidth=1, parcellation=np.empty([]),
cov, var = compute_var_cov(func, backend='numpy')

# remove the nan value and medial wall from dist file
row, col, distance = sp.sparse.find(dist)
row, col, distance = sp.sparse.find(dist_scipy)

num_within, num_between, corr_within, corr_between = [], [], [], []
for i in range(numBins):
Expand Down Expand Up @@ -134,7 +136,7 @@ def compute_DCBC_np(maxDist=35, binWidth=1, parcellation=np.empty([]),


def compute_DCBC_pt(maxDist=35, binWidth=1, parcellation=np.empty([]),
func=None, dist=None, weighting=True):
func=None, dist=None, weighting=True, batch_size=None):
""" DCBC calculation (PyTorch version)

Args:
Expand All @@ -155,7 +157,7 @@ def compute_DCBC_pt(maxDist=35, binWidth=1, parcellation=np.empty([]),
D: a dictionary contains necessary information for DCBC analysis
"""
numBins = int(np.floor(maxDist / binWidth))
cov, var = compute_var_cov(func, backend='torch')
cov, var = compute_var_cov(func, backend='torch', batch_size=batch_size)
# cor = np.corrcoef(func)
if not dist.is_sparse:
dist = dist.to_sparse()
Expand Down
75 changes: 50 additions & 25 deletions utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,41 @@ def scan_subdirs(path):
return sub_dirs


def euclidean_distance(a, b, decimals=3):
""" Compute euclidean similarity between samples in a and b.
K(X, Y) = <X, Y> / (||X||*||Y||)
def euclidean_distance(a, b=None, decimals=3, max_dist=None):
"""
Compute euclidean similarity between samples in a and b
incrementally to save memory.

Args:
a (ndarray): shape of (n_samples, n_features) input data.
e.g (32492, 34) means 32,492 cortical nodes
with 34 task condition activation profile
b (ndarray): shape of (n_samples, n_features) input data.
If None, b = a
decimals: the precision when rounding
a (ndarray): shape of (n_samples, n_features) input data.
e.g (32492, 34) means 32,492 cortical nodes with
34 tasks
condition activation profile
b (ndarray): shape of (n_samples, n_features) input data.
If None, b = a.
decimals: the precision when rounding.
max_dist: Optional; distances greater than this value will be
set to 0.

Returns:
r: the cosine similarity matrix between nodes. [N * N]
N is the number of cortical nodes
r: the pairwise distance matrix [N x N], where N is the number
of samples in a.
"""
p1 = np.einsum('ij,ij->i', a, a)[:, np.newaxis]
p2 = np.einsum('ij,ij->i', b, b)[:, np.newaxis]
p3 = -2 * np.dot(a, b.T)
if b is None:
b = a

N = a.shape[0]
dist = np.zeros((N, N), dtype=np.float32)

for i in range(N):
# Compute distances incrementally for row 'i'
diff = a[i] - b
row_dist = np.sqrt(np.einsum('ij,ij->i', diff, diff))
dist[i] = np.round(row_dist, decimals)

dist = np.round(np.sqrt(p1 + p2 + p3), decimals)
dist.flat[::dist.shape[0] + 1] = 0.0
# Apply max_dist condition if provided
if max_dist is not None:
dist[i][dist[i] > max_dist] = 0

return dist

Expand Down Expand Up @@ -123,8 +136,8 @@ def compute_dist_from_surface(files, type, max_dist=50, hems='L', sparse=True):
dist = []
if files is None:
file_name = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'parcellations', 'fs_LR_32k_template',
'fs_LR.32k.%s.sphere.surf.gii' % hems)
'parcellations', 'fs_LR_32k_template',
f'fs_LR.32k.{hems}.sphere.surf.gii')
else:
file_name = files

Expand All @@ -133,9 +146,8 @@ def compute_dist_from_surface(files, type, max_dist=50, hems='L', sparse=True):
surf = [x.data for x in mat.darrays]
surf_vertices = surf[0]

dist = euclidean_distance(surf_vertices, surf_vertices)
dist[dist > max_dist] = 0

dist = euclidean_distance(surf_vertices, surf_vertices,
max_dist=max_dist)
elif type == 'dijstra':
mat = nb.load(file_name)
surf = [x.data for x in mat.darrays]
Expand Down Expand Up @@ -316,7 +328,8 @@ def compute_dist_np(coord, resolution=2):


### variance / covariance
def compute_var_cov(data, cond='all', mean_centering=True, backend='torch'):
def compute_var_cov(data, cond='all', mean_centering=True, backend='torch',
batch_size=None):
""" Compute the variance and covariance for a given data matrix.
Automatically chooses the backend or uses user-specified backend.

Expand All @@ -340,7 +353,8 @@ def compute_var_cov(data, cond='all', mean_centering=True, backend='torch'):
if type(data) is np.ndarray:
data = pt.tensor(data, dtype=pt.get_default_dtype())
assert type(data) is pt.Tensor, "Input data must be pytorch tensor!"
return compute_var_cov_pt(data, cond=cond, mean_centering=mean_centering)
return compute_var_cov_pt(data, cond=cond, mean_centering=mean_centering,
batch_size=batch_size)
elif backend == 'numpy' or not TORCH_AVAILABLE:
return compute_var_cov_np(data, cond=cond, mean_centering=mean_centering)
else:
Expand Down Expand Up @@ -388,7 +402,8 @@ def compute_var_cov_np(data, cond='all', mean_centering=True):
return cov, var


def compute_var_cov_pt(data, cond='all', mean_centering=True):
def compute_var_cov_pt(data, cond='all', mean_centering=True,
batch_size=None):
""" Compute the variance and covariance for a given data matrix.
(PyTorch GPU version)

Expand Down Expand Up @@ -416,7 +431,17 @@ def compute_var_cov_pt(data, cond='all', mean_centering=True):
raise TypeError("Invalid condition type input! cond must be either 'all'"
" or the column indices of expected task conditions")
k = data.shape[1]
cov = pt.matmul(data, data.T) / (k - 1)

if batch_size:
p = data.shape[0]
cov = pt.zeros((p, p))
for start in range(0, p, batch_size):
end = min(start + batch_size, p)
batch_data = data[start:end]
cov[start:end, :] = pt.matmul(batch_data, data.T) / (k - 1)
else:
cov = pt.matmul(data, data.T) / (k - 1)

# sd = data.std(dim=1).reshape(-1, 1) # standard deviation
sd = pt.sqrt(pt.sum(data ** 2, dim=1, keepdim=True) / (k - 1))
var = pt.matmul(sd, sd.T)
Expand Down