Skip to content

Commit 9358820

Browse files
introspective-swallowstesMMathisLab
authored
Implementation for discrete multisession (#135)
* Implemented discrete multisession. * Added tests for discrete multisession. * Added tests for discrete multisession. * Add sklearn integration test. * Added comments. * Updating test setup * Pin pytest-sphinx to 0.5.0 * Pin pytest to 7.4.4 * Revert unintended commit. * Fixed tests. * Limit pandas < 2.2.0 for docs build * Update intersphinx mapping for pandas * Unpin pandas * apply pre-commit * Remove outdated TODO statement * Update usage.rst * Update usage.rst old spelling error, caught by new tests :) * Update CODE_OF_CONDUCT.md - spelling * Update make_neuropixel.py - spelling * Update usage.rst --------- Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent bbf9f6e commit 9358820

13 files changed

+352
-33
lines changed

CODE_OF_CONDUCT.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ In the interest of fostering an open and welcoming environment, we as
66
contributors and maintainers pledge to making participation in our project and
77
our community a harassment-free experience for everyone, regardless of age, body
88
size, disability, ethnicity, sex characteristics, gender identity and expression,
9-
level of experience, education, socio-economic status, nationality, personal
9+
level of experience, education, socioeconomic status, nationality, personal
1010
appearance, race, religion, or sexual identity and orientation.
1111

1212
## Our Standards

cebra/data/datasets.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ def __init__(
222222
else:
223223
self._cindex = None
224224
if discrete:
225-
raise NotImplementedError(
226-
"Multisession implementation does not support discrete index yet."
227-
)
225+
self._dindex = torch.cat(list(
226+
self._iter_property("discrete_index")),
227+
dim=0)
228228
else:
229229
self._dindex = None
230230

cebra/data/multi_session.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,16 @@ def index(self):
160160

161161
@dataclasses.dataclass
162162
class DiscreteMultiSessionDataLoader(MultiSessionLoader):
163-
pass
163+
"""Contrastive learning conditioned on a discrete behavior variable."""
164+
165+
# Overwrite sampler with the discrete implementation
166+
# Generalize MultisessionSampler to avoid doing this?
167+
def __post_init__(self):
168+
self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset)
169+
170+
@property
171+
def index(self):
172+
return self.dataset.discrete_index
164173

165174

166175
@dataclasses.dataclass

cebra/datasets/demo.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,17 @@ def discrete_index(self):
111111
return self.dindex
112112

113113

114-
# TODO(stes) remove this from the demo datasets until multi-session training
115-
# with discrete indices is implemented in the sklearn API.
116-
# @register("demo-discrete-multisession")
114+
@register("demo-discrete-multisession")
117115
class MultiDiscrete(cebra.data.DatasetCollection):
118116
"""Demo dataset for testing."""
119117

120-
def __init__(self, nums_neural=[3, 4, 5]):
118+
def __init__(
119+
self,
120+
nums_neural=[3, 4, 5],
121+
num_timepoints=_DEFAULT_NUM_TIMEPOINTS,
122+
):
121123
super().__init__(*[
122-
DemoDatasetDiscrete(_DEFAULT_NUM_TIMEPOINTS, num_neural)
124+
DemoDatasetDiscrete(num_timepoints, num_neural)
123125
for num_neural in nums_neural
124126
])
125127

