Reproduction of: “Quantum Self-Supervised Learning” (Jaderberg et al.), arXiv:2103.14653 — https://arxiv.org/abs/2103.14653
In this folder, you will find an implementation and evaluation of the core ideas from the paper. It supports three representation networks under the same SSL pipeline: a photonic (MerLin/Perceval) model, a gate-model (Qiskit) model, and a classical MLP baseline.
— Default backend in this repo: MerLin (photonic).
- Dataset and task: CIFAR-10, restricted to the first k labels (e.g., k=5).
- Training: Self-supervised pretraining with InfoNCE on two augmented views (SimCLR-style), followed by linear evaluation with a frozen encoder.
- Models (representation network):
- MerLin photonic circuit (Perceval + Merlin quantum layer)
- Qiskit parameterized circuit (via
qSSL/qnn) - Classical MLP baseline
- Metrics: SSL losses over epochs and linear-probing accuracy curves; checkpoints and run metadata saved per experiment.
Pretrained models in ./results give the following results:
| Number of epochs | Number of classes (CIFAR10) | Qiskit based | Classical SSL | Quantum SSL (no_bunching=False) |
Quantum SSL (no_bunching=True) |
|---|---|---|---|---|---|
| 2 | 5 | 48.37 ✅ OK #32 x0.08/x0.008 |
48.08 🚫 #144 x1/x1 |
8 modes: 49.22 #184 x0.97/x0.95 10 modes: 47.28 #320 x0.89/x0.88 12 modes: 46.46 #488 x0.83/x0.65 |
8 modes: 45.58 #184 x0.97/x0.97 10 modes: 45.58 #320 x0.97/x0.93 12 modes: 45.76 #488 x0.94/x0.82 |
| 5 | 5 | 47.88 | 49.04 | 8 modes: 49.9 10 modes: 51.12 12 modes: 50.64 |
8 modes: 49.3 10 modes: 48.86 12 modes: 51.74 |
Legend:
- #number of parameters
- x ... speed-up (relative to classical baseline)
Overall, we reproduced the results highlighted in the paper and we have a photonic implementantion of it, using MerLin, that is faster and more accurate (but has more trainable parameters).
lib/runner.py— runtime entry point consumed by the repo-level runnerlib/— core library modules used by scriptsdata_utils.py— datasets, transforms (SSL and linear eval)model.py— backbone, representation networks (MerLin/Qiskit/Classical), projection headtraining_utils.py— InfoNCE, training loops, metrics and results I/Odefaults.py— helper to exposeconfigs/defaults.jsonto notebooks/tests
configs/— default configs + CLI schema consumed by the shared runnerdefaults.json,cli.json
- Other
utils/linear_probing.py— evaluate frozen features with a linear head. Pretrained models live underoutdir/requirements.txt— Python dependenciesutils/,tests/— placeholders following the template
python -m venv ssl-venv
source ssl-venv/bin/activate
pip install -r requirements.txtRun with the default MerLin settings from the repository root:
python implementation.py --paper qSSL --config qSSL/configs/defaults.json- Or from inside the project directory:
cd qSSL
python ../implementation.py --paper qSSL --config configs/defaults.json- CLI overrides (mix and match as needed):
# MerLin (photonic)
python implementation.py --paper qSSL --merlin --classes 5 --modes 10 --epochs 2 --batch_size 256 --ckpt-step 1
# Qiskit (gate-model)
python implementation.py --paper qSSL --qiskit --classes 5 --epochs 2 --batch_size 256 --ckpt-step 1
# Classical baseline
python implementation.py --paper qSSL --classical --classes 5 --epochs 2 --batch_size 256 --ckpt-step 1Need to see every toggle first? Run python implementation.py --paper qSSL --help for the auto-generated CLI, including dataset paths, backend switches, and visualization flags.
Data root: CIFAR10 downloads under <DATA_DIR>/qSSL (default DATA_DIR env or <repo>/data). Override the base root with --datadir if needed; the paper subfolder is added automatically.
See configs/defaults.json (overrides are described in cli.json). Key fields:
dataset:root,classes,batch_sizemodel:backend(merlin|qiskit|classical),width,loss_dim,batch_norm,temperature- Qiskit-specific:
layers,encoding,q_ansatz,q_sweeps,activation,shots,q_backend - MerLin-specific:
modes,no_bunching training:epochs,ckpt_step,le_epochs
You can combine --config with CLI overrides. The runner resolves the final configuration and saves it to the results directory (args.json).
- Reference weights are hosted on Hugging Face under
Quandela/ReproducedPapersQML/qSSL. Each run directory mirrors the layout produced locally (checkpoints plusargs.json). qSSL/utils/linear_probing.pydefaults to the MerLin checkpoint atmerlin/20250827_181840/model-cl-5-epoch-5.pth. When--pretrainedis a repo-relative path (or a full HF URL) the script automatically downloads the.pthfile and matchingargs.json.- Use
--hf-repo,--hf-prefix, and--hf-revisionif you need to point to another Hugging Face namespace or branch (defaults are set toQuandela/ReproducedPapersQML/qSSL). - Example:
python qSSL/utils/linear_probing.py \ --pretrained merlin/20250827_181840/model-cl-5-epoch-5.pth \ --hf-repo Quandela/ReproducedPapersQML --hf-prefix qSSL --hf-revision main
- SSL pretraining
- Input: for each image, generate two strong augmentations (query/key) using
TwoCropsTransform. - Backbone: ResNet18 (final FC replaced by Identity).
- Compression: Linear layer to
width(quantum-friendly size). - Representation network (choose one): MerLin, Qiskit, or Classical MLP.
- Projection head: MLP to
loss_dimwith BN + ReLU. - Loss: InfoNCE (temperature τ) on the two views.
- Linear evaluation
- Freeze backbone + compression + representation.
- Train a linear classifier on top using lightly augmented train data and minimal val transforms.
- Report accuracy curves and final/best validation accuracy.
-
MerLin (default)
- Photonic circuit built with Perceval: two trainable interferometers around a phase-encoding layer.
- Features are Sigmoid-normalized and scaled by 1/π to map into phase parameters.
- Parameters:
modes(number of photonic modes),no_bunching(photon statistics),width(input feature size to the circuit), plus trainable circuit phases.
-
Qiskit (gate-model)
- Representation network
QNetwithn_qubits = width. - Configurable
encoding,q_ansatz,layers,q_sweeps,activation,shots, andq_backend(e.g.,qasm_simulator).
- Representation network
-
Classical baseline
- Simple MLP with
args.layersrepetitions of Linear(width, width) + LeakyReLU.
- Simple MLP with
Each invocation writes to <outdir>/run_YYYYMMDD-HHMMSS/ (default base outdir/ inside qSSL/):
config_snapshot.json— final config after merging defaults, CLI, and extra overridesargs.json— lightweight namespace serialized for backward-compatible tools (e.g.,utils/linear_probing.py)run.log— streaming logs from the shared runtimetraining_metrics.json— SSL and linear-eval losses/accuracies over epochsexperiment_summary.json— consolidated summary with final and best val accuracymodel-cl-<classes>-epoch-<n>.pth— checkpoints saved everyckpt_stepepochs
Evaluate pretrained encoders with a frozen representation and train a linear head:
# Default run (downloads the reference Hugging Face checkpoint)
python qSSL/utils/linear_probing.py
# Evaluate all checkpoints from a local run directory
python qSSL/utils/linear_probing.py --pretrained ./outdir/run_<timestamp>/
# Evaluate a specific local checkpoint file
python qSSL/utils/linear_probing.py --pretrained ./outdir/run_<timestamp>/model-cl-5-epoch-5.pth
# Evaluate any other Hugging Face checkpoint via repo-relative path
python qSSL/utils/linear_probing.py --pretrained merlin/<run_id>/model-cl-5-epoch-5.pth- Original paper: Quantum Self-Supervised Learning — https://arxiv.org/abs/2103.14653
- Portions of the Qiskit pipeline and general approach are inspired by the original authors’ resources where relevant.
- For Qiskit, ensure
qiskit-aeris installed and the selected backend (e.g.,qasm_simulator) is available.
Tests are in the ./tests folder and contain tests to validate one forward pass in the classical, MerLin and Qiskit models as well as a test on the InfoNCE loss. Once the environment is installed, you can run them
python3 -m pytest tests/