cebra/datasets/make_neuropixel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def read_neuropixel(
171171
):
172172
"""Load 120Hz Neuropixels data recorded in the specified cortex during the movie1 stimulus.
173173
174-
The Neuropixels recordin is filtered and transformed to spike counts in a bin size specified by the sampling rat.
174+
The Neuropixels recording is filtered and transformed to spike counts in a bin size specified by the sampling rat.
175175
176176
Args:
177177
path: The wildcard file path where the neuropixels .nwb files are located.

cebra/distributions/multisession.py

+124
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,127 @@ def __getitem__(self, pos_idx):
259259
for i in range(self.num_sessions):
260260
pos_samples[i] = self.data[i][pos_idx[i]]
261261
return pos_samples
262+
263+
264+
class DiscreteMultisessionSampler(cebra_distr.PriorDistribution,
265+
cebra_distr.ConditionalDistribution):
266+
"""Discrete multi-session sampling.
267+
268+
Discrete indices don't need to be aligned. Positive pairs are found
269+
by matching the discrete index in randomly assigned sessions.
270+
271+
After data processing, the dimensionality of the returned features
272+
matches. The resulting embeddings can be concatenated, and shuffling
273+
(across the session axis) can be applied to the reference samples, or
274+
reversed for the positive samples.
275+
276+
TODO:
277+
* Add better CUDA support and refactor ``numpy`` implementation.
278+
"""
279+
280+
def __init__(self, dataset):
281+
self.dataset = dataset
282+
283+
# TODO(stes): implement in pytorch
284+
self.all_data = self.dataset.discrete_index.cpu().numpy()
285+
self.session_lengths = np.array(self.dataset.session_lengths)
286+
287+
self.lengths = np.cumsum(self.session_lengths)
288+
self.lengths[1:] = self.lengths[:-1]
289+
self.lengths[0] = 0
290+
291+
self.index = [
292+
cebra_distr.DiscreteUniform(
293+
dataset.discrete_index.int().to(_device))
294+
for dataset in self.dataset.iter_sessions()
295+
]
296+
297+
@property
298+
def num_sessions(self) -> int:
299+
"""The number of sessions in the index."""
300+
return len(self.lengths)
301+
302+
def mix(self, array: np.ndarray, idx: np.ndarray):
303+
"""Re-order array elements according to the given index mapping.
304+
305+
The given array should be of the shape ``(session, batch, ...)`` and the
306+
indices should have length ``session x batch``, representing a mapping
307+
between indices.
308+
309+
The resulting array will be rearranged such that
310+
``out.reshape(session*batch, -1)[i] = array.reshape(session*batch, -1)[idx[i]]``
311+
312+
For the inverse mapping, convert the indices first using ``_invert_index``
313+
function.
314+
315+
Args:
316+
array: A 2D matrix containing samples for each session.
317+
idx: A list of indexes to re-order ``array`` on.
318+
"""
319+
n, m = array.shape[:2]
320+
return array.reshape(n * m, -1)[idx].reshape(array.shape)
321+
322+
def sample_prior(self, num_samples):
323+
# TODO(stes) implement empirical/uniform resampling
324+
ref_idx = np.random.uniform(0, 1, (self.num_sessions, num_samples))
325+
ref_idx = (ref_idx * self.session_lengths[:, None]).astype(int)
326+
return ref_idx
327+
328+
def sample_conditional(self, idx: torch.Tensor) -> torch.Tensor:
329+
"""Sample from the conditional distribution.
330+
331+
Note:
332+
* Reference samples are sampled equally between sessions.
333+
* In order to guarantee the same number of positive samples per
334+
session, reference samples are randomly assigned to a session and its
335+
corresponding positive sample is searched in that session only.
336+
* As a result, ref/pos pairing is shuffled and can be recovered
337+
the reverse shuffle operation.
338+
339+
Args:
340+
idx: Reference indices, with dimension ``(session, batch)``.
341+
342+
Returns:
343+
Positive indices (1st return value), which will be grouped by
344+
session and *not* match the reference indices.
345+
In addition, a mapping will be returned to apply the same shuffle operation
346+
that was applied to assign reference samples to a session along session/batch dimension
347+
(2nd return value), or reverse the shuffle operation (3rd return value).
348+
Returned shapes are ``(session, batch), (session, batch), (session, batch)``.
349+
350+
TODO:
351+
* re-implement in pytorch for additional speed gains
352+
"""
353+
354+
shape = idx.shape
355+
# TODO(stes) unclear why we cannot restrict to 2d overall
356+
# idx has shape (2, #samples per batch)
357+
s = idx.shape[:2]
358+
idx_all = (idx + self.lengths[:, None]).flatten()
359+
360+
# get discrete indices
361+
query = self.all_data[idx_all]
362+
363+
# shuffle operation to assign each index to a session
364+
idx = np.random.permutation(len(query))
365+
366+
# TODO this part fails in Pytorch
367+
# apply shuffle
368+
query = query[idx.reshape(s)]
369+
query = torch.from_numpy(query).to(_device)
370+
371+
# sample conditional for each assigned session
372+
pos_idx = torch.zeros(shape, device=_device).long()
373+
for i in range(self.num_sessions):
374+
pos_idx[i] = self.index[i].sample_conditional(query[i])
375+
pos_idx = pos_idx.cpu().numpy()
376+
377+
# reverse indices to recover the ref/pos samples matching
378+
idx_rev = _invert_index(idx)
379+
return pos_idx, idx, idx_rev
380+
381+
def __getitem__(self, pos_idx):
382+
pos_samples = np.zeros(pos_idx.shape[:2] + (self.data.shape[2],))
383+
for i in range(self.num_sessions):
384+
pos_samples[i] = self.data[i][pos_idx[i]]
385+
return pos_samples

cebra/integrations/sklearn/cebra.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _require_arg(key):
153153

154154
# Discrete behavior contrastive training is selected with the default dataloader
155155
if not is_cont and is_disc:
156+
kwargs = dict(**shared_kwargs,)
156157
if is_full:
157158
if is_hybrid:
158159
raise_not_implemented_error = True
@@ -162,7 +163,10 @@ def _require_arg(key):
162163
if is_hybrid:
163164
raise_not_implemented_error = True
164165
else:
165-
raise_not_implemented_error = True
166+
return (
167+
cebra.data.DiscreteMultiSessionDataLoader(**kwargs),
168+
"multi-session",
169+
)
166170

167171
# Mixed behavior contrastive training is selected with the default dataloader
168172
if is_cont and is_disc:
@@ -1030,7 +1034,6 @@ def _partial_fit(
10301034
if callback is None:
10311035
raise ValueError(
10321036
"callback_frequency requires to specify a callback.")
1033-
10341037
model.train()
10351038

10361039
solver.fit(

docs/source/usage.rst

+26-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Firstly, why use CEBRA?
1212

1313
CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE <https://www.jmlr.org/papers/v9/vandermaaten08a.html>`_ and `UMAP <https://arxiv.org/abs/1802.03426>`_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent.
1414

15-
That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for toplogical exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023).
15+
That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023).
1616

1717
The CEBRA workflow
1818
------------------
@@ -419,10 +419,10 @@ We can now fit the model in different modes.
419419

420420
.. rubric:: Multi-session training
421421

422-
For multi-sesson training, lists of data are provided instead of a single dataset and eventual corresponding auxiliary variables.
422+
For multi-session training, lists of data are provided instead of a single dataset and eventual corresponding auxiliary variables.
423423

424424
.. warning::
425-
For now, multi-session training can only handle a **unique set of continuous labels**. All other combinations will raise an error.
425+
For now, multi-session training can only handle a **unique set of continuous labels** or a **unique discrete label**. All other combinations will raise an error. For the continuous case we provide the following example:
426426

427427

428428
.. testcode::
@@ -450,6 +450,29 @@ Once you defined your CEBRA model, you can run:
450450
multi_cebra_model.fit([neural_session1, neural_session2], [continuous_label1, continuous_label2])
451451

452452

453+
Similarly, for the discrete case a discrete label can be provided and the CEBRA model will use the discrete multisession mode:
454+
455+
.. testcode::
456+
457+
timesteps1 = 5000
458+
timesteps2 = 3000
459+
neurons1 = 50
460+
neurons2 = 30
461+
out_dim = 8
462+
463+
neural_session1 = np.random.normal(0,1,(timesteps1, neurons1))
464+
neural_session2 = np.random.normal(0,1,(timesteps2, neurons2))
465+
discrete_label1 = np.random.randint(0,10,(timesteps1, ))
466+
discrete_label2 = np.random.randint(0,10,(timesteps2, ))
467+
468+
multi_cebra_model = cebra.CEBRA(batch_size=512,
469+
output_dimension=out_dim,
470+
max_iterations=10,
471+
max_adapt_iterations=10)
472+
473+
474+
multi_cebra_model.fit([neural_session1, neural_session2], [discrete_label1, discrete_label2])
475+
453476
.. admonition:: See API docs
454477
:class: dropdown
455478

tests/test_distributions.py

+25
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,31 @@ def test_multi_session_time_contrastive(time_offset):
298298
len(rev_idx.flatten())).all())
299299

300300

301+
def test_multi_session_discrete():
302+
dataset = cebra_datasets.init("demo-discrete-multisession")
303+
sampler = cebra_distr.DiscreteMultisessionSampler(dataset)
304+
305+
num_samples = 5
306+
sample = sampler.sample_prior(num_samples)
307+
assert sample.shape == (dataset.num_sessions, num_samples)
308+
309+
positive, idx, rev_idx = sampler.sample_conditional(sample)
310+
assert positive.shape == (dataset.num_sessions, num_samples)
311+
assert idx.shape == (dataset.num_sessions * num_samples,)
312+
assert rev_idx.shape == (dataset.num_sessions * num_samples,)
313+
# NOTE(celia): test the private function ``_inverse_idx()``, with idx arrays flat
314+
assert (idx.flatten()[rev_idx.flatten()].all() == np.arange(
315+
len(rev_idx.flatten())).all())
316+
317+
# Check positive samples' labels match reference samples' labels
318+
sample_labels = sampler.all_data[(sample +
319+
sampler.lengths[:, None]).flatten()]
320+
sample_labels = sample_labels[idx.reshape(sample.shape[:2])].flatten()
321+
positive_labels = sampler.all_data[(positive +
322+
sampler.lengths[:, None]).flatten()]
323+
assert (sample_labels == positive_labels).all()
324+
325+
301326
class OldDeltaDistribution(cebra_distr_base.JointDistribution,
302327
cebra_distr_base.HasGenerator):
303328
"""

tests/test_integration_train.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
import cebra.models
3636
import cebra.solver
3737

38+
if torch.cuda.is_available():
39+
_DEVICE = "cuda"
40+
else:
41+
_DEVICE = "cpu"
42+
3843

3944
def _init_single_session_solver(loader, args):
4045
"""Train a single session CEBRA model."""
@@ -77,6 +82,7 @@ def _list_data_loaders():
7782
cebra.data.HybridDataLoader,
7883
cebra.data.FullDataLoader,
7984
cebra.data.ContinuousMultiSessionDataLoader,
85+
cebra.data.DiscreteMultiSessionDataLoader,
8086
]
8187
# TODO limit this to the valid combinations---however this
8288
# requires to adapt the dataset API slightly; it is currently
@@ -95,7 +101,7 @@ def _list_data_loaders():
95101
@pytest.mark.requires_dataset
96102
@pytest.mark.parametrize("dataset_name, loader_type", _list_data_loaders())
97103
def test_train(dataset_name, loader_type):
98-
args = cebra.config.Config(num_steps=1, device="cuda").as_namespace()
104+
args = cebra.config.Config(num_steps=1, device=_DEVICE).as_namespace()
99105

100106
dataset = cebra.datasets.init(dataset_name)
101107
if loader_type not in cebra_data_helper.get_loader_options(dataset):

0 commit comments

Comments
 (0)