From a24825245d41a6a9e1fa2233a11c03634c2039d7 Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Mon, 5 Jan 2026 11:08:08 -0800 Subject: [PATCH 01/39] Update xcode version --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c92ab72..eac4381 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,7 +2,7 @@ version: 2.1 orbs: ocean: dwave/ocean@1 - windows: circleci/windows@5.0 + windows: circleci/windows@5.1 # latest as of Jan 2026 environment: PIP_PROGRESS_BAR: 'off' @@ -37,7 +37,7 @@ jobs: executor: name: ocean/macos - xcode: "16.4.0" + xcode: "26.2.0" # latest as of Jan 2026 steps: - checkout From 5283e68aa80f00f3c2409cde8048e07d52e7792d Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Mon, 5 Jan 2026 11:32:54 -0800 Subject: [PATCH 02/39] Skip broken numpy 2.4.0 --- requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index 1583954..2c6c120 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ +# skip numpy 2.4.0 due to a bug fixed in numpy#30500 +# TODO: remove the constraint once numpy 2.4.0 gets yanked +numpy~=2.0,!=2.4.0 + torch==2.9.1 dimod==0.12.18 dwave-system==1.28.0 From faf9462727936e0b99d8e99c54a236bd0fd533bf Mon Sep 17 00:00:00 2001 From: kchern Date: Wed, 12 Nov 2025 22:43:50 +0000 Subject: [PATCH 03/39] Add maximum mean discrepancy and radial basis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Vladimir Vargas Calderón --- dwave/plugins/torch/models/losses/mmd.py | 214 ++++++++++++++++++ ...dd-mmd-loss-function-3fa9e9a2cb452391.yaml | 10 + tests/requirements.txt | 1 + tests/test_dvae_winci2020.py | 67 ++++++ 4 files changed, 292 insertions(+) create mode 100755 dwave/plugins/torch/models/losses/mmd.py create mode 100644 releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/models/losses/mmd.py new file mode 100755 index 0000000..39075b7 --- /dev/null +++ b/dwave/plugins/torch/models/losses/mmd.py @@ -0,0 +1,214 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional + +import torch +import torch.nn as nn + +from dwave.plugins.torch.nn.modules.utils import store_config + +__all__ = ["Kernel", "RadialBasisFunction", "mmd_loss", "MMDLoss"] + + +class Kernel(nn.Module): + """Base class for kernels. + + Kernels are functions that compute a similarity measure between data points. Any ``Kernel`` + subclass must implement the ``_kernel`` method, which computes the kernel matrix for a given + input multi-dimensional tensor with shape (n, f1, f2, ...), where n is the number of items + and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) + containing the pairwise kernel values. + """ + + @abstractmethod + def _kernel(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) + containing the pairwise kernel values. + + Args: + x (torch.Tensor): A (n, f1, f2, ...) tensor. + + Returns: + torch.Tensor: A (n, n) tensor. + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes kernels for intra and inter set pairs in ``x`` and ``y``. In general, ``x`` and + ``y`` are (n_x, f1, f2, ...) and (n_y, f1, f2, ...) shaped tensors, and the output is a + (n_x + n_y, n_x + n_y) shaped tensor containing the pairwise kernel values. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor. + + Returns: + torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. + """ + if x.shape[1:] != y.shape[1:]: + raise ValueError( + "Input dimensions must match. You are trying to compute " + f"the kernel between tensors of shape {x.shape} and {y.shape}." + ) + # Concatenate along batch dimension + xy = torch.cat([x, y], dim=0) + return self._kernel(xy) + + +class RadialBasisFunction(Kernel): + """Radial basis function kernel. + + This kernel between two data points x and y is defined as + :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth + parameter. + + This implementation considers aggregating multiple radial basis function kernels with different + bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of + multipliers. The base bandwidth can be provided directly or estimated from the data using the + average distance between samples. + + Args: + num_features (int): Number of kernel bandwidths to use. + mul_factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are + computed as :math:`\sigma_i = \sigma * mul\_factor^{i - num\_features // 2}` for + :math:`i` in ``[0, num_features - 1]``. Defaults to 2.0. + bandwidth (float | None): Base bandwidth parameter. If None, the bandwidth is estimated + from the data. Defaults to None. + """ + + @store_config + def __init__( + self, num_features: int, mul_factor: int | float = 2.0, bandwidth: Optional[float] = None + ): + super().__init__() + bandwidth_multipliers = mul_factor ** (torch.arange(num_features) - num_features // 2) + self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) + self.bandwidth = bandwidth + + @torch.no_grad() + def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | float: + """ + Computes the base bandwidth parameter as the average distance between samples if the + bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth. + See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking + the average distance as the bandwidth. + + Args: + l2_distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise + L2 distances between samples. If it is None and the bandwidth is not provided, an + error will be raised. Defaults to None. + + Returns: + torch.Tensor | float: The base bandwidth parameter. + """ + if self.bandwidth is None: + num_samples = l2_distance_matrix.shape[0] + return l2_distance_matrix.sum() / (num_samples**2 - num_samples) + return self.bandwidth + + def _kernel(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the radial basis function kernel as + + .. math:: + k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), + + where :math:`\sigma_i` are the bandwidths. + + Args: + x (torch.Tensor): A (n, f1, f2, ...) tensor. + + Returns: + torch.Tensor: A (n, n) tensor representing the kernel matrix. + """ + distance_matrix = torch.cdist(x, x, p=2) + bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers + return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) + + +def mmd_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """ + Computes the maximum mean discrepancy (MMD) loss between two sets of samples x and y. + + This is a two-sample test to test the null hypothesis that the two samples are drawn from the + same distribution (https://dl.acm.org/doi/abs/10.5555/2188385.2188410). The squared MMD is + defined as + + .. math:: + MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, + + where :math:`\varphi` is a feature map associated with the kernel function + :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the + distributions of the samples. It follows that, in terms of the kernel function, the squared MMD + can be computed as + + .. math:: + E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. + + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. In machine learning applications, the MMD can be + used as a loss function to compare the distribution of model-generated samples to the + distribution of real data samples to force model-generated samples to match the real data + distribution. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + kernel (Kernel): A kernel function object. + + Returns: + torch.Tensor: The computed MMD loss. + """ + num_x = x.shape[0] + num_y = y.shape[0] + kernel_matrix = kernel(x, y) + kernel_xx = kernel_matrix[:num_x, :num_x] + kernel_yy = kernel_matrix[num_x:, num_x:] + kernel_xy = kernel_matrix[:num_x, num_x:] + xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) + yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) + xy = kernel_xy.sum() / (num_x * num_y) + return xx + yy - 2 * xy + + +class MMDLoss(nn.Module): + """ + Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of + samples. + + This uses the `mmd_loss` function to compute the loss. + + Args: + kernel (Kernel): A kernel function object. + """ + + @store_config + def __init__(self, kernel: Kernel): + super().__init__() + self.kernel = kernel + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the MMD loss between two sets of samples x and y. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + + Returns: + torch.Tensor: The computed MMD loss. + """ + return mmd_loss(x, y, self.kernel) diff --git a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml new file mode 100644 index 0000000..a431a7d --- /dev/null +++ b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + MMD loss is available in ``dwave.plugins.torch.models.losses.mmd.mmd_loss``, + which computes the MMD loss using a ``dwave.plugins.torch.models.losses.mmd.Kernel`` + (specialized to the ``dwave.plugins.torch.models.losses.mmd.RBFKernel``). This + enables training encoders in discrete variational autoencoders to match the + distribution of the prior model. + + \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt index b2bd102..d7abc8f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ coverage codecov parameterized +einops diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index f40ef8f..d0b4393 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -15,6 +15,7 @@ import unittest import torch +from einops import repeat from parameterized import parameterized from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine @@ -22,6 +23,7 @@ DiscreteVariationalAutoencoder as DVAE, ) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss +from dwave.plugins.torch.models.losses.mmd import MMDLoss, RadialBasisFunction, mmd_loss from dwave.samplers import SimulatedAnnealingSampler @@ -84,6 +86,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in latent_dims_list} + # Now we also create a DVAE with a trainable Encoder + def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor: + # straight-through estimator that maps positive logits to 1 and negative logits to -1 + hard = torch.sign(logits) + soft = logits + result = hard - soft.detach() + soft + # Now we need to repeat the result n_samples times along a new dimension + return repeat(result, "b ... -> b n ...", n=n_samples) + + self.dvae_with_trainable_encoder = DVAE( + encoder=torch.nn.Linear(input_features, latent_features), + decoder=Decoder(latent_features, input_features), + latent_to_discrete=deterministic_latent_to_discrete, + ) + + self.fixed_boltzmann_machine = GraphRestrictedBoltzmannMachine( + nodes=(0, 1), + edges=[(0, 1)], + linear={0: 0.0, 1: 0.0}, + quadratic={(0, 1): 0.0}, + ) # Creates a uniform distribution over spin strings of length 2 + self.boltzmann_machine = GraphRestrictedBoltzmannMachine( nodes=(0, 1), edges=[(0, 1)], @@ -110,6 +134,49 @@ def test_mappings(self): # map [0, 1] to [-1, 1]: torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3]) + @parameterized.expand([True, False]) + def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): + """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" + dvae = self.dvae_with_trainable_encoder + optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) + kernel = RadialBasisFunction(num_features=5, mul_factor=2.0, bandwidth=None) + # Before training, the encoder will not map data points to the correct spin strings: + expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.squeeze(1) + discretes_set = {tuple(row.tolist()) for row in discretes} + self.assertNotEqual(discretes_set, expected_set) + mmd_loss_module = None + # Train the encoder so that the latent distribution matches the prior GRBM distribution + for _ in range(1000): + optimiser.zero_grad() + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.reshape(discretes.shape[0], -1) + prior_samples = self.fixed_boltzmann_machine.sample( + sampler=self.sampler_sa, + as_tensor=True, + device=discretes.device, + prefactor=1.0, + linear_range=None, + quadratic_range=None, + sample_params=dict(num_sweeps=10, seed=1234, num_reads=100), + ) + if use_mmd_loss_class: + if mmd_loss_module is None: + mmd_loss_module = MMDLoss(kernel) + mmd = mmd_loss_module(discretes, prior_samples) + else: + mmd = mmd_loss(discretes, prior_samples, kernel) + mmd.backward() + optimiser.step() + # After training, the encoder should map data points to spin strings that match the samples + # from the prior GRBM. Since the prior GRBM is uniform over spin strings of length 2, the + # encoder should map the four data points to the four spin strings (in any order). + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.squeeze(1) + discretes_set = {tuple(row.tolist()) for row in discretes} + self.assertEqual(discretes_set, expected_set) + @parameterized.expand([1, 2]) def test_train(self, n_latent_dims): """Test training simple dataset.""" From c7618740b23b4ecf30fd4e31ef36fa321a4e08f1 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 17:45:01 +0000 Subject: [PATCH 04/39] Rename acronyms and fix first-line in docstring --- dwave/plugins/torch/models/losses/mmd.py | 41 ++++++++++++------------ tests/test_dvae_winci2020.py | 6 ++-- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/models/losses/mmd.py index 39075b7..1fcb737 100755 --- a/dwave/plugins/torch/models/losses/mmd.py +++ b/dwave/plugins/torch/models/losses/mmd.py @@ -20,7 +20,7 @@ from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction", "mmd_loss", "MMDLoss"] +__all__ = ["Kernel", "RadialBasisFunction", "maximum_mean_discrepancy", "MaximumMeanDiscrepancy"] class Kernel(nn.Module): @@ -35,26 +35,28 @@ class Kernel(nn.Module): @abstractmethod def _kernel(self, x: torch.Tensor) -> torch.Tensor: - """ + """Perform a pairwise kernel evaluation over samples. + Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) containing the pairwise kernel values. Args: - x (torch.Tensor): A (n, f1, f2, ...) tensor. + x (torch.Tensor): A (n, f1, f2, ..., fk) tensor. Returns: torch.Tensor: A (n, n) tensor. """ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Computes kernels for intra and inter set pairs in ``x`` and ``y``. In general, ``x`` and - ``y`` are (n_x, f1, f2, ...) and (n_y, f1, f2, ...) shaped tensors, and the output is a - (n_x + n_y, n_x + n_y) shaped tensor containing the pairwise kernel values. + """Computes kernels for all pairs between and within ``x`` and ``y``. + + In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk) and (n_y, f1, f2, ..., fk)-shaped + tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing the pairwise + kernel values. Args: - x (torch.Tensor): A (n_x, f1, f2, ...) tensor. - y (torch.Tensor): A (n_y, f1, f2, ...) tensor. + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. Returns: torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. @@ -64,13 +66,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) - # Concatenate along batch dimension xy = torch.cat([x, y], dim=0) return self._kernel(xy) class RadialBasisFunction(Kernel): - """Radial basis function kernel. + """The radial basis function kernel. This kernel between two data points x and y is defined as :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth @@ -101,7 +102,8 @@ def __init__( @torch.no_grad() def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | float: - """ + """Heuristically determine a bandwidth parameter as the average distance between samples. + Computes the base bandwidth parameter as the average distance between samples if the bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth. See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking @@ -140,9 +142,8 @@ def _kernel(self, x: torch.Tensor) -> torch.Tensor: return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) -def mmd_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: - """ - Computes the maximum mean discrepancy (MMD) loss between two sets of samples x and y. +def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Computes the maximum mean discrepancy (MMD) loss between two sets of samples ``x`` and ``y``. This is a two-sample test to test the null hypothesis that the two samples are drawn from the same distribution (https://dl.acm.org/doi/abs/10.5555/2188385.2188410). The squared MMD is @@ -184,9 +185,8 @@ def mmd_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: return xx + yy - 2 * xy -class MMDLoss(nn.Module): - """ - Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of +class MaximumMeanDiscrepancy(nn.Module): + """Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of samples. This uses the `mmd_loss` function to compute the loss. @@ -201,8 +201,7 @@ def __init__(self, kernel: Kernel): self.kernel = kernel def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Computes the MMD loss between two sets of samples x and y. + """Computes the MMD loss between two sets of samples x and y. Args: x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. @@ -211,4 +210,4 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The computed MMD loss. """ - return mmd_loss(x, y, self.kernel) + return maximum_mean_discrepancy(x, y, self.kernel) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index d0b4393..5ee1c80 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE, ) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss -from dwave.plugins.torch.models.losses.mmd import MMDLoss, RadialBasisFunction, mmd_loss +from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancy, RadialBasisFunction, maximum_mean_discrepancy from dwave.samplers import SimulatedAnnealingSampler @@ -163,10 +163,10 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): ) if use_mmd_loss_class: if mmd_loss_module is None: - mmd_loss_module = MMDLoss(kernel) + mmd_loss_module = MaximumMeanDiscrepancy(kernel) mmd = mmd_loss_module(discretes, prior_samples) else: - mmd = mmd_loss(discretes, prior_samples, kernel) + mmd = maximum_mean_discrepancy(discretes, prior_samples, kernel) mmd.backward() optimiser.step() # After training, the encoder should map data points to spin strings that match the samples From 54f1ffe2dc434040b237f3135538dd5f64fec0fa Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 18:18:09 +0000 Subject: [PATCH 05/39] Define kernel as function of two inputs --- dwave/plugins/torch/models/losses/mmd.py | 59 +++++++++++++----------- tests/test_dvae_winci2020.py | 6 +-- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/models/losses/mmd.py index 1fcb737..895a98f 100755 --- a/dwave/plugins/torch/models/losses/mmd.py +++ b/dwave/plugins/torch/models/losses/mmd.py @@ -20,7 +20,8 @@ from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction", "maximum_mean_discrepancy", "MaximumMeanDiscrepancy"] +__all__ = ["Kernel", "RadialBasisFunction", + "maximum_mean_discrepancy_loss", "MaximumMeanDiscrepancyLoss"] class Kernel(nn.Module): @@ -34,17 +35,18 @@ class Kernel(nn.Module): """ @abstractmethod - def _kernel(self, x: torch.Tensor) -> torch.Tensor: + def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) containing the pairwise kernel values. Args: - x (torch.Tensor): A (n, f1, f2, ..., fk) tensor. + x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. Returns: - torch.Tensor: A (n, n) tensor. + torch.Tensor: A (nx, ny) tensor. """ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -66,8 +68,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) - xy = torch.cat([x, y], dim=0) - return self._kernel(xy) + return self._kernel(x, y) class RadialBasisFunction(Kernel): @@ -122,7 +123,7 @@ def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | flo return l2_distance_matrix.sum() / (num_samples**2 - num_samples) return self.bandwidth - def _kernel(self, x: torch.Tensor) -> torch.Tensor: + def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the radial basis function kernel as @@ -132,22 +133,21 @@ def _kernel(self, x: torch.Tensor) -> torch.Tensor: where :math:`\sigma_i` are the bandwidths. Args: - x (torch.Tensor): A (n, f1, f2, ...) tensor. + x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. Returns: - torch.Tensor: A (n, n) tensor representing the kernel matrix. + torch.Tensor: A (nx, ny) tensor representing the kernel matrix. """ - distance_matrix = torch.cdist(x, x, p=2) + distance_matrix = torch.cdist(x, y, p=2) bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) -def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: - """Computes the maximum mean discrepancy (MMD) loss between two sets of samples ``x`` and ``y``. +def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. - This is a two-sample test to test the null hypothesis that the two samples are drawn from the - same distribution (https://dl.acm.org/doi/abs/10.5555/2188385.2188410). The squared MMD is - defined as + The `squared MMD `_ is defined as .. math:: MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, @@ -160,22 +160,25 @@ def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) - .. math:: E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. - If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. In machine learning applications, the MMD can be - used as a loss function to compare the distribution of model-generated samples to the - distribution of real data samples to force model-generated samples to match the real data - distribution. + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss + function for minimizing the distance between the model distribution and data distribution. + + For more information, see + Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). + A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. Args: - x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. - y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. kernel (Kernel): A kernel function object. Returns: - torch.Tensor: The computed MMD loss. + torch.Tensor: The squared maximum mean discrepancy estimate. """ num_x = x.shape[0] num_y = y.shape[0] - kernel_matrix = kernel(x, y) + xy = torch.cat([x, y], dim=0) + kernel_matrix = kernel(xy, xy) kernel_xx = kernel_matrix[:num_x, :num_x] kernel_yy = kernel_matrix[num_x:, num_x:] kernel_xy = kernel_matrix[:num_x, num_x:] @@ -185,11 +188,11 @@ def maximum_mean_discrepancy(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) - return xx + yy - 2 * xy -class MaximumMeanDiscrepancy(nn.Module): - """Creates a module that computes the maximum mean discrepancy (MMD) loss between two sets of - samples. +class MaximumMeanDiscrepancyLoss(nn.Module): + """An unbiased estimator for the squared maximum mean discrepancy. - This uses the `mmd_loss` function to compute the loss. + This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to + compute the loss. Args: kernel (Kernel): A kernel function object. @@ -210,4 +213,4 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The computed MMD loss. """ - return maximum_mean_discrepancy(x, y, self.kernel) + return maximum_mean_discrepancy_loss(x, y, self.kernel) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 5ee1c80..9ac14c7 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE, ) from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss -from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancy, RadialBasisFunction, maximum_mean_discrepancy +from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancyLoss, RadialBasisFunction, maximum_mean_discrepancy_loss from dwave.samplers import SimulatedAnnealingSampler @@ -163,10 +163,10 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): ) if use_mmd_loss_class: if mmd_loss_module is None: - mmd_loss_module = MaximumMeanDiscrepancy(kernel) + mmd_loss_module = MaximumMeanDiscrepancyLoss(kernel) mmd = mmd_loss_module(discretes, prior_samples) else: - mmd = maximum_mean_discrepancy(discretes, prior_samples, kernel) + mmd = maximum_mean_discrepancy_loss(discretes, prior_samples, kernel) mmd.backward() optimiser.step() # After training, the encoder should map data points to spin strings that match the samples From 208d2dc7b853b397ed8ab123bcd8ad4c023303f4 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 18:36:04 +0000 Subject: [PATCH 06/39] Refactor MMD into kernels, functional, and loss --- dwave/plugins/torch/nn/functional.py | 71 ++++++++++++++ .../losses/mmd.py => nn/modules/kernels.py} | 92 ++----------------- dwave/plugins/torch/nn/modules/loss.py | 56 +++++++++++ tests/test_dvae_winci2020.py | 15 +-- 4 files changed, 145 insertions(+), 89 deletions(-) create mode 100755 dwave/plugins/torch/nn/functional.py rename dwave/plugins/torch/{models/losses/mmd.py => nn/modules/kernels.py} (59%) create mode 100755 dwave/plugins/torch/nn/modules/loss.py diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py new file mode 100755 index 0000000..3399632 --- /dev/null +++ b/dwave/plugins/torch/nn/functional.py @@ -0,0 +1,71 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functional interface.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dwave.plugins.torch.nn.modules.kernels import Kernel + +import torch + +__all__ = ["maximum_mean_discrepancy_loss"] + + +def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. + + The `squared MMD `_ is defined as + + .. math:: + MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, + + where :math:`\varphi` is a feature map associated with the kernel function + :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the + distributions of the samples. It follows that, in terms of the kernel function, the squared MMD + can be computed as + + .. math:: + E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. + + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss + function for minimizing the distance between the model distribution and data distribution. + + For more information, see + Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). + A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. + kernel (Kernel): A kernel function object. + + Returns: + torch.Tensor: The squared maximum mean discrepancy estimate. + """ + num_x = x.shape[0] + num_y = y.shape[0] + xy = torch.cat([x, y], dim=0) + kernel_matrix = kernel(xy, xy) + kernel_xx = kernel_matrix[:num_x, :num_x] + kernel_yy = kernel_matrix[num_x:, num_x:] + kernel_xy = kernel_matrix[:num_x, num_x:] + xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) + yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) + xy = kernel_xy.sum() / (num_x * num_y) + return xx + yy - 2 * xy + +torch.nn.MSELoss \ No newline at end of file diff --git a/dwave/plugins/torch/models/losses/mmd.py b/dwave/plugins/torch/nn/modules/kernels.py similarity index 59% rename from dwave/plugins/torch/models/losses/mmd.py rename to dwave/plugins/torch/nn/modules/kernels.py index 895a98f..e13ab1d 100755 --- a/dwave/plugins/torch/models/losses/mmd.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Kernel functions.""" from abc import abstractmethod from typing import Optional @@ -20,18 +21,17 @@ from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction", - "maximum_mean_discrepancy_loss", "MaximumMeanDiscrepancyLoss"] +__all__ = ["Kernel", "RadialBasisFunction"] class Kernel(nn.Module): """Base class for kernels. - Kernels are functions that compute a similarity measure between data points. Any ``Kernel`` - subclass must implement the ``_kernel`` method, which computes the kernel matrix for a given - input multi-dimensional tensor with shape (n, f1, f2, ...), where n is the number of items - and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) - containing the pairwise kernel values. + `Kernels `_ are functions that compute a similarity + measure between data points. Any ``Kernel`` subclass must implement the ``_kernel`` method, + which computes the kernel matrix for a given input multi-dimensional tensor with shape + (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that + the output is a tensor of shape (n, n) containing the pairwise kernel values. """ @abstractmethod @@ -52,9 +52,9 @@ def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes kernels for all pairs between and within ``x`` and ``y``. - In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk) and (n_y, f1, f2, ..., fk)-shaped - tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing the pairwise - kernel values. + In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk)- and (n_y, f1, f2, ..., fk)-shaped + tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing pairwise kernel + evaluations. Args: x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. @@ -142,75 +142,3 @@ def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: distance_matrix = torch.cdist(x, y, p=2) bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) - - -def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: - """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. - - The `squared MMD `_ is defined as - - .. math:: - MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, - - where :math:`\varphi` is a feature map associated with the kernel function - :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the - distributions of the samples. It follows that, in terms of the kernel function, the squared MMD - can be computed as - - .. math:: - E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. - - If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss - function for minimizing the distance between the model distribution and data distribution. - - For more information, see - Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). - A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. - - Args: - x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. - y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. - kernel (Kernel): A kernel function object. - - Returns: - torch.Tensor: The squared maximum mean discrepancy estimate. - """ - num_x = x.shape[0] - num_y = y.shape[0] - xy = torch.cat([x, y], dim=0) - kernel_matrix = kernel(xy, xy) - kernel_xx = kernel_matrix[:num_x, :num_x] - kernel_yy = kernel_matrix[num_x:, num_x:] - kernel_xy = kernel_matrix[:num_x, num_x:] - xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) - yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) - xy = kernel_xy.sum() / (num_x * num_y) - return xx + yy - 2 * xy - - -class MaximumMeanDiscrepancyLoss(nn.Module): - """An unbiased estimator for the squared maximum mean discrepancy. - - This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to - compute the loss. - - Args: - kernel (Kernel): A kernel function object. - """ - - @store_config - def __init__(self, kernel: Kernel): - super().__init__() - self.kernel = kernel - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Computes the MMD loss between two sets of samples x and y. - - Args: - x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. - y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. - - Returns: - torch.Tensor: The computed MMD loss. - """ - return maximum_mean_discrepancy_loss(x, y, self.kernel) diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py new file mode 100755 index 0000000..eed4b28 --- /dev/null +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -0,0 +1,56 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.utils import store_config + +if TYPE_CHECKING: + from dwave.plugins.torch.nn.modules.kernels import Kernel + +__all__ = ["MaximumMeanDiscrepancyLoss"] + + +class MaximumMeanDiscrepancyLoss(nn.Module): + """An unbiased estimator for the squared maximum mean discrepancy as a loss function. + + This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to + compute the loss. + + Args: + kernel (Kernel): A kernel function object. + """ + + @store_config + def __init__(self, kernel: Kernel): + super().__init__() + self.kernel = kernel + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes the MMD loss between two sets of samples x and y. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + + Returns: + torch.Tensor: The computed MMD loss. + """ + return mmd_loss(x, y, self.kernel) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 9ac14c7..13f962e 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -19,11 +19,12 @@ from parameterized import parameterized from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine -from dwave.plugins.torch.models.discrete_variational_autoencoder import ( - DiscreteVariationalAutoencoder as DVAE, -) +from dwave.plugins.torch.models.discrete_variational_autoencoder import \ + DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss -from dwave.plugins.torch.models.losses.mmd import MaximumMeanDiscrepancyLoss, RadialBasisFunction, maximum_mean_discrepancy_loss +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF +from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler @@ -139,7 +140,7 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" dvae = self.dvae_with_trainable_encoder optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) - kernel = RadialBasisFunction(num_features=5, mul_factor=2.0, bandwidth=None) + kernel = RBF(num_features=5, mul_factor=2.0, bandwidth=None) # Before training, the encoder will not map data points to the correct spin strings: expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} _, discretes, _ = dvae(self.data, n_samples=1) @@ -163,10 +164,10 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): ) if use_mmd_loss_class: if mmd_loss_module is None: - mmd_loss_module = MaximumMeanDiscrepancyLoss(kernel) + mmd_loss_module = MMDLoss(kernel) mmd = mmd_loss_module(discretes, prior_samples) else: - mmd = maximum_mean_discrepancy_loss(discretes, prior_samples, kernel) + mmd = mmd_loss(discretes, prior_samples, kernel) mmd.backward() optimiser.step() # After training, the encoder should map data points to spin strings that match the samples From 550b6f426cdfad569b29383bda9f7320397c8d72 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 19:39:43 +0000 Subject: [PATCH 07/39] Add unit tests for kernels --- dwave/plugins/torch/nn/modules/kernels.py | 38 ++++---- tests/test_dvae_winci2020.py | 2 +- tests/test_kernels.py | 102 ++++++++++++++++++++++ 3 files changed, 121 insertions(+), 21 deletions(-) create mode 100755 tests/test_kernels.py diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index e13ab1d..aed869a 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -14,7 +14,6 @@ """Kernel functions.""" from abc import abstractmethod -from typing import Optional import torch import torch.nn as nn @@ -84,25 +83,25 @@ class RadialBasisFunction(Kernel): average distance between samples. Args: - num_features (int): Number of kernel bandwidths to use. - mul_factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are - computed as :math:`\sigma_i = \sigma * mul\_factor^{i - num\_features // 2}` for - :math:`i` in ``[0, num_features - 1]``. Defaults to 2.0. - bandwidth (float | None): Base bandwidth parameter. If None, the bandwidth is estimated - from the data. Defaults to None. + n_kernels (int): Number of kernel bandwidths to use. + factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are + computed as :math:`\sigma_i = \sigma * factor^{i - n\_kernels // 2}` for + :math:`i` in ``[0, n\_kernels - 1]``. Defaults to 2.0. + bandwidth (float | None): Base bandwidth parameter. If ``None``, the bandwidth is computed + from the data (without gradients). Defaults to ``None``. """ @store_config def __init__( - self, num_features: int, mul_factor: int | float = 2.0, bandwidth: Optional[float] = None + self, n_kernels: int, factor: int | float = 2.0, bandwidth: float | None = None ): super().__init__() - bandwidth_multipliers = mul_factor ** (torch.arange(num_features) - num_features // 2) - self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) + factors = factor ** (torch.arange(n_kernels) - n_kernels // 2) + self.register_buffer("factors", factors) self.bandwidth = bandwidth @torch.no_grad() - def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | float: + def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float: """Heuristically determine a bandwidth parameter as the average distance between samples. Computes the base bandwidth parameter as the average distance between samples if the @@ -111,21 +110,20 @@ def _get_bandwidth(self, l2_distance_matrix: torch.Tensor) -> torch.Tensor | flo the average distance as the bandwidth. Args: - l2_distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise - L2 distances between samples. If it is None and the bandwidth is not provided, an - error will be raised. Defaults to None. + distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise + L2 distances between samples. If it is ``None`` and the bandwidth is not provided, + an error will be raised. Defaults to ``None``. Returns: torch.Tensor | float: The base bandwidth parameter. """ if self.bandwidth is None: - num_samples = l2_distance_matrix.shape[0] - return l2_distance_matrix.sum() / (num_samples**2 - num_samples) + num_samples = distance_matrix.shape[0] + return distance_matrix.sum() / (num_samples**2 - num_samples) return self.bandwidth def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Computes the radial basis function kernel as + """Compute the radial basis function kernel between ``x`` and ``y``. .. math:: k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), @@ -139,6 +137,6 @@ def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: A (nx, ny) tensor representing the kernel matrix. """ - distance_matrix = torch.cdist(x, y, p=2) - bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers + distance_matrix = torch.cdist(x.flatten(1), y.flatten(1), p=2) + bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.factors return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 13f962e..efdaeaa 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -140,7 +140,7 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" dvae = self.dvae_with_trainable_encoder optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) - kernel = RBF(num_features=5, mul_factor=2.0, bandwidth=None) + kernel = RBF(n_kernels=5, factor=2.0, bandwidth=None) # Before training, the encoder will not map data points to the correct spin strings: expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} _, discretes, _ = dvae(self.data, n_samples=1) diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100755 index 0000000..a507885 --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,102 @@ +import unittest + +import torch +from parameterized import parameterized + +from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF + + +class TestKernel(unittest.TestCase): + def test_forward(self): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((5, 3)) + y = torch.randn((9, 3)) + self.assertEqual(1, one(x, y)) + + def test_shape_mismatch(self): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((5, 4)) + y = torch.randn((9, 3)) + self.assertRaises(ValueError, one, x, y) + + +class TestRadialBasisFunction(unittest.TestCase): + + def test_has_config(self): + rbf = RBF(5, 2.1, 0.1) + self.assertDictEqual(dict(rbf.config), dict(module_name="RadialBasisFunction", + n_kernels=5, factor=2.1, bandwidth=0.1)) + + @parameterized.expand([ + (torch.randn((5, 12)), torch.rand((7, 12))), + (torch.randn((5, 12, 34)), torch.rand((7, 12, 34))), + ]) + def test_shape(self, x, y): + rbf = RBF(2, 2.1, 0.1) + k = rbf(x, y) + self.assertEqual(tuple(k.shape), (x.shape[0], y.shape[0])) + + def test_get_bandwidth_default(self): + rbf = RBF(2, 2.1, 0.1) + d = torch.tensor(123) + self.assertEqual(0.1, rbf._get_bandwidth(d)) + + def test_get_bandwidth(self): + rbf = RBF(2, 2.1, None) + d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]]) + self.assertEqual(3.4, rbf._get_bandwidth(d)) + + def test_get_bandwidth_no_grad(self): + rbf = RBF(2, 2.1, None) + d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]], requires_grad=True) + self.assertEqual(3.4, rbf._get_bandwidth(d)) + self.assertIsNone(rbf._get_bandwidth(d).grad) + + def test_single_factors(self): + rbf = RBF(1, 2.1, None) + self.assertListEqual(rbf.factors.tolist(), [1.0]) + + def test_two_factors(self): + rbf = RBF(2, 2.1, None) + torch.testing.assert_close(torch.tensor([2.1**-1, 1]), rbf.factors) + + def test_three_factors(self): + rbf = RBF(3, 2.1, None) + torch.testing.assert_close(torch.tensor([2.1**-1, 1, 2.1]), rbf.factors) + + def test_kernel(self): + x = torch.tensor([[1.0, 1.0], + [2.0, 3.0]], requires_grad=True) + y = torch.tensor([[0.0, 1.0], + [-3.0, 5.0], + [1.2, 9.0]], requires_grad=True) + dist = torch.cdist(x, y) + + with self.subTest("Adaptive bandwidth"): + rbf = RBF(1, 2.1, None) + bandwidths = rbf._get_bandwidth(dist) * rbf.factors + manual = torch.exp(-dist/bandwidths) + torch.testing.assert_close(manual, rbf(x, y)) + + with self.subTest("Simple bandwidth"): + rbf = RBF(1, 2.1, 12.34) + bandwidths = 12.34 * rbf.factors + manual = torch.exp(-dist/bandwidths) + torch.testing.assert_close(manual, rbf(x, y)) + + with self.subTest("Multiple kernels"): + rbf = RBF(3, 2.1, 123) + bandwidths = rbf._get_bandwidth(dist) * rbf.factors + manual = torch.exp(-dist/bandwidths.reshape(-1, 1, 1)).sum(0) + torch.testing.assert_close(manual, rbf(x, y)) + + +if __name__ == "__main__": + unittest.main() From decd59037e7af95adbc804baa5b7b77a6ab1bb05 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 23:19:33 +0000 Subject: [PATCH 08/39] Add tests for functional and loss add errors --- dwave/plugins/torch/nn/functional.py | 26 +++++++- dwave/plugins/torch/nn/modules/kernels.py | 7 +- tests/test_functional.py | 80 +++++++++++++++++++++++ tests/test_kernels.py | 4 +- tests/test_loss.py | 47 +++++++++++++ tests/test_nn.py | 13 ++++ 6 files changed, 170 insertions(+), 7 deletions(-) create mode 100755 tests/test_functional.py create mode 100755 tests/test_loss.py diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py index 3399632..b8b250b 100755 --- a/dwave/plugins/torch/nn/functional.py +++ b/dwave/plugins/torch/nn/functional.py @@ -25,13 +25,21 @@ __all__ = ["maximum_mean_discrepancy_loss"] +class SampleSizeError(ValueError): + pass + + +class DimensionMismatchError(ValueError): + pass + + def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. The `squared MMD `_ is defined as .. math:: - MMD^2(X, Y) = \|E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] \|^2, + MMD^2(X, Y) = |E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] |^2, where :math:`\varphi` is a feature map associated with the kernel function :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the @@ -53,11 +61,25 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. kernel (Kernel): A kernel function object. + Raises: + SampleSizeError: If the sample size of ``x`` or ``y`` is less than two. + DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + Returns: torch.Tensor: The squared maximum mean discrepancy estimate. """ num_x = x.shape[0] num_y = y.shape[0] + if num_x < 2 or num_y < 2: + raise SampleSizeError( + "Sample size of ``x`` and ``y`` must be at least two. " + f"Got, respectively, {x.shape} and {y.shape}." + ) + if x.shape[1:] != y.shape[1:]: + raise DimensionMismatchError( + "Input dimensions must match. You are trying to compute " + f"the kernel between tensors of shape {x.shape} and {y.shape}." + ) xy = torch.cat([x, y], dim=0) kernel_matrix = kernel(xy, xy) kernel_xx = kernel_matrix[:num_x, :num_x] @@ -67,5 +89,3 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) xy = kernel_xy.sum() / (num_x * num_y) return xx + yy - 2 * xy - -torch.nn.MSELoss \ No newline at end of file diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index aed869a..61e9c21 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn +from dwave.plugins.torch.nn.functional import DimensionMismatchError from dwave.plugins.torch.nn.modules.utils import store_config __all__ = ["Kernel", "RadialBasisFunction"] @@ -32,7 +33,6 @@ class Kernel(nn.Module): (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) containing the pairwise kernel values. """ - @abstractmethod def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. @@ -59,11 +59,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. + Raises: + DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + Returns: torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. """ if x.shape[1:] != y.shape[1:]: - raise ValueError( + raise DimensionMismatchError( "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100755 index 0000000..da85182 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,80 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from dwave.plugins.torch.nn.functional import SampleSizeError +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel + + +class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): + def test_mmd_loss_constant(self): + x = torch.tensor([[1.2], [4.1]]) + y = torch.tensor([[0.3], [0.5]]) + + class Constant(Kernel): + def __init__(self): + super().__init__() + self.k = torch.tensor([[10, 4, 0, 1], + [4, 10, 4, 2], + [0, 4, 10, 3], + [1, 2, 3, 10]]).float() + + def _kernel(self, x, y): + return self.k + # The resulting kernel matrix will be constant, so (averages) KXX = KYY = 2KXY + kernel = Constant() + # kxx = (4 + 4)/2 + # kyy = (3 + 3)/2 + # kxy = (0 + 1 + 4 + 2)/4 + # kxx + kyy -2kxy = 4 + 3 - 3.5 = 3.5 + self.assertEqual(3.5, mmd_loss(x, y, kernel)) + + def test_sample_size_error(self): + x = torch.tensor([[1.2], [4.1]]) + y = torch.tensor([[0.3]]) + self.assertRaises(SampleSizeError, mmd_loss, x, y, None) + + def test_mmd_loss_dim_mismatch(self): + x = torch.tensor([[1], [4]], dtype=torch.float32) + y = torch.tensor([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.6]]) + self.assertRaises(DimensionMismatchError, mmd_loss, x, y, None) + + def test_mmd_loss_arange(self): + x = torch.tensor([[1.0], [4.0], [5.0]]) + y = torch.tensor([[0.3], [0.4]]) + + class Constant(Kernel): + def _kernel(self, x, y): + return torch.tensor([[150, 22, 39, 34, 28], + [22, 630, 98, 56, 44], + [39, 98, 560, 78, 33], + [-99, -99, -99, 299, 13], + [-99, -99, -99, 13, 970]], dtype=torch.float32) + + mmd_loss(x, y, Constant()) + # NOTE: calculation takes kxy = upper-right corner; no PSD assumption + # kxx = (22+39+98)/3 + # kyy = 13 + # kxy = (34+28+56+44+78+33)/6 + # kxx + kyy - 2*kxy + # kxx + kyy - 2*kxy = -25.0 + self.assertEqual(-25, mmd_loss(x, y, Constant())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_kernels.py b/tests/test_kernels.py index a507885..7b29d72 100755 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF @@ -24,7 +24,7 @@ def _kernel(self, x, y): one = One() x = torch.rand((5, 4)) y = torch.randn((9, 3)) - self.assertRaises(ValueError, one, x, y) + self.assertRaises(DimensionMismatchError, one, x, y) class TestRadialBasisFunction(unittest.TestCase): diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100755 index 0000000..e59dbe3 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,47 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from parameterized import parameterized + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss + + +class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): + @parameterized.expand([ + (torch.tensor([[1.2], [4.1]]), torch.tensor([[0.3], [0.5]])), + (torch.randn((123, 4, 3, 2)), torch.rand(100, 4, 3, 2)), + ]) + def test_mmd_loss(self, x, y): + class Constant(Kernel): + def __init__(self): + super().__init__() + self.k = torch.tensor([[10, 4, 0, 1], + [4, 10, 4, 2], + [0, 4, 10, 3], + [1, 2, 3, 10]]).float() + + def _kernel(self, x, y): + return self.k + # The resulting kernel matrix will be constant, so (averages) KXX = KYY = 2KXY + kernel = Constant() + compute_mmd = MMDLoss(kernel) + torch.testing.assert_close(mmd_loss(x, y, kernel), compute_mmd(x, y)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..bac40c9 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,3 +1,16 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest import torch From 69a3072c1846a9043fdf360e952e7338bbf95551 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 16 Dec 2025 23:45:26 +0000 Subject: [PATCH 09/39] Update release note --- .../add-mmd-loss-function-3fa9e9a2cb452391.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml index a431a7d..46ea631 100644 --- a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml +++ b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml @@ -1,10 +1,10 @@ --- features: - | - MMD loss is available in ``dwave.plugins.torch.models.losses.mmd.mmd_loss``, - which computes the MMD loss using a ``dwave.plugins.torch.models.losses.mmd.Kernel`` - (specialized to the ``dwave.plugins.torch.models.losses.mmd.RBFKernel``). This - enables training encoders in discrete variational autoencoders to match the - distribution of the prior model. + Add a ``MaximumMeanDiscrepancyLoss`` in ``dwave.plugins.torch.nn.loss`` for estimating the + squared maximum mean discrepancy (MMD) for a given kernel and two samples. + Its functional counterpart ``maximum_mean_discrepancy_loss`` is in + ``dwave.plugins.torch.nn.functional``. + Kernels reside in ``dwave.plugins.torch.nn.modules.kernels``. This enables, for example, + training discrete autoencoders to match the distribution of a target distribution (e.g., prior). - \ No newline at end of file From bd31d7d1972fed7d02f38b693d3abc3e2fc04d5f Mon Sep 17 00:00:00 2001 From: kchern Date: Wed, 17 Dec 2025 18:49:33 +0000 Subject: [PATCH 10/39] Rename RBF to GaussianKernel --- dwave/plugins/torch/nn/modules/kernels.py | 10 ++++---- tests/test_dvae_winci2020.py | 2 +- tests/test_kernels.py | 29 +++++++++++------------ 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 61e9c21..2d614d5 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -21,7 +21,7 @@ from dwave.plugins.torch.nn.functional import DimensionMismatchError from dwave.plugins.torch.nn.modules.utils import store_config -__all__ = ["Kernel", "RadialBasisFunction"] +__all__ = ["Kernel", "GaussianKernel"] class Kernel(nn.Module): @@ -73,14 +73,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self._kernel(x, y) -class RadialBasisFunction(Kernel): - """The radial basis function kernel. +class GaussianKernel(Kernel): + """The Gaussian kernel. This kernel between two data points x and y is defined as :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth parameter. - This implementation considers aggregating multiple radial basis function kernels with different + This implementation considers aggregating multiple Gaussian kernels with different bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of multipliers. The base bandwidth can be provided directly or estimated from the data using the average distance between samples. @@ -126,7 +126,7 @@ def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float: return self.bandwidth def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Compute the radial basis function kernel between ``x`` and ``y``. + """Compute the Gaussian kernel between ``x`` and ``y``. .. math:: k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index efdaeaa..f5f67c6 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss -from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF +from dwave.plugins.torch.nn.modules.kernels import GaussianKernel as RBF from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 7b29d72..fd106af 100755 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,8 +3,7 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel -from dwave.plugins.torch.nn.modules.kernels import RadialBasisFunction as RBF +from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, GaussianKernel, Kernel class TestKernel(unittest.TestCase): @@ -27,11 +26,11 @@ def _kernel(self, x, y): self.assertRaises(DimensionMismatchError, one, x, y) -class TestRadialBasisFunction(unittest.TestCase): +class TestGaussianKernel(unittest.TestCase): def test_has_config(self): - rbf = RBF(5, 2.1, 0.1) - self.assertDictEqual(dict(rbf.config), dict(module_name="RadialBasisFunction", + rbf = GaussianKernel(5, 2.1, 0.1) + self.assertDictEqual(dict(rbf.config), dict(module_name="GaussianKernel", n_kernels=5, factor=2.1, bandwidth=0.1)) @parameterized.expand([ @@ -39,36 +38,36 @@ def test_has_config(self): (torch.randn((5, 12, 34)), torch.rand((7, 12, 34))), ]) def test_shape(self, x, y): - rbf = RBF(2, 2.1, 0.1) + rbf = GaussianKernel(2, 2.1, 0.1) k = rbf(x, y) self.assertEqual(tuple(k.shape), (x.shape[0], y.shape[0])) def test_get_bandwidth_default(self): - rbf = RBF(2, 2.1, 0.1) + rbf = GaussianKernel(2, 2.1, 0.1) d = torch.tensor(123) self.assertEqual(0.1, rbf._get_bandwidth(d)) def test_get_bandwidth(self): - rbf = RBF(2, 2.1, None) + rbf = GaussianKernel(2, 2.1, None) d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]]) self.assertEqual(3.4, rbf._get_bandwidth(d)) def test_get_bandwidth_no_grad(self): - rbf = RBF(2, 2.1, None) + rbf = GaussianKernel(2, 2.1, None) d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]], requires_grad=True) self.assertEqual(3.4, rbf._get_bandwidth(d)) self.assertIsNone(rbf._get_bandwidth(d).grad) def test_single_factors(self): - rbf = RBF(1, 2.1, None) + rbf = GaussianKernel(1, 2.1, None) self.assertListEqual(rbf.factors.tolist(), [1.0]) def test_two_factors(self): - rbf = RBF(2, 2.1, None) + rbf = GaussianKernel(2, 2.1, None) torch.testing.assert_close(torch.tensor([2.1**-1, 1]), rbf.factors) def test_three_factors(self): - rbf = RBF(3, 2.1, None) + rbf = GaussianKernel(3, 2.1, None) torch.testing.assert_close(torch.tensor([2.1**-1, 1, 2.1]), rbf.factors) def test_kernel(self): @@ -80,19 +79,19 @@ def test_kernel(self): dist = torch.cdist(x, y) with self.subTest("Adaptive bandwidth"): - rbf = RBF(1, 2.1, None) + rbf = GaussianKernel(1, 2.1, None) bandwidths = rbf._get_bandwidth(dist) * rbf.factors manual = torch.exp(-dist/bandwidths) torch.testing.assert_close(manual, rbf(x, y)) with self.subTest("Simple bandwidth"): - rbf = RBF(1, 2.1, 12.34) + rbf = GaussianKernel(1, 2.1, 12.34) bandwidths = 12.34 * rbf.factors manual = torch.exp(-dist/bandwidths) torch.testing.assert_close(manual, rbf(x, y)) with self.subTest("Multiple kernels"): - rbf = RBF(3, 2.1, 123) + rbf = GaussianKernel(3, 2.1, 123) bandwidths = rbf._get_bandwidth(dist) * rbf.factors manual = torch.exp(-dist/bandwidths.reshape(-1, 1, 1)).sum(0) torch.testing.assert_close(manual, rbf(x, y)) From c3eb030867bb0372261008ca6f7752db794e80cc Mon Sep 17 00:00:00 2001 From: kchern Date: Wed, 17 Dec 2025 18:52:08 +0000 Subject: [PATCH 11/39] Renme RBF to GaussianKernel --- tests/test_dvae_winci2020.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index f5f67c6..38dfff7 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -23,7 +23,7 @@ DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss -from dwave.plugins.torch.nn.modules.kernels import GaussianKernel as RBF +from dwave.plugins.torch.nn.modules.kernels import GaussianKernel from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler @@ -140,7 +140,7 @@ def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" dvae = self.dvae_with_trainable_encoder optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) - kernel = RBF(n_kernels=5, factor=2.0, bandwidth=None) + kernel = GaussianKernel(n_kernels=5, factor=2.0, bandwidth=None) # Before training, the encoder will not map data points to the correct spin strings: expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} _, discretes, _ = dvae(self.data, n_samples=1) From 58f858ae0ad31d63c65b433f21b07e334374bde6 Mon Sep 17 00:00:00 2001 From: kchern Date: Mon, 5 Jan 2026 18:09:50 +0000 Subject: [PATCH 12/39] Remove custom errors and fix docstrings Co-Authored-By: Theodor Isacsson --- dwave/plugins/torch/nn/functional.py | 15 ++++----------- dwave/plugins/torch/nn/modules/kernels.py | 14 +++++++++----- dwave/plugins/torch/nn/modules/loss.py | 4 ++-- tests/test_functional.py | 7 +++---- tests/test_kernels.py | 15 ++++++++++++--- 5 files changed, 30 insertions(+), 25 deletions(-) diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py index b8b250b..8327f7a 100755 --- a/dwave/plugins/torch/nn/functional.py +++ b/dwave/plugins/torch/nn/functional.py @@ -25,13 +25,6 @@ __all__ = ["maximum_mean_discrepancy_loss"] -class SampleSizeError(ValueError): - pass - - -class DimensionMismatchError(ValueError): - pass - def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. @@ -62,8 +55,8 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern kernel (Kernel): A kernel function object. Raises: - SampleSizeError: If the sample size of ``x`` or ``y`` is less than two. - DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + ValueError: If the sample size of ``x`` or ``y`` is less than two. + ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) Returns: torch.Tensor: The squared maximum mean discrepancy estimate. @@ -71,12 +64,12 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern num_x = x.shape[0] num_y = y.shape[0] if num_x < 2 or num_y < 2: - raise SampleSizeError( + raise ValueError( "Sample size of ``x`` and ``y`` must be at least two. " f"Got, respectively, {x.shape} and {y.shape}." ) if x.shape[1:] != y.shape[1:]: - raise DimensionMismatchError( + raise ValueError( "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 2d614d5..2aa6ced 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -13,18 +13,17 @@ # limitations under the License. """Kernel functions.""" -from abc import abstractmethod +from abc import ABC, abstractmethod import torch import torch.nn as nn -from dwave.plugins.torch.nn.functional import DimensionMismatchError from dwave.plugins.torch.nn.modules.utils import store_config __all__ = ["Kernel", "GaussianKernel"] -class Kernel(nn.Module): +class Kernel(ABC, nn.Module): """Base class for kernels. `Kernels `_ are functions that compute a similarity @@ -60,16 +59,21 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. Raises: - DimensionMismatchError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) Returns: torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. """ if x.shape[1:] != y.shape[1:]: - raise DimensionMismatchError( + raise ValueError( "Input dimensions must match. You are trying to compute " f"the kernel between tensors of shape {x.shape} and {y.shape}." ) + if x.shape[0] < 2 or y.shape[0] < 2: + raise ValueError( + "Sample size of ``x`` and ``y`` must be at least two. " + f"Got, respectively, {x.shape} and {y.shape}." + ) return self._kernel(x, y) diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py index eed4b28..5da41f5 100755 --- a/dwave/plugins/torch/nn/modules/loss.py +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -29,7 +29,7 @@ class MaximumMeanDiscrepancyLoss(nn.Module): - """An unbiased estimator for the squared maximum mean discrepancy as a loss function. + """An unbiased estimator for the squared maximum mean discrepancy (MMD) as a loss function. This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to compute the loss. @@ -44,7 +44,7 @@ def __init__(self, kernel: Kernel): self.kernel = kernel def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Computes the MMD loss between two sets of samples x and y. + """Computes the MMD loss between two sets of samples ``x`` and ``y``. Args: x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. diff --git a/tests/test_functional.py b/tests/test_functional.py index da85182..17f81f4 100755 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -15,9 +15,8 @@ import torch -from dwave.plugins.torch.nn.functional import SampleSizeError from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss -from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, Kernel +from dwave.plugins.torch.nn.modules.kernels import Kernel class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): @@ -46,13 +45,13 @@ def _kernel(self, x, y): def test_sample_size_error(self): x = torch.tensor([[1.2], [4.1]]) y = torch.tensor([[0.3]]) - self.assertRaises(SampleSizeError, mmd_loss, x, y, None) + self.assertRaisesRegex(ValueError, "must be at least two", mmd_loss, x, y, None) def test_mmd_loss_dim_mismatch(self): x = torch.tensor([[1], [4]], dtype=torch.float32) y = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) - self.assertRaises(DimensionMismatchError, mmd_loss, x, y, None) + self.assertRaisesRegex(ValueError, "Input dimensions must match. You are trying to compute ", mmd_loss, x, y, None) def test_mmd_loss_arange(self): x = torch.tensor([[1.0], [4.0], [5.0]]) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index fd106af..278ca53 100755 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -3,7 +3,7 @@ import torch from parameterized import parameterized -from dwave.plugins.torch.nn.modules.kernels import DimensionMismatchError, GaussianKernel, Kernel +from dwave.plugins.torch.nn.modules.kernels import Kernel, GaussianKernel class TestKernel(unittest.TestCase): @@ -16,6 +16,16 @@ def _kernel(self, x, y): y = torch.randn((9, 3)) self.assertEqual(1, one(x, y)) + @parameterized.expand([(1, 2), (2, 1)]) + def test_sample_size(self, nx, ny): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((nx, 5)) + y = torch.randn((ny, 5)) + self.assertRaisesRegex(ValueError, "must be at least two", one, x, y) + def test_shape_mismatch(self): class One(Kernel): def _kernel(self, x, y): @@ -23,8 +33,7 @@ def _kernel(self, x, y): one = One() x = torch.rand((5, 4)) y = torch.randn((9, 3)) - self.assertRaises(DimensionMismatchError, one, x, y) - + self.assertRaisesRegex(ValueError, "Input dimensions must match", one, x, y) class TestGaussianKernel(unittest.TestCase): From 6e85113d36d70c108b84dd2b9b009b66cd8dad15 Mon Sep 17 00:00:00 2001 From: kchern Date: Mon, 5 Jan 2026 18:20:11 +0000 Subject: [PATCH 13/39] Fix a docstring --- dwave/plugins/torch/nn/modules/kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 2aa6ced..5dfaf4e 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -36,7 +36,7 @@ class Kernel(ABC, nn.Module): def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. - Computes the kernel matrix for an input of shape (n, f1, f2, ...), whose shape is (n, n) + Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and (ny, f1, f2, ..., fk), whose shape is (nx, ny) containing the pairwise kernel values. Args: From fb756c2f5879bcffb66f374464f407b81df79cc5 Mon Sep 17 00:00:00 2001 From: Kevin Chern <32395608+kevinchern@users.noreply.github.com> Date: Tue, 6 Jan 2026 10:23:11 -0800 Subject: [PATCH 14/39] Fix minor code aesthetics Co-authored-by: Theodor Isacsson --- dwave/plugins/torch/nn/modules/kernels.py | 4 +++- dwave/plugins/torch/nn/modules/loss.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py index 5dfaf4e..6119cb5 100755 --- a/dwave/plugins/torch/nn/modules/kernels.py +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -32,11 +32,13 @@ class Kernel(ABC, nn.Module): (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that the output is a tensor of shape (n, n) containing the pairwise kernel values. """ + @abstractmethod def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Perform a pairwise kernel evaluation over samples. - Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and (ny, f1, f2, ..., fk), whose shape is (nx, ny) + Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and + (ny, f1, f2, ..., fk), whose shape is (nx, ny) containing the pairwise kernel values. Args: diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py index 5da41f5..45488ec 100755 --- a/dwave/plugins/torch/nn/modules/loss.py +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -39,7 +39,7 @@ class MaximumMeanDiscrepancyLoss(nn.Module): """ @store_config - def __init__(self, kernel: Kernel): + def __init__(self, kernel: Kernel) -> None: super().__init__() self.kernel = kernel From 616d1bd9924903f9f336b79e0cfbada392b01c5e Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Wed, 7 Jan 2026 16:02:08 -0800 Subject: [PATCH 15/39] Update requirements for Python 3.14 --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2c6c120..52a6645 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,9 +3,9 @@ numpy~=2.0,!=2.4.0 torch==2.9.1 -dimod==0.12.18 -dwave-system==1.28.0 -dwave-hybrid==0.6.13 +dimod==0.12.21 +dwave-system==1.34.0 +dwave-hybrid==0.6.14 # Development requirements reno==4.1.0 From d08b203e920871d25e0719cb5cac4ef63d2b93a4 Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 14 Nov 2025 21:54:00 +0000 Subject: [PATCH 16/39] Add block-spin update sampler --- dwave/plugins/torch/nn/functional.py | 42 ++- dwave/plugins/torch/samplers/__init__.py | 15 + .../torch/samplers/block_spin_sampler.py | 332 ++++++++++++++++++ dwave/plugins/torch/tensor.py | 47 +++ dwave/plugins/torch/utils.py | 2 +- .../block-spin-sampler-b62ba4c83880c729.yaml | 10 + tests/test_block_sampler.py | 274 +++++++++++++++ tests/test_functional.py | 24 +- tests/test_tensor.py | 27 ++ 9 files changed, 768 insertions(+), 5 deletions(-) create mode 100755 dwave/plugins/torch/samplers/__init__.py create mode 100644 dwave/plugins/torch/samplers/block_spin_sampler.py create mode 100755 dwave/plugins/torch/tensor.py create mode 100755 releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml create mode 100755 tests/test_block_sampler.py create mode 100755 tests/test_tensor.py diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py index 8327f7a..8b26d49 100755 --- a/dwave/plugins/torch/nn/functional.py +++ b/dwave/plugins/torch/nn/functional.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Functional interface.""" - from __future__ import annotations from typing import TYPE_CHECKING @@ -22,8 +21,7 @@ import torch -__all__ = ["maximum_mean_discrepancy_loss"] - +__all__ = ["maximum_mean_discrepancy_loss", "bit2spin_soft", "spin2bit_soft"] def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: @@ -82,3 +80,41 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) xy = kernel_xy.sum() / (num_x * num_y) return xx + yy - 2 * xy + + +def bit2spin_soft(b: torch.Tensor) -> torch.Tensor: + """Maps input :math:`b` to :math:`2b-1`. + + The mapping does not require :math:`b` to be binary, only that it is in the interval :math:`[0, 1]`. + + Args: + b (torch.Tensor): Input tensor of values in :math:`[0, 1]`. + + Raises: + ValueError: If not all ``b`` values are in :math:`[0, 1]`. + + Returns: + torch.Tensor: A tensor with values :math:`2b-1`. + """ + if not ((b >= 0) & (b <= 1)).all(): + raise ValueError(f"Not all inputs are in [0, 1]: {b}") + return b * 2 - 1 + + +def spin2bit_soft(s: torch.Tensor) -> torch.Tensor: + """Maps input :math:`s` to :math:`(s+1)/2`. + + The mapping does not require :math:`s` to be spin-valued, only that it is in the interval :math:`[-1, 1]`. + + Args: + s (torch.Tensor): Input tensor of values in :math:`[-1, 1]`. + + Raises: + ValueError: If not all ``s`` values are in `[-1, 1]`. + + Returns: + torch.Tensor: A tensor with values :math:`(s+1)/2`. + """ + if (s.abs() > 1).any(): + raise ValueError(f"Not all inputs are in [-1, 1]: {s}") + return (s + 1) / 2 diff --git a/dwave/plugins/torch/samplers/__init__.py b/dwave/plugins/torch/samplers/__init__.py new file mode 100755 index 0000000..932b865 --- /dev/null +++ b/dwave/plugins/torch/samplers/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dwave.plugins.torch.samplers.block_spin_sampler import * diff --git a/dwave/plugins/torch/samplers/block_spin_sampler.py b/dwave/plugins/torch/samplers/block_spin_sampler.py new file mode 100644 index 0000000..cadb256 --- /dev/null +++ b/dwave/plugins/torch/samplers/block_spin_sampler.py @@ -0,0 +1,332 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from typing import TYPE_CHECKING, Callable, Hashable, Literal + +import torch +from torch import nn + +if TYPE_CHECKING: + from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM + from torch._prims_common import DeviceLikeType + +from dwave.plugins.torch.nn.functional import bit2spin_soft +from dwave.plugins.torch.tensor import randspin + +__all__ = ["BlockSampler"] + + +class BlockSampler: + """A block-spin update sampler for graph-restricted Boltzmann machines. + + Due to the sparse definition of GRBMs, some tedious indexing tricks are required to + efficiently sample in blocks of spins. Ideally, an adjacency list can be used, however, + adjacencies are ragged, making vectorization inapplicable. + + Block-Gibbs and Block-Metropolis obey detailed balance and are ergodic methods at finite nonzero + temperature which, at fixed parameters, converge upon Boltzmann distributions. Block-Metropolis + allows higher acceptance rates for proposals (faster single-step mixing), but is non-ergodic in + the limit of zero or infinite temperature. Decorrelation from an initial condition can be slower. + Block-Gibbs represents best practice for independent sampling. + + Args: + grbm (GRBM): The Graph-Restricted Boltzmann Machine to sample from. + crayon (Callable[Hashable, Hashable]): A colouring function that maps a single + node of the ``grbm`` to its colour. + num_chains (int): Number of Markov chains to run in parallel. + initial_states (torch.Tensor | None): A tensor of +/-1 values of shape + (``num_chains``, ``grbm.n_nodes``) representing the initial states of the Markov chains. + If None, initial states will be uniformly randomized with number of chains equal to + ``num_chains``. Defaults to None. + schedule (Iterable[Float]): The inverse temperature schedule. + proposal_acceptance_criteria (Literal["Gibbs", "Metropolis"]): The proposal acceptance + criterion used to accept or reject states in the Markov chain. Defaults to "Gibbs". + seed (int | None): Random seed. Defaults to None. + + Raises: + InvalidProposalAcceptanceCriteriaError: If the proposal acceptance criteria is not one of + "Gibbs" or "Metropolis". + """ + + def __init__(self, grbm: GRBM, crayon: Callable[[Hashable], Hashable], num_chains: int, + schedule: Iterable[float], + proposal_acceptance_criteria: Literal["Gibbs", "Metropolis"] = "Gibbs", + initial_states: torch.Tensor | None = None, + seed: int | None = None): + super().__init__() + + if num_chains < 1: + raise ValueError("Number of reads should be a positive integer.") + + self._proposal_acceptance_criteria = proposal_acceptance_criteria.title() + if self._proposal_acceptance_criteria not in {"Gibbs", "Metropolis"}: + raise ValueError( + 'Proposal acceptance criterion should be one of "Gibbs" or "Metropolis"' + ) + + self._grbm: GRBM = grbm + self._crayon: Callable[[Hashable], Hashable] = crayon + if not self._valid_crayon(): + raise ValueError( + "crayon is not a valid colouring of grbm. " + + "At least one edge has vertices of the same colour." + ) + + self._partition = self._get_partition() + self._padded_adjacencies, self._padded_adjacencies_weight = self._get_adjacencies() + + self._rng = torch.Generator() + if seed is not None: + self._rng = self._rng.manual_seed(seed) + + initial_states = self._prepare_initial_states(num_chains, initial_states, self._rng) + self._schedule = nn.Parameter(torch.tensor(list(schedule)), requires_grad=False) + self._x = nn.Parameter(initial_states.float(), requires_grad=False) + self._zeros = nn.Parameter(torch.zeros((num_chains, 1)), requires_grad=False) + + def to(self, device: DeviceLikeType) -> BlockSampler: + """Moves sampler components to the target device. + + If the device is "meta", then the random number generator (RNG) + will not be modified at all. For all other devices, all attributes used for performing + block-spin updates will be moved to the target device. Importantly, the RNG's device is + relayed by the following procedure: + 1. Draw a random integer between 0 (inclusive) and 2**60 (exclusive) with the current + generator as a new seed ``s``. + 2. Create a new generator on the target device. + 3. Set the new generator's seed as ``s``. + + Developer-note: Not sure the above constitutes a good practice, but I not aware of any + obvious solution for moving generators across devices. + + Args: + device (DeviceLikeType): The target device. + """ + self._x = self._x.to(device) + self._zeros = self._zeros.to(device) + self._schedule = self._schedule.to(device) + self._partition = self._partition.to(device) + self._padded_adjacencies = self._padded_adjacencies.to(device) + self._padded_adjacencies_weight = self._padded_adjacencies_weight.to(device) + if device != "meta": + rng = torch.Generator(device) + rng.manual_seed(torch.randint(0, 2**60, (1,), generator=self._rng).item()) + self._rng = rng + return self + + def _prepare_initial_states( + self, num_chains: int, initial_states: torch.Tensor | None = None, + generator: torch.Generator | None = None + ) -> torch.Tensor: + """Convert initial states to tensor or sample uniformly random spins as initial states. + + Args: + num_chains (int): Number of initial states. + initial_states (torch.Tensor | None): A tensor of shape + (``num_chains``, ``self._grbm.n_nodes``) representing the initial states of the + sampler's Markov chains. If None, then initial states are sampled uniformly from + +/-1 values. Defaults to None. + generator (torch.Generator | None): A random number generator. + + Raises: + ShapeMismatchError: If the shape of initial states do not match that of the expected + (``num_chains``, ``self._grbm.n_nodes``). + NonSpinError: If the provided initial states have nonspin-valued entries. + + Returns: + torch.Tensor: The initial states of the sampler's Markov chain. + """ + if initial_states is None: + initial_states = randspin((num_chains, self._grbm.n_nodes), generator=generator) + + if initial_states.shape != (num_chains, self._grbm.n_nodes): + raise ValueError( + "Initial states should be of shape ``num_chains, grbm.n_nodes`` " + f"{(num_chains, self._grbm.n_nodes)}, but got {tuple(initial_states.shape)} instead." + ) + + if not set(initial_states.unique().tolist()).issubset({-1, 1}): + raise ValueError("Initial states contain nonspin values.") + + return initial_states + + def _valid_crayon(self) -> bool: + """Determines whether ``crayon`` is a valid colouring of the graph-restricted Boltzmann machine. + + Returns: + bool: True if the colouring is valid and False otherwise. + """ + for u, v in self._grbm.edges: + if self._crayon(u) == self._crayon(v): + return False + return True + + def _get_partition(self) -> nn.ParameterList: + """Computes the vertex partition induced by the colouring function. + + Returns: + nn.ParameterList: The partition induced by the colouring. + """ + partition = defaultdict(list) + for node in self._grbm.nodes: + idx = self._grbm.node_to_idx[node] + c = self._crayon(node) + partition[c].append(idx) + partition = nn.ParameterList([ + nn.Parameter(torch.tensor(partition[k], requires_grad=False), requires_grad=False) + for k in sorted(partition) + ]) + return partition + + def _get_adjacencies(self) -> tuple[torch.Tensor, torch.Tensor]: + """Create two adjacency matrices, one for neighbouring indices and another for the + corresponding edge weights' indices. + + The issue begins with the adjacency lists being ragged. To address this, we pad adjacencies + with ``-1`` values. The exact values do not matter, as the way these adjacencies will be used + is by padding an input state with 0s, so when accessing ``-1``, the output will be masked out. + + For example, consider the returned adjacency matrices ``padded_adjacencies`` and + ``padded_adjacencies_weight``. + + In the first adjacency matrix, ``padded_adjacencies[0]`` is a + ``torch.Tensor`` consisting of indices of neighbouring vertices of vertex ``0``. Values of + ``-1`` in this tensor indicates no neighbour. + + In the second adjacency matrix, ``padded_adjacencies_weight[0]`` is a ``torch.Tensor`` + consisting of indices of edge weight indices corresponding to edges of vertex ``0``. + Similarly, ``-1`` values in this tensor indicates no neighbour. + + Returns: + tuple[list[torch.Tensor], list[torch.Tensor]]: The first output is a padded adjacency + matrix, the second output is an adjacency matrix of edge weight indices. + """ + max_degree = 0 + if self._grbm.n_edges: + max_degree = torch.unique(torch.cat([self._grbm.edge_idx_i, self._grbm.edge_idx_j]), + return_counts=True)[1].max().item() + adjacency = nn.Parameter( + -torch.ones(self._grbm.n_nodes, max_degree, dtype=int), requires_grad=False + ) + adjacency_weight = nn.Parameter( + -torch.ones(self._grbm.n_nodes, max_degree, dtype=int), requires_grad=False + ) + + adjacency_dict = defaultdict(list) + edge_to_idx = dict() + for idx, (u, v) in enumerate( + zip(self._grbm.edge_idx_i.tolist(), + self._grbm.edge_idx_j.tolist())): + adjacency_dict[v].append(u) + adjacency_dict[u].append(v) + edge_to_idx[u, v] = idx + edge_to_idx[v, u] = idx + for u in self._grbm.idx_to_node: + neighbours = adjacency_dict[u] + adj_weight_idxs = [edge_to_idx[u, v] for v in neighbours] + num_neighbours = len(neighbours) + adjacency[u][:num_neighbours] = torch.tensor(neighbours) + adjacency_weight[u][:num_neighbours] = torch.tensor(adj_weight_idxs) + return adjacency, adjacency_weight + + @torch.no_grad + def _compute_effective_field(self, block) -> torch.Tensor: + """Computes the effective field for all vertices in ``block``. + + Args: + block (nn.ParameterList): A list of integers (indices) corresponding to the vertices of + a colour. + + Returns: + torch.Tensor: The effective fields of each vertex in ``block``. + """ + xnbr = torch.hstack([self._x, self._zeros])[:, self._padded_adjacencies[block]] + h = self._grbm.linear[block] + J = self._grbm.quadratic[self._padded_adjacencies_weight[block]] + return (xnbr * J.unsqueeze(0)).sum(2) + h + + @torch.no_grad + def _metropolis_update(self, beta: float, block: nn.ParameterList, + effective_field: torch.Tensor) -> None: + """Performs a Metropolis update in-place. + + Args: + beta (float): The inverse temperature to sample at. + block (nn.ParameterList): A list of integers (indices) corresponding to the vertices of + a colour. + effective_field (torch.Tensor): Effective fields of each spin corresponding to indices + of the block. + """ + delta = -2 * self._x[:, block] * effective_field + prob = (-delta * beta).exp().clip(0, 1) + + # if the delta field is negative, then flipping the spin will improve the energy + prob[delta <= 0] = 1 + flip = -bit2spin_soft(prob.bernoulli(generator=self._rng)) + self._x[:, block] = flip * self._x[:, block] + + @torch.no_grad + def _gibbs_update(self, beta: torch.Tensor, block: torch.nn.ParameterList, effective_field: torch.Tensor) -> None: + """Performs a Gibbs update in-place. + + Args: + beta (torch.Tensor): The (scalar) inverse temperature to sample at. + block (nn.ParameterList): A list of integers (indices) corresponding to the vertices of + a colour. + effective_field (torch.Tensor): Effective fields of each spin corresponding to indices + of the block. + """ + prob = 1 / (1 + torch.exp(2 * beta * effective_field)) + spins = bit2spin_soft(prob.bernoulli(generator=self._rng)) + self._x[:, block] = spins + + @torch.no_grad + def _step(self, beta: torch.Tensor) -> None: + """Performs a block-spin update in-place. + + Args: + beta (torch.Tensor): Inverse temperature to sample at. + """ + for block in self._partition: + effective_field = self._compute_effective_field(block) + if self._proposal_acceptance_criteria == "Metropolis": + self._metropolis_update(beta, block, effective_field) + elif self._proposal_acceptance_criteria == "Gibbs": + self._gibbs_update(beta, block, effective_field) + else: + # NOTE: This line should never be reached because acceptance proposal criterion + # should've been checked on instantiation + raise ValueError(f"Invalid proposal acceptance criterion.") + + @torch.no_grad + def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: + """Performs block updates. + + Args: + x (torch.Tensor): A tensor of shape (``batch_size``, ``dim``) or (``batch_size``, ``n_nodes``) + interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will + be sampled; entries with +/-1 values will remain constant. + + Returns: + torch.Tensor: A tensor of shape (batch_size, dim) of +/-1 values sampled from the model. + """ + if x is not None: + raise NotImplementedError("Support for conditional sampling has not been implemented.") + for beta in self._schedule: + self._step(beta) + return self._x diff --git a/dwave/plugins/torch/tensor.py b/dwave/plugins/torch/tensor.py new file mode 100755 index 0000000..b2e860a --- /dev/null +++ b/dwave/plugins/torch/tensor.py @@ -0,0 +1,47 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +if TYPE_CHECKING: + from torch import Generator + from torch._prims_common import DeviceLikeType + from torch.types import _bool, _dtype, _size + +__all__ = ["randspin"] + + +def randspin(size: _size, **kwargs) -> torch.Tensor: + """Wrapper for ``torch.randint`` restricted to spin outputs (+/-1 values). + + Args: + size (torch.types._size): Shape of the output tensor. + **kwargs: Keyword arguments of ``torch.randint``. + + Raises: + ValueError: If ``low`` is supplied as a keyword argument. + ValueError: If ``high`` is supplied as a keyword argument. + + Returns: + torch.Tensor: A tensor of +/-1 values. + """ + if "low" in kwargs: + raise ValueError("Invalid keyword argument `low`.") + if "high" in kwargs: + raise ValueError("Invalid keyword argument `high`.") + b = torch.randint(0, 2, size, **kwargs) + return 2 * b - 1 diff --git a/dwave/plugins/torch/utils.py b/dwave/plugins/torch/utils.py index fd142db..b3e147e 100755 --- a/dwave/plugins/torch/utils.py +++ b/dwave/plugins/torch/utils.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch diff --git a/releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml b/releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml new file mode 100755 index 0000000..14d2de6 --- /dev/null +++ b/releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + Add ``BlockSampler`` for performing block Gibbs (or Metropolis) sampling of + graph-restricted Boltzmann Machines. + - | + Add functions for converting spins to bits and bits to spins. + - | + Add ``randspin`` for generating random spins. + diff --git a/tests/test_block_sampler.py b/tests/test_block_sampler.py new file mode 100755 index 0000000..0df7d43 --- /dev/null +++ b/tests/test_block_sampler.py @@ -0,0 +1,274 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import dwave_networkx as dnx +import networkx as nx +import torch +from parameterized import parameterized + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM +from dwave.plugins.torch.samplers.block_spin_sampler import BlockSampler + + +class TestBlockSampler(unittest.TestCase): + ZEPHYR = dnx.zephyr_graph(1, coordinates=True) + GRBM_ZEPHYR = GRBM(ZEPHYR.nodes, ZEPHYR.edges) + CRAYON_ZEPHYR = dnx.zephyr_four_color + + BIPARTITE = nx.complete_bipartite_graph(5, 3) + GRBM_BIPARTITE = GRBM(BIPARTITE.nodes, BIPARTITE.edges) + def CRAYON_BIPARTITE(b): return b < 5 + + GRBM_SINGLE = GRBM([0], []) + def CRAYON_SINGLE(s): 0 + + GRBM_CRAYON_TEST_CASES = [(GRBM_ZEPHYR, CRAYON_ZEPHYR), + (GRBM_BIPARTITE, CRAYON_BIPARTITE), + (GRBM_SINGLE, CRAYON_SINGLE)] + + def setUp(self) -> None: + self.crayon_veqa = lambda v: v == "a" + return super().setUp() + + @parameterized.expand(GRBM_CRAYON_TEST_CASES) + def test_sample(self, grbm, crayon): + for pac in "Metropolis", "Gibbs": + schedule = [0.0, 1.0, 2.0] + bss1 = BlockSampler(grbm, crayon, 10, schedule, pac, seed=1) + bss1.sample() + + bss2 = BlockSampler(grbm, crayon, 10, [1.0], pac, seed=1) + for beta in schedule: + bss2._step(beta) + + self.assertListEqual(bss1._x.tolist(), bss2._x.tolist()) + + def test_device(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_000_000 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Gibbs", seed=2) + bss.to('meta') + self.assertEqual("cpu", bss._grbm.linear.device.type) + self.assertEqual("cpu", bss._grbm.quadratic.device.type) + self.assertEqual("meta", bss._x.device.type) + self.assertEqual("meta", bss._padded_adjacencies.device.type) + self.assertEqual("meta", bss._padded_adjacencies_weight.device.type) + self.assertEqual("meta", bss._zeros.device.type) + self.assertEqual("meta", bss._schedule.device.type) + self.assertEqual("meta", bss._partition[0].device.type) + self.assertEqual("meta", bss._partition[1].device.type) + # NOTE: "meta" device is not supported for torch.Generator + self.assertEqual("cpu", bss._rng.device.type) + + def test_gibbs_update(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_000_000 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Gibbs", seed=2) + bss._x.data[:] = 1 + zero = torch.tensor(0.0) + ones = torch.ones((sample_size, 1)) + bss._gibbs_update(0.0, bss._partition[0], ones*zero) + torch.testing.assert_close(torch.tensor(0.5), bss._x.mean(), atol=1e-3, rtol=1e-3) + bss._gibbs_update(0.0, bss._partition[1], ones*zero) + torch.testing.assert_close(torch.tensor(0.0), bss._x.mean(), atol=1e-3, rtol=1e-3) + + effective_field = torch.tensor(1.2) + bss._gibbs_update(1.0, bss._partition[0], effective_field*ones) + bss._gibbs_update(1.0, bss._partition[1], effective_field*ones) + torch.testing.assert_close( + torch.tanh(-effective_field), + bss._x.mean(), + atol=1e-3, rtol=1e-3) + + def test_initial_states_respected(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + initial_states = torch.tensor([[-1, 1], [1, 1], [-1, -1], [1, 1], [-1, 1], [-1, 1], [1, 1]]) + + bss = BlockSampler(grbm, crayon, len(initial_states), [1.0], "Metropolis", + initial_states, 2) + self.assertListEqual(bss._x.tolist(), initial_states.tolist()) + + def test_metropolis_update_average(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_000_000 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Metropolis", seed=2) + bss._x.data[:] = 1 + ones = torch.ones((sample_size, 1)) + effective_field = torch.tensor(1.2) + for i in range(10): + bss._metropolis_update(1.0, bss._partition[0], effective_field*ones) + bss._metropolis_update(1.0, bss._partition[1], effective_field*ones) + torch.testing.assert_close( + torch.tanh(-effective_field), + bss._x.mean(), + atol=1e-3, rtol=1e-3) + + def test_metropolis_update_oscillates(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_00 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Metropolis", seed=2) + bss._x.data[:] = 1 + zero_effective_field = torch.zeros((sample_size, 1)) + bss._metropolis_update(0.0, bss._partition[0], zero_effective_field) + self.assertTrue((bss._x[:, 1] == -1).all()) + bss._metropolis_update(0.0, bss._partition[1], zero_effective_field) + self.assertTrue((bss._x == -1).all()) + + def test_effective_field(self): + # Create a triangle graph with an additional dangling vertex + # a + # / | \ + # b--c d + self.nodes = list("abcd") + self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]] + + # Manually set the parameter weights for testing + dtype = torch.float32 + grbm = GRBM(self.nodes, self.edges) + grbm._linear.data = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=dtype) + grbm._quadratic.data = torch.tensor([1.1, 2.2, 3.3, 6.6], dtype=dtype) + + def crayon(v): + if v == "a": + return 0 + if v == "b": + return 1 + if v == "c": + return 2 + if v == "d": + return 1 + bss = BlockSampler(grbm, crayon, 3, [1.0], seed=3) + bss._x.data[:] = torch.tensor([[1, 1, -1, -1], + [-1, -1, 1, -1], + [1, 1, 1, -1]]) + # effective field for a + effective_field_a = bss._compute_effective_field(bss._partition[0]) + torch.testing.assert_close( + effective_field_a, + torch.tensor([[0.0 + 1.1 - 2.2 - 3.3], + [0.0 - 1.1 + 2.2 - 3.3], + [0.0 + 1.1 + 2.2 - 3.3]]) + ) + # effective field for b, d + effective_field_bd = bss._compute_effective_field(bss._partition[1]) + torch.testing.assert_close(effective_field_bd, + torch.tensor([[1.0 + 1.1 - 6.6, 3.0 + 3.3], + [1.0 - 1.1 + 6.6, 3.0 - 3.3], + [1.0 + 1.1 + 6.6, 3.0 + 3.3]])) + # effective field for c + effective_field_c = bss._compute_effective_field(bss._partition[2]) + torch.testing.assert_close(effective_field_c, + torch.tensor([[2.0 + 2.2 + 6.6], + [2.0 - 2.2 - 6.6], + [2.0 + 2.2 + 6.6]])) + + def test_get_adjacencies(self): + # Create a triangle graph with an additional dangling vertex + # a + # / | \ + # b--c d + self.nodes = list("abcd") + self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]] + + # Manually set the parameter weights for testing + dtype = torch.float32 + grbm = GRBM(self.nodes, self.edges) + grbm._linear.data = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=dtype) + grbm._quadratic.data = torch.tensor([1.1, 2.2, 3.3, 6.6], dtype=dtype) + + def crayon(v): + if v == "a": + return 0 + if v == "b": + return 1 + if v == "c": + return 2 + if v == "d": + return 1 + bss = BlockSampler(grbm, crayon, 10, [1.0], seed=4) + padded_adj, padded_adj_weights = bss._get_adjacencies() + + # First, check the neighbour indices are correct + # a has neighbours b, c, d in that order, so 2, 3, 4 + self.assertListEqual(padded_adj[0].tolist(), [1, 2, 3]) + # b has neighbours a, c, in that order, so 0, 2, and padded -1 + self.assertListEqual(padded_adj[1].tolist(), [0, 2, -1]) + # c has neighbours a, b, in that order, so 0, 1, and padded -1 + self.assertListEqual(padded_adj[2].tolist(), [0, 1, -1]) + # d has neighbour a, so 0, and two padded -1 + self.assertListEqual(padded_adj[3].tolist(), [0, -1, -1]) + + # Next, check weights are correct + # a has edges 0, 1, 2 + self.assertListEqual(padded_adj_weights[0].tolist(), [0, 1, 2]) + # b has edges 0, 3, + self.assertListEqual(padded_adj_weights[1].tolist(), [0, 3, -1]) + # c has edges 0, 3, + self.assertListEqual(padded_adj_weights[2].tolist(), [1, 3, -1]) + # d has edges 2 + self.assertListEqual(padded_adj_weights[3].tolist(), [2, -1, -1]) + + @parameterized.expand(GRBM_CRAYON_TEST_CASES) + def test_get_partition(self, grbm: GRBM, crayon): + bss = BlockSampler(grbm, crayon, 10, [1.0], seed=5) + # Check every block is indeed coloured correctly + for block in bss._partition: + self.assertEqual(1, len({crayon(grbm.idx_to_node[bidx]) for bidx in block.tolist()})) + # Check every node has been included + self.assertSetEqual({idx for block in bss._partition for idx in block.tolist()}, + {bss._grbm.node_to_idx[node] for node in bss._grbm.nodes}) + + def test_invalid_crayon(self): + grbm = GRBM([0, 1], [(0, 1)]) + def crayon(n): return 1 + self.assertRaisesRegex(ValueError, "not a valid colouring", BlockSampler, grbm, crayon, 10, [1.0]) + + def test_invalid_proposal(self): + grbm = GRBM([0, 1], [(0, 1)]) + def crayon(n): return 1 + self.assertRaisesRegex(ValueError, "Proposal acceptance criterion should be one of", BlockSampler, + grbm, crayon, 10, [1.0], "abc") + + def test_prepare_initial_states(self): + grbm = GRBM([0, 1, 2], [(0, 1)]) + def crayon(n): return n + bss = BlockSampler(grbm, crayon, 1, [1.0],) + + with self.subTest("Nonspin initial states."): + self.assertRaisesRegex(ValueError, "contain nonspin values", bss._prepare_initial_states, + initial_states=torch.tensor([[0, 1, -1]]), num_chains=1) + + with self.subTest("Testing initial states with incorrect shape."): + self.assertRaisesRegex(ValueError, "Initial states should be of shape", bss._prepare_initial_states, + num_chains=10, initial_states=torch.tensor([[-1, 1, 1, 1, -1]])) + + @parameterized.expand(GRBM_CRAYON_TEST_CASES) + def test_invalid_num_reads(self, grbm, crayon): + self.assertRaisesRegex(ValueError, "should be a positive integer", BlockSampler, grbm, crayon, 0, [1.0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_functional.py b/tests/test_functional.py index 17f81f4..da3e770 100755 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -14,8 +14,11 @@ import unittest import torch +from parameterized import parameterized +from dwave.plugins.torch.nn.functional import bit2spin_soft from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.functional import spin2bit_soft from dwave.plugins.torch.nn.modules.kernels import Kernel @@ -51,7 +54,9 @@ def test_mmd_loss_dim_mismatch(self): x = torch.tensor([[1], [4]], dtype=torch.float32) y = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) - self.assertRaisesRegex(ValueError, "Input dimensions must match. You are trying to compute ", mmd_loss, x, y, None) + self.assertRaisesRegex(ValueError, + "Input dimensions must match. You are trying to compute ", + mmd_loss, x, y, None) def test_mmd_loss_arange(self): x = torch.tensor([[1.0], [4.0], [5.0]]) @@ -75,5 +80,22 @@ def _kernel(self, x, y): self.assertEqual(-25, mmd_loss(x, y, Constant())) +class TestFunctional(unittest.TestCase): + + def test_spin2bit_soft(self): + self.assertListEqual(spin2bit_soft(torch.tensor([-1.0, 1.0, 0.5])).tolist(), [0, 1, 0.75]) + + @parameterized.expand([([-1.1, 1.0],), ([-0.5, 1.1],)]) + def test_spin2bit_raises(self, input): + self.assertRaises(ValueError, spin2bit_soft, torch.tensor(input)) + + def test_bit2spin_soft(self): + self.assertListEqual(bit2spin_soft(torch.tensor([0.0, 1.0, 0.5])).tolist(), [-1, 1, 0]) + + @parameterized.expand([([-0.1, 1.0],), ([0.1, 1.1],)]) + def test_bit2spin_soft_raises(self, input): + self.assertRaises(ValueError, bit2spin_soft, torch.tensor(input)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100755 index 0000000..67022a4 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,27 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dwave.plugins.torch.tensor import randspin + + +class TestTensor(unittest.TestCase): + + def test_rands(self): + self.assertSetEqual({-1, 1}, set(randspin((2000,)).unique().tolist())) + + +if __name__ == "__main__": + unittest.main() From 868e38b085b0952c133689cdaed0bc1b4c545d4c Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Thu, 18 Dec 2025 17:09:50 -0800 Subject: [PATCH 17/39] Add samplers submodule --- .../plugins/torch/models/boltzmann_machine.py | 80 +--------- dwave/plugins/torch/samplers/__init__.py | 2 + dwave/plugins/torch/samplers/_base.py | 104 +++++++++++++ dwave/plugins/torch/samplers/dimod_sampler.py | 107 ++++++++++++++ tests/test_samplers/test_annealing.py | 15 ++ tests/test_samplers/test_base.py | 138 ++++++++++++++++++ 6 files changed, 367 insertions(+), 79 deletions(-) create mode 100644 dwave/plugins/torch/samplers/_base.py create mode 100644 dwave/plugins/torch/samplers/dimod_sampler.py create mode 100644 tests/test_samplers/test_annealing.py create mode 100644 tests/test_samplers/test_base.py diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 8bd0350..9f36756 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -32,12 +32,12 @@ import torch if TYPE_CHECKING: + from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler from dimod import Sampler, SampleSet from dimod import BinaryQuadraticModel from hybrid.composers import AggregatedSamples -from dwave.plugins.torch.utils import sampleset_to_tensor from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple spread = AggregatedSamples.spread @@ -261,84 +261,6 @@ def theta(self) -> torch.Tensor: by the model's input ``nodes`` and ``edges``.""" return torch.cat([self._linear, self._quadratic]) - @overload - def sample(self, sampler: Sampler, as_tensor: Literal[True], **kwargs) -> torch.Tensor: ... - - @overload - def sample(self, sampler: Sampler, as_tensor: Literal[False], **kwargs) -> SampleSet: ... - - def sample( - self, - sampler: Sampler, - *, - prefactor: float, - linear_range: Optional[tuple[float, float]] = None, - quadratic_range: Optional[tuple[float, float]] = None, - device: Optional[torch.device] = None, - sample_params: Optional[dict] = None, - as_tensor: bool = True, - ) -> Union[torch.Tensor, SampleSet]: - """Sample from the Boltzmann machine. - - This method samples and converts a sample of spins to tensors and ensures they - are not aggregated---provided the aggregation information is retained in the - sample set. - - Args: - sampler (Sampler): The sampler used to sample from the model. - prefactor (float): The prefactor for which the Hamiltonian is scaled by. - This quantity is typically the temperature at which the sampler operates - at. Standard CPU-based samplers such as Metropolis- or Gibbs-based - samplers will often default to sampling at an unit temperature, thus a - unit prefactor should be used. In the case of a quantum annealer, a - reasonable choice of a prefactor is 1/beta where beta is the effective - inverse temperature and can be estimated using - :meth:`GraphRestrictedBoltzmannMachine.estimate_beta`. - linear_range (tuple[float, float], optional): Linear weights are clipped to - ``linear_range`` prior to sampling. This clipping occurs after the ``prefactor`` - scaling has been applied. When None, no clipping is applied. Defaults to None. - quadratic_range (tuple[float, float], optional): Quadratic weights are clipped to - ``quadratic_range`` prior to sampling. This clipping occurs after the ``prefactor`` - scaling has been applied. When None, no clipping is applied.Defaults to None. - device (torch.device, optional): The device of the constructed tensor. - If ``None`` and data is a tensor then the device of data is used. - If ``None`` and data is not a tensor then the result tensor is - constructed on the current device. - sample_params (dict, optional): Parameters of the `sampler.sample` method. - as_tensor (bool): Whether to return the sampleset as a tensor. - Defaults to ``True``. If ``False`` returns a ``dimod.SampleSet``. - - Returns: - torch.Tensor | SampleSet: Spins sampled from the model - (shape prescribed by ``sampler`` and ``sample_params``). - """ - if sample_params is None: - sample_params = dict() - h, J = self.to_ising(prefactor, linear_range, quadratic_range) - sample_set = spread(sampler.sample_ising(h, J, **sample_params)) - - if as_tensor: - return self.sampleset_to_tensor(sample_set, device=device) - - return sample_set - - def sampleset_to_tensor( - self, sample_set: SampleSet, device: Optional[torch.device] = None - ) -> torch.Tensor: - """Converts a ``dimod.SampleSet`` to a ``torch.Tensor`` using the node order of the class. - - Args: - sample_set (dimod.SampleSet): A sample set. - device (torch.device, optional): The device of the constructed tensor. - If ``None`` and data is a tensor then the device of data is used. - If ``None`` and data is not a tensor then the result tensor is constructed - on the current device. - - Returns: - torch.Tensor: The sample set as a ``torch.Tensor``. - """ - return sampleset_to_tensor(self._nodes, sample_set, device) - def quasi_objective( self, s_observed: torch.Tensor, diff --git a/dwave/plugins/torch/samplers/__init__.py b/dwave/plugins/torch/samplers/__init__.py index 932b865..294f10a 100755 --- a/dwave/plugins/torch/samplers/__init__.py +++ b/dwave/plugins/torch/samplers/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dwave.plugins.torch.samplers._base import * from dwave.plugins.torch.samplers.block_spin_sampler import * +from dwave.plugins.torch.samplers.dimod_sampler import * diff --git a/dwave/plugins/torch/samplers/_base.py b/dwave/plugins/torch/samplers/_base.py new file mode 100644 index 0000000..a38e8f3 --- /dev/null +++ b/dwave/plugins/torch/samplers/_base.py @@ -0,0 +1,104 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import copy + +import torch + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine + + +__all__ = ["TorchSampler"] + + +class TorchSampler(abc.ABC): + """Base class for all PyTorch plugin samplers.""" + + def __init__(self, refresh: bool = True) -> None: + self._parameters = {} + self._modules = {} + + if refresh: + self.refresh_parameters() + + def parameters(self): + """Parameters in the sampler.""" + for p in self._parameters.values(): + yield p + + def modules(self): + """Modules in the sampler.""" + for m in self._modules.values(): + yield m + + @abc.abstractmethod + def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: + """Abstract sample method.""" + + def to(self, *args, **kwargs): + """Performs Tensor dtype and/or device conversion on sampler parameters. + + See :meth:`torch.Tensor.to` for usage details.""" + # perform a shallow copy of the sampler to be returned + sampler = copy.copy(self) + parameters = {} + modules = {} + + for name, p in self._parameters.items(): + new_p = p.to(*args, **kwargs) + + # set attribute and update parameters + setattr(sampler, name, new_p) + parameters[name] = new_p + + for name, m in self._modules.items(): + new_m = m.to(*args, **kwargs) + + # set attribute and update modules + setattr(sampler, name, new_m) + modules[name] = new_m + + sampler._parameters = parameters + sampler._modules = modules + + return sampler + + def refresh_parameters(self, replace=True, clear=True): + """Refreshes the parameters and modules attributes in-place. + + Searches the sampler for any initialized torch parameters and modules + and adds them to the :attr:`TorchSampler_parameters` attribute, which + is used to update device or dtype using the + :meth:`TorchSampler.to` method. + + Args: + replace: Replace any previous parameters with new values. + clear: Clear the parameters attribute before adding new ones. + """ + if clear: + self._parameters.clear() + self._modules.clear() + + for attr_, val in self.__dict__.items(): + # NOTE: Only refreshes torch parameters and modules, but _not_ any + # GRBM models. Can be generalized if plugin gets a baseclass module. + if replace or attr_ not in self._parameters: + if isinstance(val, torch.Tensor): + self._parameters[attr_] = val + elif ( + isinstance(val, torch.nn.Module) and + not isinstance(val, GraphRestrictedBoltzmannMachine) + ): + self._modules[attr_] = val diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py new file mode 100644 index 0000000..0dce1bf --- /dev/null +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -0,0 +1,107 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from dimod import Sampler +import torch + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine +from dwave.plugins.torch.samplers._base import TorchSampler +from dwave.plugins.torch.utils import sampleset_to_tensor +from hybrid.composers import AggregatedSamples + +if TYPE_CHECKING: + import dimod + from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine + + +__all__ = ["DimodSampler"] + + +class DimodSampler(TorchSampler): + """PyTorch plugin wrapper for a dimod sampler. + + Args: + module (GraphRestrictedBoltzmannMachine): GraphRestrictedBoltzmannMachine module. Requires the + methods ``to_ising`` and ``nodes``. + sampler (dimod.Sampler): Dimod sampler. + prefactor (float): The prefactor for which the Hamiltonian is scaled by. + This quantity is typically the temperature at which the sampler operates + at. Standard CPU-based samplers such as Metropolis- or Gibbs-based + samplers will often default to sampling at an unit temperature, thus a + unit prefactor should be used. In the case of a quantum annealer, a + reasonable choice of a prefactor is 1/beta where beta is the effective + inverse temperature and can be estimated using + :meth:`GraphRestrictedBoltzmannMachine.estimate_beta`. + linear_range (tuple[float, float], optional): Linear weights are clipped to + ``linear_range`` prior to sampling. This clipping occurs after the ``prefactor`` + scaling has been applied. When None, no clipping is applied. Defaults to None. + quadratic_range (tuple[float, float], optional): Quadratic weights are clipped to + ``quadratic_range`` prior to sampling. This clipping occurs after the ``prefactor`` + scaling has been applied. When None, no clipping is applied.Defaults to None. + sample_kwargs (dict[str, Any]): Dictionary containing optional arguments for the dimod sampler. + """ + + def __init__( + self, + module: GraphRestrictedBoltzmannMachine, + sampler: dimod.Sampler, + prefactor: float | None = None, + linear_range: tuple[float, float] | None = None, + quadratic_range: tuple[float, float] | None = None, + sample_kwargs: dict[str, Any] | None = None + ) -> None: + self._module = module + + # use default prefactor value of 1 + self._prefactor = prefactor or 1 + + self._linear_range = linear_range + self._quadratic_range = quadratic_range + + self._sampler = sampler + self._sampler_params = sample_kwargs or {} + + # cached sample_set from latest sample + self._sample_set = None + + # adds all torch parameters to 'self._parameters' for automatic device/dtype + # update support unless 'refresh_parameters = False' + super().__init__() + + def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: + """Sample from the dimod sampler and return the corresponding tensor. + + The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set` + which is overwritten by subsequent calls. + + Args: + x (torch.Tensor): TODO + """ + h, J = self._module.to_ising(self._prefactor, self._linear_range, self._quadratic_range) + self._sample_set = AggregatedSamples.spread(self._sampler.sample_ising(h, J, **self._sampler_params)) + + # use same device as modules linear + device = self._module._linear.device + return sampleset_to_tensor(self._module.nodes, self._sample_set, device) + + @property + def sample_set(self) -> dimod.SampleSet: + """The sample set returned from the latest sample call.""" + if self._sample_set is None: + raise AttributeError("no samples found; call 'sample()' first") + + return self._sample_set diff --git a/tests/test_samplers/test_annealing.py b/tests/test_samplers/test_annealing.py new file mode 100644 index 0000000..c4198e5 --- /dev/null +++ b/tests/test_samplers/test_annealing.py @@ -0,0 +1,15 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest diff --git a/tests/test_samplers/test_base.py b/tests/test_samplers/test_base.py new file mode 100644 index 0000000..2d3c711 --- /dev/null +++ b/tests/test_samplers/test_base.py @@ -0,0 +1,138 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from dwave.plugins.torch.samplers._base import TorchSampler + +class TestTorchSampler(unittest.TestCase): + """Test TorchSampler base class.""" + + def test_subclass_without_sample(self): + """Test creating a new subclass.""" + + class EmptySubClass(TorchSampler): + """Empty subclass without any methods.""" + + def something_else(self): + pass + + with self.assertRaises(TypeError): + EmptySubClass() # type: ignore + + def test_simple_subclass(self): + """Test creating a new subclass.""" + + expected_sample = torch.Tensor([1, 2, 3]) + + class SimpleSubClass(TorchSampler): + """Simple subclass with a dummy sample method.""" + + def sample(self, x: torch.Tensor | None = None): + return x or expected_sample + + simple_obj = SimpleSubClass() + + torch.testing.assert_close(simple_obj.sample(), expected_sample) + + def test_parameters(self): + """Test that parameters are correctly.""" + init_device = torch.device("cpu") + + parameters = { + "param_0": torch.nn.Parameter(torch.Tensor([2, 4, 8], device=init_device)), + "param_1": torch.Tensor([1, 1, 2], device=init_device), + } + + class SubClassWithParameters(TorchSampler): + """Simple subclass with a dummy sample method.""" + + def __init__(self) -> None: + self.param_0 = parameters["param_0"] + self.param_1 = parameters["param_1"] + + # refresh parameters + super().__init__(refresh=True) + + def sample(self, x: torch.Tensor | None = None): # type: ignore + pass + + params_obj = SubClassWithParameters() + + self.assertDictEqual(params_obj._parameters, parameters) + + with self.subTest("set device to meta"): + params_obj = params_obj.to(torch.device("meta")) + + for p in params_obj.parameters(): + self.assertEqual(p.device, torch.device("meta")) + + self.assertIs(params_obj.param_0, list(params_obj.parameters())[0]) + self.assertIs(params_obj.param_1, list(params_obj.parameters())[1]) + + with self.subTest("add new parameter on cpu"): + cpu_param = torch.Tensor([4, 2, 0], device=torch.device("cpu")) + setattr(params_obj, "cpu_param", cpu_param) + + # check that new param is _not_ part of parameters unless refreshed + self.assertNotIn(cpu_param, list(params_obj.parameters())) + + # refresh parameters and check again + params_obj.refresh_parameters() + self.assertIn(cpu_param, list(params_obj.parameters())) + + # check that new param has different device + self.assertEqual(params_obj._parameters["cpu_param"].device, torch.device("cpu")) + self.assertEqual(params_obj._parameters["param_0"].device, torch.device("meta")) + self.assertEqual(params_obj._parameters["param_1"].device, torch.device("meta")) + + # finally, check that setting 'params_obj' to meta device again works + params_obj = params_obj.to(torch.device("meta")) + for p in params_obj.parameters(): + self.assertEqual(p.device, torch.device("meta")) + + def test_module_parameters(self): + """Test that modules are correctly set.""" + init_device = torch.device("cpu") + + param = torch.nn.Conv2d(1, 20, 5).to(init_device) + + class SubClassWithModule(TorchSampler): + """Simple subclass with a dummy sample method.""" + + def __init__(self) -> None: + self.param = param + + # refresh parameters + super().__init__(refresh=True) + + def sample(self, x: torch.Tensor | None = None): # type: ignore + pass + + module_obj = SubClassWithModule() + + self.assertEqual(type(next(module_obj.modules())), type(param)) + + with self.subTest("set device to meta"): + res_device = torch.device("meta") + res_module = next(module_obj.modules()) + + # set all modules to device, recursively setting all the + # modules parameters to device + module_obj = module_obj.to(res_device) + + # assert that modules parameters have set device (just check one param) + self.assertEqual(next(res_module.parameters()).device, res_device) From 891e85220c71cca268fe39eecbe29c6b66ee053e Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 20 Jan 2026 22:57:22 -0800 Subject: [PATCH 18/39] Fix formatting --- dwave/plugins/torch/samplers/_base.py | 1 - dwave/plugins/torch/samplers/dimod_sampler.py | 15 +++++++++++---- tests/test_samplers/test_base.py | 5 +++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/dwave/plugins/torch/samplers/_base.py b/dwave/plugins/torch/samplers/_base.py index a38e8f3..c6b2020 100644 --- a/dwave/plugins/torch/samplers/_base.py +++ b/dwave/plugins/torch/samplers/_base.py @@ -19,7 +19,6 @@ from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine - __all__ = ["TorchSampler"] diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index 0dce1bf..e8c8574 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -15,16 +15,17 @@ from typing import TYPE_CHECKING, Any -from dimod import Sampler import torch +from dimod import Sampler +from hybrid.composers import AggregatedSamples from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine from dwave.plugins.torch.samplers._base import TorchSampler from dwave.plugins.torch.utils import sampleset_to_tensor -from hybrid.composers import AggregatedSamples if TYPE_CHECKING: import dimod + from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine @@ -89,10 +90,16 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: which is overwritten by subsequent calls. Args: - x (torch.Tensor): TODO + x (torch.Tensor): A tensor of shape (``batch_size``, ``dim``) or (``batch_size``, ``n_nodes``) + interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will + be sampled; entries with +/-1 values will remain constant. """ + if x is not None: + raise NotImplementedError("Support for conditional sampling has not been implemented.") h, J = self._module.to_ising(self._prefactor, self._linear_range, self._quadratic_range) - self._sample_set = AggregatedSamples.spread(self._sampler.sample_ising(h, J, **self._sampler_params)) + self._sample_set = AggregatedSamples.spread( + self._sampler.sample_ising(h, J, **self._sampler_params) + ) # use same device as modules linear device = self._module._linear.device diff --git a/tests/test_samplers/test_base.py b/tests/test_samplers/test_base.py index 2d3c711..2171616 100644 --- a/tests/test_samplers/test_base.py +++ b/tests/test_samplers/test_base.py @@ -18,6 +18,7 @@ from dwave.plugins.torch.samplers._base import TorchSampler + class TestTorchSampler(unittest.TestCase): """Test TorchSampler base class.""" @@ -67,7 +68,7 @@ def __init__(self) -> None: # refresh parameters super().__init__(refresh=True) - def sample(self, x: torch.Tensor | None = None): # type: ignore + def sample(self, x: torch.Tensor | None = None): # type: ignore pass params_obj = SubClassWithParameters() @@ -119,7 +120,7 @@ def __init__(self) -> None: # refresh parameters super().__init__(refresh=True) - def sample(self, x: torch.Tensor | None = None): # type: ignore + def sample(self, x: torch.Tensor | None = None): # type: ignore pass module_obj = SubClassWithModule() From 5ce22b28804e241a33774ab073a26a1a7b080c98 Mon Sep 17 00:00:00 2001 From: kchern Date: Tue, 20 Jan 2026 22:58:12 -0800 Subject: [PATCH 19/39] Restore GRBM methods and show deprecation warnings --- .../plugins/torch/models/boltzmann_machine.py | 86 ++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 9f36756..812f37f 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -32,12 +32,12 @@ import torch if TYPE_CHECKING: - from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler from dimod import Sampler, SampleSet from dimod import BinaryQuadraticModel from hybrid.composers import AggregatedSamples +from dwave.plugins.torch.utils import sampleset_to_tensor from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple spread = AggregatedSamples.spread @@ -45,6 +45,8 @@ __all__ = ["GraphRestrictedBoltzmannMachine"] +_SAMPLING_DEPRECATION_MESSAGE = "Use `dwave.plugins.torch.samplers` module for all sampling-related tasks instead." + class GraphRestrictedBoltzmannMachine(torch.nn.Module): """Creates a graph-restricted Boltzmann machine. @@ -261,6 +263,88 @@ def theta(self) -> torch.Tensor: by the model's input ``nodes`` and ``edges``.""" return torch.cat([self._linear, self._quadratic]) + @overload + def sample(self, sampler: Sampler, as_tensor: Literal[True], **kwargs) -> torch.Tensor: + warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + + @overload + def sample(self, sampler: Sampler, as_tensor: Literal[False], **kwargs) -> SampleSet: + warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + + def sample( + self, + sampler: Sampler, + *, + prefactor: float, + linear_range: Optional[tuple[float, float]] = None, + quadratic_range: Optional[tuple[float, float]] = None, + device: Optional[torch.device] = None, + sample_params: Optional[dict] = None, + as_tensor: bool = True, + ) -> Union[torch.Tensor, SampleSet]: + """Sample from the Boltzmann machine. + + This method samples and converts a sample of spins to tensors and ensures they + are not aggregated---provided the aggregation information is retained in the + sample set. + + Args: + sampler (Sampler): The sampler used to sample from the model. + prefactor (float): The prefactor for which the Hamiltonian is scaled by. + This quantity is typically the temperature at which the sampler operates + at. Standard CPU-based samplers such as Metropolis- or Gibbs-based + samplers will often default to sampling at an unit temperature, thus a + unit prefactor should be used. In the case of a quantum annealer, a + reasonable choice of a prefactor is 1/beta where beta is the effective + inverse temperature and can be estimated using + :meth:`GraphRestrictedBoltzmannMachine.estimate_beta`. + linear_range (tuple[float, float], optional): Linear weights are clipped to + ``linear_range`` prior to sampling. This clipping occurs after the ``prefactor`` + scaling has been applied. When None, no clipping is applied. Defaults to None. + quadratic_range (tuple[float, float], optional): Quadratic weights are clipped to + ``quadratic_range`` prior to sampling. This clipping occurs after the ``prefactor`` + scaling has been applied. When None, no clipping is applied.Defaults to None. + device (torch.device, optional): The device of the constructed tensor. + If ``None`` and data is a tensor then the device of data is used. + If ``None`` and data is not a tensor then the result tensor is + constructed on the current device. + sample_params (dict, optional): Parameters of the `sampler.sample` method. + as_tensor (bool): Whether to return the sampleset as a tensor. + Defaults to ``True``. If ``False`` returns a ``dimod.SampleSet``. + + Returns: + torch.Tensor | SampleSet: Spins sampled from the model + (shape prescribed by ``sampler`` and ``sample_params``). + """ + warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + if sample_params is None: + sample_params = dict() + h, J = self.to_ising(prefactor, linear_range, quadratic_range) + sample_set = spread(sampler.sample_ising(h, J, **sample_params)) + + if as_tensor: + return self.sampleset_to_tensor(sample_set, device=device) + + return sample_set + + def sampleset_to_tensor( + self, sample_set: SampleSet, device: Optional[torch.device] = None + ) -> torch.Tensor: + """Converts a ``dimod.SampleSet`` to a ``torch.Tensor`` using the node order of the class. + + Args: + sample_set (dimod.SampleSet): A sample set. + device (torch.device, optional): The device of the constructed tensor. + If ``None`` and data is a tensor then the device of data is used. + If ``None`` and data is not a tensor then the result tensor is constructed + on the current device. + + Returns: + torch.Tensor: The sample set as a ``torch.Tensor``. + """ + warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + return sampleset_to_tensor(self._nodes, sample_set, device) + def quasi_objective( self, s_observed: torch.Tensor, From 867c584219c83773dca3b884d613cbaf83e7fcad Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 22 Jan 2026 15:42:58 -0800 Subject: [PATCH 20/39] Rename module field to grbm --- dwave/plugins/torch/samplers/dimod_sampler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index e8c8574..4aa52e1 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -58,14 +58,14 @@ class DimodSampler(TorchSampler): def __init__( self, - module: GraphRestrictedBoltzmannMachine, + grbm: GraphRestrictedBoltzmannMachine, sampler: dimod.Sampler, prefactor: float | None = None, linear_range: tuple[float, float] | None = None, quadratic_range: tuple[float, float] | None = None, sample_kwargs: dict[str, Any] | None = None ) -> None: - self._module = module + self._grbm = grbm # use default prefactor value of 1 self._prefactor = prefactor or 1 @@ -96,14 +96,14 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: """ if x is not None: raise NotImplementedError("Support for conditional sampling has not been implemented.") - h, J = self._module.to_ising(self._prefactor, self._linear_range, self._quadratic_range) + h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) self._sample_set = AggregatedSamples.spread( self._sampler.sample_ising(h, J, **self._sampler_params) ) # use same device as modules linear - device = self._module._linear.device - return sampleset_to_tensor(self._module.nodes, self._sample_set, device) + device = self._grbm._linear.device + return sampleset_to_tensor(self._grbm.nodes, self._sample_set, device) @property def sample_set(self) -> dimod.SampleSet: From 581383de8a42f91893a58fe93e13bb688a2eaae1 Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 22 Jan 2026 21:15:47 -0800 Subject: [PATCH 21/39] Remove default value for prefactor --- dwave/plugins/torch/samplers/dimod_sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index 4aa52e1..f4c999b 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -60,15 +60,14 @@ def __init__( self, grbm: GraphRestrictedBoltzmannMachine, sampler: dimod.Sampler, - prefactor: float | None = None, + prefactor: float, linear_range: tuple[float, float] | None = None, quadratic_range: tuple[float, float] | None = None, sample_kwargs: dict[str, Any] | None = None ) -> None: self._grbm = grbm - # use default prefactor value of 1 - self._prefactor = prefactor or 1 + self._prefactor = prefactor self._linear_range = linear_range self._quadratic_range = quadratic_range From 76ecc295d2fc58907680649315c3e8706c03b5c5 Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 22 Jan 2026 21:16:06 -0800 Subject: [PATCH 22/39] Add unit tests for DimodSampler --- tests/test_samplers/test_annealing.py | 15 --- tests/test_samplers/test_dimod_sampler.py | 147 ++++++++++++++++++++++ 2 files changed, 147 insertions(+), 15 deletions(-) delete mode 100644 tests/test_samplers/test_annealing.py create mode 100644 tests/test_samplers/test_dimod_sampler.py diff --git a/tests/test_samplers/test_annealing.py b/tests/test_samplers/test_annealing.py deleted file mode 100644 index c4198e5..0000000 --- a/tests/test_samplers/test_annealing.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025 D-Wave -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest diff --git a/tests/test_samplers/test_dimod_sampler.py b/tests/test_samplers/test_dimod_sampler.py new file mode 100644 index 0000000..59e7b88 --- /dev/null +++ b/tests/test_samplers/test_dimod_sampler.py @@ -0,0 +1,147 @@ +# Copyright 2026 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from dimod import SPIN, BinaryQuadraticModel, IdentitySampler, SampleSet, TrackingComposite +from parameterized import parameterized + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM +from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler +from dwave.samplers import SteepestDescentSampler +from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple + + +class TestDimodSampler(unittest.TestCase): + def setUp(self) -> None: + # Create a triangle graph with an additional dangling vertex + # a + # / | \ + # b--c d + # Note the node order is deliberately "dbac" in order to test variable orderings + self.nodes = list("dbac") + self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]] + self.n = 4 + + # Manually set the parameter weights for testing + dtype = torch.float32 + h = [0.0, 1, 2, 3] + + bm = GRBM(self.nodes, self.edges) + bm._linear.data = torch.tensor(h, dtype=dtype) + bm._quadratic.data = torch.tensor([1, 2, 3, 6], dtype=dtype) + + self.bm = bm + + self.ones = torch.ones(4).unsqueeze(0) + self.mones = -torch.ones(4).unsqueeze(0) + self.pmones = torch.tensor([[1, -1, 1, -1]], dtype=dtype) + self.mpones = torch.tensor([[-1, 1, -1, 1]], dtype=dtype) + + self.sample_1 = torch.vstack([self.ones, self.ones, self.ones, self.pmones]) + self.sample_2 = torch.vstack([self.ones, self.ones, self.ones, self.mpones]) + return super().setUp() + + def test_sample(self): + grbm = GRBM(list("abcd"), [("a", "b"), ("a", "c"), ("a", "d"), ("b", "c")]) + + with self.subTest("Spins should be identical to input."): + initial_states = [[1, 1, 1, 1], + [1, 1, 1, 1], + [-1, -1, 1, -1]] + sampler = DimodSampler(grbm, IdentitySampler(), + prefactor=1, linear_range=None, quadratic_range=None, + sample_kwargs=dict(initial_states=(initial_states, "abcd"))) + spins = sampler.sample() + self.assertIsInstance(spins, torch.Tensor) + self.assertTupleEqual((3, 4), tuple(spins.shape)) + self.assertListEqual(initial_states, spins.tolist()) + + with self.subTest("Prefactor should scale weights up."): + grbm.linear.data[:] = 1 + grbm.quadratic.data[:] = -1 + prefactor = 12345 + tracker = TrackingComposite(SteepestDescentSampler()) + sampler = DimodSampler(grbm, tracker, + prefactor=prefactor, linear_range=None, quadratic_range=None, + sample_kwargs=dict()) + sampler.sample() + self.assertDictEqual(tracker.input['h'], dict(zip(grbm.nodes, [prefactor]*4))) + self.assertDictEqual(tracker.input['J'], dict(zip(grbm.edges, [-prefactor]*4))) + + with self.subTest("Linear weights should be clipped to be 0."): + grbm.linear.data[:] = torch.tensor([-2, -0.002, 0.002, 3]) + tracker = TrackingComposite(SteepestDescentSampler()) + sampler = DimodSampler(grbm, tracker, + prefactor=100, linear_range=[0, 0], quadratic_range=None, + sample_kwargs=dict()) + sampler.sample() + torch.testing.assert_close( + torch.tensor(list(tracker.input['h'].values())), + torch.tensor([0, 0, 0, 0.0]) + ) + with self.subTest("Linear weights should be clipped to be within range."): + grbm.linear.data[:] = torch.tensor([-2, -0.002, 0.002, 3]) + tracker = TrackingComposite(SteepestDescentSampler()) + sampler = DimodSampler(grbm, tracker, + prefactor=100, linear_range=[-1, 1], quadratic_range=None, + sample_kwargs=dict()) + sampler.sample() + torch.testing.assert_close( + torch.tensor(list(tracker.input['h'].values())), + torch.tensor([-1, -0.2, 0.2, 1]) + ) + with self.subTest("Quadratic weights should be clipped to be within range."): + grbm.quadratic.data[:] = torch.tensor([-2, -0.002, 0.002, 3]) + tracker = TrackingComposite(SteepestDescentSampler()) + sampler = DimodSampler(grbm, tracker, + prefactor=100, linear_range=None, quadratic_range=[-1, 1], + sample_kwargs=dict()) + sampler.sample() + torch.testing.assert_close( + torch.tensor(list(tracker.input['J'].values())), + torch.tensor([-1, -0.2, 0.2, 1]) + ) + with self.subTest("Quadratic weights should be clipped to be 0."): + grbm.quadratic.data[:] = torch.tensor([-2, -0.002, 0.002, 3]) + tracker = TrackingComposite(SteepestDescentSampler()) + sampler = DimodSampler(grbm, tracker, + prefactor=100, linear_range=None, quadratic_range=[0, 0], + sample_kwargs=dict()) + sampler.sample() + torch.testing.assert_close( + torch.tensor(list(tracker.input['J'].values())), + torch.tensor([0, 0, 0, 0.0]) + ) + + def test_sample_set(self): + grbm = GRBM(list("abcd"), [("a", "b")]) + initial_states = [[1, 1, 1, 1], + [1, 1, 1, 1], + [-1, -1, 1, -1]] + sampler = DimodSampler(grbm, IdentitySampler(), + prefactor=1, linear_range=None, quadratic_range=None, + sample_kwargs=dict(initial_states=(initial_states, "abcd"))) + with self.subTest("Accessing `sample_set` field before sampling should raise an error."): + with self.assertRaisesRegex(AttributeError, "no samples found"): + sampler.sample_set + + sampler.sample() + with self.subTest("The `sample_set` attribute should be of type `dimod.SampleSet`."): + self.assertTrue(isinstance(sampler.sample_set, SampleSet)) + + +if __name__ == "__main__": + unittest.main() From 1956d7be7952a4bf6938009d99b4290925e05a80 Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 23 Jan 2026 11:50:48 -0800 Subject: [PATCH 23/39] Add sampler example --- README.rst | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/README.rst b/README.rst index 383726a..d726968 100644 --- a/README.rst +++ b/README.rst @@ -6,6 +6,49 @@ hybrid solvers and the PyTorch framework, including a Torch neural network module for building and training Boltzmann Machines along with various sampler utility functions. +Example +------- +Boltzmann Machines are probabilistic generative models for high-dimensional binary data. +The following example walks through a typical workflow for fitting Boltzmann Machines via maximum likelihood. + +Define a Graph-Restricted Boltzmann Machine with a cycle of length four: +.. code-block:: python + import torch + from torch.optim import SGD + + from dwave.plugins.torch.models import GraphRestrictedBoltzmannMachine as GRBM + from dwave.plugins.torch.samplers import BlockSampler + + grbm = GRBM(nodes=["a", "b", "c", "d"], edges=[("a", "b"), ("b", "c"), ("c", "d"), ("d", "a")]) + print("Linear weights:", grbm.linear) + print("Quadratic weights:", grbm.quadratic) + +The following instantiates a block-Gibbs sampler. +Variables "a" and "c" are in block 0; variables "b" and "d" are in block 1. +The sampler consists of three parallel Markov chains of length ten each. +Each Markov chain samples at a constant unit inverse temperature. + +.. code-block:: python + sampler = BlockSampler(grbm=grbm, crayon=lambda v: v in {"b", "d"}, num_chains=3, schedule=[1]*10) + +Create a batch of data and perform one likelihood-optimization step +.. code-block:: python + # Example optimization step for maximizing the likelihood of dummy data + x_data = torch.tensor([[1, -1, 1, -1], [-1, 1, 1, 1]], dtype=torch.float32) + optimizer = SGD(grbm.parameters(), lr=1) + x_model = sampler.sample() + grbm.quasi_objective(x_data, x_model).backward() + optimizer.step() + print("Updated quadratic weights:", grbm.quadratic) + +To use a [dimod](https://github.com/dwavesystems/dimod/) sampler, replace the `sampler = BlockSampler(...)` line with +.. code-block:: python + from dwave.plugins.torch.samplers import DimodSampler + from dwave.samplers import RandomSampler + sampler = DimodSampler(grbm=grbm, sampler=RandomSampler(), + prefactor=1, sample_kwargs=dict(num_reads=5)) + + License ------- From 3dbf68152849196fd8c7f366dbfcc0be96c5a6d7 Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 23 Jan 2026 11:58:41 -0800 Subject: [PATCH 24/39] Fix formatting --- README.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.rst b/README.rst index d726968..622c121 100644 --- a/README.rst +++ b/README.rst @@ -12,6 +12,7 @@ Boltzmann Machines are probabilistic generative models for high-dimensional bina The following example walks through a typical workflow for fitting Boltzmann Machines via maximum likelihood. Define a Graph-Restricted Boltzmann Machine with a cycle of length four: + .. code-block:: python import torch from torch.optim import SGD @@ -32,6 +33,7 @@ Each Markov chain samples at a constant unit inverse temperature. sampler = BlockSampler(grbm=grbm, crayon=lambda v: v in {"b", "d"}, num_chains=3, schedule=[1]*10) Create a batch of data and perform one likelihood-optimization step + .. code-block:: python # Example optimization step for maximizing the likelihood of dummy data x_data = torch.tensor([[1, -1, 1, -1], [-1, 1, 1, 1]], dtype=torch.float32) @@ -42,6 +44,7 @@ Create a batch of data and perform one likelihood-optimization step print("Updated quadratic weights:", grbm.quadratic) To use a [dimod](https://github.com/dwavesystems/dimod/) sampler, replace the `sampler = BlockSampler(...)` line with + .. code-block:: python from dwave.plugins.torch.samplers import DimodSampler from dwave.samplers import RandomSampler From 951be8295bdaf69439d29863b9d6bf6169b77770 Mon Sep 17 00:00:00 2001 From: Kevin Chern <32395608+kevinchern@users.noreply.github.com> Date: Fri, 23 Jan 2026 14:18:02 -0800 Subject: [PATCH 25/39] Fix formatting --- README.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 622c121..c061b9b 100644 --- a/README.rst +++ b/README.rst @@ -14,6 +14,7 @@ The following example walks through a typical workflow for fitting Boltzmann Mac Define a Graph-Restricted Boltzmann Machine with a cycle of length four: .. code-block:: python + import torch from torch.optim import SGD @@ -35,7 +36,7 @@ Each Markov chain samples at a constant unit inverse temperature. Create a batch of data and perform one likelihood-optimization step .. code-block:: python - # Example optimization step for maximizing the likelihood of dummy data + x_data = torch.tensor([[1, -1, 1, -1], [-1, 1, 1, 1]], dtype=torch.float32) optimizer = SGD(grbm.parameters(), lr=1) x_model = sampler.sample() @@ -43,9 +44,10 @@ Create a batch of data and perform one likelihood-optimization step optimizer.step() print("Updated quadratic weights:", grbm.quadratic) -To use a [dimod](https://github.com/dwavesystems/dimod/) sampler, replace the `sampler = BlockSampler(...)` line with +To use a `dimod `_ sampler, replace the `sampler = BlockSampler(...)` line with .. code-block:: python + from dwave.plugins.torch.samplers import DimodSampler from dwave.samplers import RandomSampler sampler = DimodSampler(grbm=grbm, sampler=RandomSampler(), From d61879683f0f0898dcce339e7503a33ba1493566 Mon Sep 17 00:00:00 2001 From: Kevin Chern <32395608+kevinchern@users.noreply.github.com> Date: Fri, 23 Jan 2026 14:21:16 -0800 Subject: [PATCH 26/39] Fix formatting for inline code --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index c061b9b..09ddfc8 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ Example Boltzmann Machines are probabilistic generative models for high-dimensional binary data. The following example walks through a typical workflow for fitting Boltzmann Machines via maximum likelihood. -Define a Graph-Restricted Boltzmann Machine with a cycle of length four: +Define a Graph-Restricted Boltzmann Machine with a square graph: .. code-block:: python @@ -44,7 +44,7 @@ Create a batch of data and perform one likelihood-optimization step optimizer.step() print("Updated quadratic weights:", grbm.quadratic) -To use a `dimod `_ sampler, replace the `sampler = BlockSampler(...)` line with +To use a `dimod `_ sampler, replace the :code:`sampler = BlockSampler(...)` line with .. code-block:: python From dd91fc86609535a77b14827ee971317eb4a8ac69 Mon Sep 17 00:00:00 2001 From: Kevin Chern <32395608+kevinchern@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:30:39 -0800 Subject: [PATCH 27/39] Polish README --- README.rst | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 09ddfc8..b3ac0cf 100644 --- a/README.rst +++ b/README.rst @@ -1,17 +1,16 @@ D-Wave PyTorch Plugin ===================== -This plugin provides an interface between D-Wave's quantum-classical -hybrid solvers and the PyTorch framework, including a Torch neural -network module for building and training Boltzmann Machines along with -various sampler utility functions. +This plugin provides an interface between D-Wave's quantum computers and +the PyTorch framework, including neural network modules for building +and training Boltzmann Machines along with various sampler utility functions. Example ------- Boltzmann Machines are probabilistic generative models for high-dimensional binary data. The following example walks through a typical workflow for fitting Boltzmann Machines via maximum likelihood. -Define a Graph-Restricted Boltzmann Machine with a square graph: +Define a Graph-Restricted Boltzmann Machine with a square graph .. code-block:: python @@ -25,14 +24,17 @@ Define a Graph-Restricted Boltzmann Machine with a square graph: print("Linear weights:", grbm.linear) print("Quadratic weights:", grbm.quadratic) -The following instantiates a block-Gibbs sampler. + +Instantiate a `block-Gibbs sampler `_. Variables "a" and "c" are in block 0; variables "b" and "d" are in block 1. The sampler consists of three parallel Markov chains of length ten each. Each Markov chain samples at a constant unit inverse temperature. .. code-block:: python + sampler = BlockSampler(grbm=grbm, crayon=lambda v: v in {"b", "d"}, num_chains=3, schedule=[1]*10) + Create a batch of data and perform one likelihood-optimization step .. code-block:: python From 95a2f3356107bd54f089581ede1bf2a0e6ae358c Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Fri, 23 Jan 2026 17:27:27 -0800 Subject: [PATCH 28/39] Add suggestions from code review --- .../plugins/torch/models/boltzmann_machine.py | 24 ++++++++++++------- dwave/plugins/torch/samplers/__init__.py | 2 +- .../torch/samplers/{_base.py => base.py} | 0 dwave/plugins/torch/samplers/dimod_sampler.py | 3 ++- tests/test_samplers/test_base.py | 2 +- 5 files changed, 20 insertions(+), 11 deletions(-) rename dwave/plugins/torch/samplers/{_base.py => base.py} (100%) diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 812f37f..ff8739f 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -45,8 +45,6 @@ __all__ = ["GraphRestrictedBoltzmannMachine"] -_SAMPLING_DEPRECATION_MESSAGE = "Use `dwave.plugins.torch.samplers` module for all sampling-related tasks instead." - class GraphRestrictedBoltzmannMachine(torch.nn.Module): """Creates a graph-restricted Boltzmann machine. @@ -264,12 +262,10 @@ def theta(self) -> torch.Tensor: return torch.cat([self._linear, self._quadratic]) @overload - def sample(self, sampler: Sampler, as_tensor: Literal[True], **kwargs) -> torch.Tensor: - warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + def sample(self, sampler: Sampler, as_tensor: Literal[True], **kwargs) -> torch.Tensor: ... @overload - def sample(self, sampler: Sampler, as_tensor: Literal[False], **kwargs) -> SampleSet: - warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + def sample(self, sampler: Sampler, as_tensor: Literal[False], **kwargs) -> SampleSet: ... def sample( self, @@ -316,7 +312,13 @@ def sample( torch.Tensor | SampleSet: Spins sampled from the model (shape prescribed by ``sampler`` and ``sample_params``). """ - warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + warnings.warn( + f"`{self.__class__}.sample()()` is deprecated since dwave-pytorch-plugin " + "0.3.0 and will be removed in 0.4.0. Use Use `dwave.plugins.torch.samplers` module " + "for all sampling-related tasks instead." + , DeprecationWarning + ) + if sample_params is None: sample_params = dict() h, J = self.to_ising(prefactor, linear_range, quadratic_range) @@ -342,7 +344,13 @@ def sampleset_to_tensor( Returns: torch.Tensor: The sample set as a ``torch.Tensor``. """ - warnings.warn(_SAMPLING_DEPRECATION_MESSAGE, DeprecationWarning) + warnings.warn( + f"`{self.__class__}.sampleset_to_tensor()` is deprecated since dwave-pytorch-plugin " + "0.3.0 and will be removed in 0.4.0. Use Use `dwave.plugins.torch.samplers` module " + "for all sampling-related tasks instead." + , DeprecationWarning + ) + return sampleset_to_tensor(self._nodes, sample_set, device) def quasi_objective( diff --git a/dwave/plugins/torch/samplers/__init__.py b/dwave/plugins/torch/samplers/__init__.py index 294f10a..6dc3ad1 100755 --- a/dwave/plugins/torch/samplers/__init__.py +++ b/dwave/plugins/torch/samplers/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dwave.plugins.torch.samplers._base import * +from dwave.plugins.torch.samplers.base import * from dwave.plugins.torch.samplers.block_spin_sampler import * from dwave.plugins.torch.samplers.dimod_sampler import * diff --git a/dwave/plugins/torch/samplers/_base.py b/dwave/plugins/torch/samplers/base.py similarity index 100% rename from dwave/plugins/torch/samplers/_base.py rename to dwave/plugins/torch/samplers/base.py diff --git a/dwave/plugins/torch/samplers/dimod_sampler.py b/dwave/plugins/torch/samplers/dimod_sampler.py index f4c999b..1e2c992 100644 --- a/dwave/plugins/torch/samplers/dimod_sampler.py +++ b/dwave/plugins/torch/samplers/dimod_sampler.py @@ -20,7 +20,7 @@ from hybrid.composers import AggregatedSamples from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine -from dwave.plugins.torch.samplers._base import TorchSampler +from dwave.plugins.torch.samplers.base import TorchSampler from dwave.plugins.torch.utils import sampleset_to_tensor if TYPE_CHECKING: @@ -95,6 +95,7 @@ def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: """ if x is not None: raise NotImplementedError("Support for conditional sampling has not been implemented.") + h, J = self._grbm.to_ising(self._prefactor, self._linear_range, self._quadratic_range) self._sample_set = AggregatedSamples.spread( self._sampler.sample_ising(h, J, **self._sampler_params) diff --git a/tests/test_samplers/test_base.py b/tests/test_samplers/test_base.py index 2171616..8c55988 100644 --- a/tests/test_samplers/test_base.py +++ b/tests/test_samplers/test_base.py @@ -16,7 +16,7 @@ import torch -from dwave.plugins.torch.samplers._base import TorchSampler +from dwave.plugins.torch.samplers.base import TorchSampler class TestTorchSampler(unittest.TestCase): From 78de3135efa93a6d2a065b6feca2b7783f054dfc Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Fri, 23 Jan 2026 17:48:51 -0800 Subject: [PATCH 29/39] Make BlockSampler sublass of TorchSampler --- .../torch/samplers/block_spin_sampler.py | 26 ++++++++++--------- tests/test_block_sampler.py | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/dwave/plugins/torch/samplers/block_spin_sampler.py b/dwave/plugins/torch/samplers/block_spin_sampler.py index cadb256..1138742 100644 --- a/dwave/plugins/torch/samplers/block_spin_sampler.py +++ b/dwave/plugins/torch/samplers/block_spin_sampler.py @@ -25,13 +25,14 @@ from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM from torch._prims_common import DeviceLikeType +from dwave.plugins.torch.samplers.base import TorchSampler from dwave.plugins.torch.nn.functional import bit2spin_soft from dwave.plugins.torch.tensor import randspin __all__ = ["BlockSampler"] -class BlockSampler: +class BlockSampler(TorchSampler): """A block-spin update sampler for graph-restricted Boltzmann machines. Due to the sparse definition of GRBMs, some tedious indexing tricks are required to @@ -68,7 +69,6 @@ def __init__(self, grbm: GRBM, crayon: Callable[[Hashable], Hashable], num_chain proposal_acceptance_criteria: Literal["Gibbs", "Metropolis"] = "Gibbs", initial_states: torch.Tensor | None = None, seed: int | None = None): - super().__init__() if num_chains < 1: raise ValueError("Number of reads should be a positive integer.") @@ -99,8 +99,12 @@ def __init__(self, grbm: GRBM, crayon: Callable[[Hashable], Hashable], num_chain self._x = nn.Parameter(initial_states.float(), requires_grad=False) self._zeros = nn.Parameter(torch.zeros((num_chains, 1)), requires_grad=False) + # call base sampler after setting parameters for correctly identifying them + # in super methods 'properties' and 'modules' + super().__init__() + def to(self, device: DeviceLikeType) -> BlockSampler: - """Moves sampler components to the target device. + """Creates a sampler copy with components moved to the target device. If the device is "meta", then the random number generator (RNG) will not be modified at all. For all other devices, all attributes used for performing @@ -117,17 +121,15 @@ def to(self, device: DeviceLikeType) -> BlockSampler: Args: device (DeviceLikeType): The target device. """ - self._x = self._x.to(device) - self._zeros = self._zeros.to(device) - self._schedule = self._schedule.to(device) - self._partition = self._partition.to(device) - self._padded_adjacencies = self._padded_adjacencies.to(device) - self._padded_adjacencies_weight = self._padded_adjacencies_weight.to(device) + sampler = super().to(device=device) + if device != "meta": rng = torch.Generator(device) - rng.manual_seed(torch.randint(0, 2**60, (1,), generator=self._rng).item()) - self._rng = rng - return self + rand_tensor = torch.randint(0, 2**60, (1,), generator=sampler._rng) + rng.manual_seed(int(rand_tensor.item())) + sampler._rng = rng + + return sampler def _prepare_initial_states( self, num_chains: int, initial_states: torch.Tensor | None = None, diff --git a/tests/test_block_sampler.py b/tests/test_block_sampler.py index 0df7d43..0d2d307 100755 --- a/tests/test_block_sampler.py +++ b/tests/test_block_sampler.py @@ -62,7 +62,7 @@ def test_device(self): crayon = self.crayon_veqa sample_size = 1_000_000 bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Gibbs", seed=2) - bss.to('meta') + bss = bss.to('meta') self.assertEqual("cpu", bss._grbm.linear.device.type) self.assertEqual("cpu", bss._grbm.quadratic.device.type) self.assertEqual("meta", bss._x.device.type) From ac377e56b7985a76732e2d2c7f18db07c98ad142 Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 23 Jan 2026 22:16:10 -0800 Subject: [PATCH 30/39] Rename to colouring --- .../torch/samplers/block_spin_sampler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dwave/plugins/torch/samplers/block_spin_sampler.py b/dwave/plugins/torch/samplers/block_spin_sampler.py index cadb256..e2b5983 100644 --- a/dwave/plugins/torch/samplers/block_spin_sampler.py +++ b/dwave/plugins/torch/samplers/block_spin_sampler.py @@ -46,7 +46,7 @@ class BlockSampler: Args: grbm (GRBM): The Graph-Restricted Boltzmann Machine to sample from. - crayon (Callable[Hashable, Hashable]): A colouring function that maps a single + colouring (Callable[Hashable, Hashable]): A colouring function that maps a single node of the ``grbm`` to its colour. num_chains (int): Number of Markov chains to run in parallel. initial_states (torch.Tensor | None): A tensor of +/-1 values of shape @@ -63,7 +63,7 @@ class BlockSampler: "Gibbs" or "Metropolis". """ - def __init__(self, grbm: GRBM, crayon: Callable[[Hashable], Hashable], num_chains: int, + def __init__(self, grbm: GRBM, colouring: Callable[[Hashable], Hashable], num_chains: int, schedule: Iterable[float], proposal_acceptance_criteria: Literal["Gibbs", "Metropolis"] = "Gibbs", initial_states: torch.Tensor | None = None, @@ -80,10 +80,10 @@ def __init__(self, grbm: GRBM, crayon: Callable[[Hashable], Hashable], num_chain ) self._grbm: GRBM = grbm - self._crayon: Callable[[Hashable], Hashable] = crayon - if not self._valid_crayon(): + self._colouring: Callable[[Hashable], Hashable] = colouring + if not self._valid_colouring(): raise ValueError( - "crayon is not a valid colouring of grbm. " + "`colouring` is not a valid colouring of grbm. " + "At least one edge has vertices of the same colour." ) @@ -165,14 +165,14 @@ def _prepare_initial_states( return initial_states - def _valid_crayon(self) -> bool: - """Determines whether ``crayon`` is a valid colouring of the graph-restricted Boltzmann machine. + def _valid_colouring(self) -> bool: + """Determines whether ``colouring`` is a valid colouring of the graph-restricted Boltzmann machine. Returns: bool: True if the colouring is valid and False otherwise. """ for u, v in self._grbm.edges: - if self._crayon(u) == self._crayon(v): + if self._colouring(u) == self._colouring(v): return False return True @@ -185,7 +185,7 @@ def _get_partition(self) -> nn.ParameterList: partition = defaultdict(list) for node in self._grbm.nodes: idx = self._grbm.node_to_idx[node] - c = self._crayon(node) + c = self._colouring(node) partition[c].append(idx) partition = nn.ParameterList([ nn.Parameter(torch.tensor(partition[k], requires_grad=False), requires_grad=False) From 22ba88990d21bee3b006839538f22220c1ad3c99 Mon Sep 17 00:00:00 2001 From: Ahmed Abdelaziz <71100796+abdela47@users.noreply.github.com> Date: Mon, 26 Jan 2026 23:29:59 +0400 Subject: [PATCH 31/39] Add unit tests for pseudo_kl_divergence_loss (#59) * Add unit tests for pseudo_kl_divergence_loss * Refine pseudo_kl_divergence_loss tests * Copyright & Apache License Reviewer Suggestion Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * updated the 2D test to derive batch_size and n_spins directly from spins_data.shape for clarity and consistency * Added 3D explanation, UnitTest Framework, Concise Test Naming * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Update tests/test_pseudo_kl_divergence.py Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> * Changing the name of Boltzmann class * correcting error cause by incomplete sugggestion note * Edited Missed class initializations --------- Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- tests/test_pseudo_kl_divergence.py | 157 +++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 tests/test_pseudo_kl_divergence.py diff --git a/tests/test_pseudo_kl_divergence.py b/tests/test_pseudo_kl_divergence.py new file mode 100644 index 0000000..a57e6cc --- /dev/null +++ b/tests/test_pseudo_kl_divergence.py @@ -0,0 +1,157 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for pseudo_kl_divergence_loss. + +These tests verify the *statistical structure* of the pseudo-KL divergence used +in the DVAE setting, not the correctness of the Boltzmann machine itself. + +In particular, we test that: +1) The loss matches the reference decomposition: + pseudo_KL = cross_entropy_with_prior - entropy_of_encoder +2) The function supports both documented spin shapes. +3) The gradient w.r.t. encoder logits behaves as expected. + +The tests intentionally use deterministic dummy Boltzmann machines to isolate +and validate the behavior of pseudo_kl_divergence_loss in isolation. +""" +import unittest + +import torch +import torch.nn.functional as F + +from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss + + +class UnitLinearBiasObjective: + """A minimal and deterministic stand-in for GraphRestrictedBoltzmannMachine. + + The purpose of this class is NOT to model a real Boltzmann machine. + Instead, it provides a simple, deterministic quasi_objective so that + we can verify how pseudo_kl_divergence_loss combines its terms. + """ + + def quasi_objective(self, spins_data: torch.Tensor, spins_model: torch.Tensor) -> torch.Tensor: + """Return a deterministic scalar representing a positive-minus-negative phase + objective, independent of encoder logits. + """ + return spins_data.float().mean() - spins_model.float().mean() + +class TestPseudoKLDivergenceLoss(unittest.TestCase): + """Unit tests for pseudo_kl_divergence_loss.""" + + def test_matches_reference_2d(self): + """Match explicit cross-entropy minus entropy reference for 2D spins.""" + + bm = UnitLinearBiasObjective() + + spins_data = torch.tensor( + [[-1, 1, -1, 1, -1, 1], + [1, -1, 1, -1, 1, -1], + [-1, -1, 1, 1, -1, 1], + [1, 1, -1, -1, 1, -1]], + dtype=torch.float32 + ) + + batch_size, n_spins = spins_data.shape + logits = torch.linspace(-2.0, 2.0, steps=batch_size * n_spins).reshape(batch_size, n_spins) + + spins_model = torch.ones(batch_size, n_spins, dtype=torch.float32) + + out = pseudo_kl_divergence_loss( + spins=spins_data, + logits=logits, + samples=spins_model, + boltzmann_machine=bm + ) + + probs = torch.sigmoid(logits) + entropy = F.binary_cross_entropy_with_logits(logits, probs) + cross_entropy = bm.quasi_objective(spins_data, spins_model) + ref = cross_entropy - entropy + + torch.testing.assert_close(out, ref) + + def test_supports_3d_spins(self): + """Support 3D spins of shape (batch_size, n_samples, n_spins) as documented.""" + bm = UnitLinearBiasObjective() + + batch_size, n_samples, n_spins = 3, 5, 4 + logits = torch.zeros(batch_size, n_spins) + # Zero logits are used in the 3D shape test to keep the entropy term simple and stable (p = 0.5), + # allowing the test to focus purely on documented shape support; nonzero values are covered in the + # 2D numerical correctness test. + + # spins: (batch_size, n_samples, n_spins) + spins_data = torch.ones(batch_size, n_samples, n_spins) + spins_model = torch.zeros(batch_size, n_spins) + + out = pseudo_kl_divergence_loss( + spins=spins_data, + logits=logits, + samples=spins_model, + boltzmann_machine=bm + ) + + probs = torch.sigmoid(logits) + entropy = F.binary_cross_entropy_with_logits(logits, probs) + cross_entropy = bm.quasi_objective(spins_data, spins_model) + + torch.testing.assert_close(out, cross_entropy - entropy) + + + def test_gradient_from_entropy_only(self): + """Verify gradient behavior of pseudo_kl_divergence_loss. + + If the Boltzmann machine quasi_objective returns a constant value, + then the loss gradient w.r.t. logits must come entirely from the + negative entropy term. + + This test ensures that pseudo_kl_divergence_loss applies the correct + statistical pressure on encoder logits. + """ + + class ConstantObjectiveBM: + def quasi_objective(self, spins_data: torch.Tensor, + spins_model: torch.Tensor) -> torch.Tensor: + # Constant => contributes no gradient wrt logits + return torch.tensor(1.2345, dtype=spins_data.dtype, device=spins_data.device) + + bm = ConstantObjectiveBM() + + batch_size, n_spins = 2, 3 + + logits = torch.randn(batch_size, n_spins, requires_grad=True) + spins_data = torch.ones(batch_size, n_spins) + spins_model = torch.zeros(batch_size, n_spins) + + out = pseudo_kl_divergence_loss( + spins=spins_data, + logits=logits, + samples=spins_model, + boltzmann_machine=bm + ) + + out.backward() + + # reference gradient from -entropy only + logits2 = logits.detach().clone().requires_grad_(True) + probs2 = torch.sigmoid(logits2) + entropy2 = F.binary_cross_entropy_with_logits(logits2, probs2) + (-entropy2).backward() + + torch.testing.assert_close(logits.grad, logits2.grad) + +if __name__ == "__main__": + unittest.main() From efa2a9f05f81065b8db49026f0c10b0d6b78f64b Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Mon, 26 Jan 2026 11:55:12 -0800 Subject: [PATCH 32/39] Add missing release notes for sampler submodule --- .../add-samplers-submodule-7868a19c103d81cb.yaml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 releasenotes/notes/add-samplers-submodule-7868a19c103d81cb.yaml diff --git a/releasenotes/notes/add-samplers-submodule-7868a19c103d81cb.yaml b/releasenotes/notes/add-samplers-submodule-7868a19c103d81cb.yaml new file mode 100644 index 0000000..e1172b7 --- /dev/null +++ b/releasenotes/notes/add-samplers-submodule-7868a19c103d81cb.yaml @@ -0,0 +1,14 @@ +--- +features: + - Add a samplers submodule with a baseclass ``TorchSampler``. + - Add a new PyTorch plugin wrapper for a dimod sampler named ``DimodSampler``. +upgrade: + - Update ``BlockSampler`` to inherits from the new ``TorchSampler``. + - | + ``BlockSampler.to()`` now returns a copy with components moved to the set device, instead + of performing the device change it in-place. +deprecations: + - | + Deprecate ``GraphRestrictedBoltzmannMachine.sample()`` and + ``GraphRestrictedBoltzmannMachine.sampleset_to_tensor`` in favour of the new + ``dwave.plugins.torch.samplers`` module. Will be removed in version 0.4.0." From f2a7c67527540607cff6380c4aa7c7ed78af7a75 Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Mon, 26 Jan 2026 11:55:21 -0800 Subject: [PATCH 33/39] Increment version number for release --- dwave/plugins/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwave/plugins/torch/__init__.py b/dwave/plugins/torch/__init__.py index faa6890..a4630b0 100644 --- a/dwave/plugins/torch/__init__.py +++ b/dwave/plugins/torch/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0" +__version__ = "0.3.0" From b67bf2438f1ffd5ee11d39cf87e2ff02371d5cd0 Mon Sep 17 00:00:00 2001 From: kchern Date: Mon, 26 Jan 2026 18:23:29 -0800 Subject: [PATCH 34/39] Add release note --- releasenotes/notes/rename-colour-2dc56d0cb0a011e2.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100755 releasenotes/notes/rename-colour-2dc56d0cb0a011e2.yaml diff --git a/releasenotes/notes/rename-colour-2dc56d0cb0a011e2.yaml b/releasenotes/notes/rename-colour-2dc56d0cb0a011e2.yaml new file mode 100755 index 0000000..fecf618 --- /dev/null +++ b/releasenotes/notes/rename-colour-2dc56d0cb0a011e2.yaml @@ -0,0 +1,4 @@ +--- +other: + - | + Rename the ``crayon`` argument in `BlockSampler`` to ``colouring``. From 6d35940eead3203523373a55e71a5651621f154b Mon Sep 17 00:00:00 2001 From: Theodor Isacsson Date: Mon, 26 Jan 2026 20:49:00 -0800 Subject: [PATCH 35/39] Increment version to 0.3.1 for release --- dwave/plugins/torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwave/plugins/torch/__init__.py b/dwave/plugins/torch/__init__.py index a4630b0..2c2c177 100644 --- a/dwave/plugins/torch/__init__.py +++ b/dwave/plugins/torch/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.0" +__version__ = "0.3.1" From 2a343add3971ca58187b834a3069f3e83552cac4 Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 13 Nov 2025 18:42:23 +0000 Subject: [PATCH 36/39] Add minimal example of MMD-AE --- examples/mmd_ae.py | 144 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 examples/mmd_ae.py diff --git a/examples/mmd_ae.py b/examples/mmd_ae.py new file mode 100644 index 0000000..87beeee --- /dev/null +++ b/examples/mmd_ae.py @@ -0,0 +1,144 @@ +from itertools import cycle +from math import prod +from os import makedirs + +import torch +from torch import nn +from torch.optim import SGD, AdamW +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms.v2 import Compose, ToDtype, ToImage +from torchvision.utils import make_grid, save_image + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM +from dwave.plugins.torch.nn import (ConvolutionNetwork, FullyConnectedNetwork, LinearBlock, + MaximumMeanDiscrepancy, RadialBasis, StraightThroughTanh, + rands_like, zephyr_subgraph) +from dwave.system import DWaveSampler + + +@torch.compile +class Autoencoder(nn.Module): + def __init__(self, shape, n_bits): + super().__init__() + dim = prod(shape) + c, h, w = shape + chidden = 1 + depth_fcnn = 3 + depth_cnn = 3 + dropout = 0.0 + self.encoder = nn.Sequential( + ConvolutionNetwork([chidden]*depth_cnn, shape), + nn.Flatten(), + FullyConnectedNetwork(chidden*h*w, n_bits, depth_fcnn, False, dropout), + ) + self.mixer = LinearBlock(n_bits, n_bits, False, dropout) + self.binarizer = StraightThroughTanh() + self.decoder = nn.Sequential( + FullyConnectedNetwork(n_bits, chidden*h*w, depth_fcnn, False, dropout), + nn.Unflatten(1, (chidden, h, w)), + ConvolutionNetwork([chidden]*(depth_cnn-1) + [1], (chidden, h, w)) + ) + + def decode(self, q): + z = self.mixer(q) + xhat = self.decoder(z) + return z, xhat + + def forward(self, x): + spins = self.binarizer(self.encoder(x)) + z, xhat = self.decode(spins) + return spins, z, xhat + + +def collect_stats(model, grbm, x, q, compute_mmd): + s, z, xhat = model(x) + zgen, xgen = model.decode(q) + stats = { + "quasi": grbm.quasi_objective(s.detach(), q), + "bce": nn.functional.binary_cross_entropy_with_logits(xhat, x), + "mmd": compute_mmd(s, q), + "mmd2": compute_mmd(z, zgen), + } + return stats + + +def get_dataset(bs, data_dir="/tmp/"): + transforms = Compose([ToImage(), ToDtype(torch.float32, scale=True)]) + train_kwargs = dict(root=data_dir, download=True) + transforms = Compose([transforms, lambda x: 1 - x]) + data_train = MNIST(transform=transforms, **train_kwargs) + train_loader = DataLoader(data_train, bs, True) + return train_loader + + +def round_graph_down(graph, group_size): + n_in = graph.number_of_nodes() + no = group_size*(n_in//group_size) + return graph.subgraph(list(graph.nodes)[:no]) + + +def run(*, num_steps): + sampler = DWaveSampler(solver="Advantage2_system1.7") + sample_params = dict(num_reads=500, annealing_time=0.5, answer_mode="raw", auto_scale=False) + h_range, j_range = sampler.properties["h_range"], sampler.properties["j_range"] + outdir = "output/mmd_ae/" + makedirs(outdir, exist_ok=True) + + device = "cuda" + + # Setup data + train_loader = get_dataset(500) + + # Instantiate model + G = zephyr_subgraph(sampler.to_networkx_graph(), 4) + nodes = list(G.nodes) + edges = list(G.edges) + grbm = GRBM(nodes, edges).to(device) + model = Autoencoder((1, 28, 28), grbm.n_nodes).to(device) + model.train() + grbm.train() + + compute_mmd = MaximumMeanDiscrepancy(RadialBasis()).to(device) + + opt_grbm = SGD(grbm.parameters(), lr=1e-3) + opt_ae = AdamW(model.parameters(), lr=1e-3) + + for step, (x, y) in enumerate(cycle(train_loader)): + torch.cuda.empty_cache() + if step > num_steps: + break + # Send data to device + x = x.to(device).float() + + q = grbm.sample(sampler, prefactor=1, linear_range=h_range, quadratic_range=j_range, + device=device, sample_params=sample_params) + + # Train autoencoder + stats = collect_stats(model, grbm, x, q, compute_mmd) + opt_ae.zero_grad() + (stats["bce"] + stats["mmd"] + stats["mmd2"]).backward() + opt_ae.step() + + # Train GRBM + if step < 1000: + # NOTE: collecting stats because the autoencoder has been updated. + stats = collect_stats(model, grbm, x, q, compute_mmd) + opt_grbm.zero_grad() + stats['quasi'].backward() + opt_grbm.step() + print(step, {k: v.item() for k, v in stats.items()}) + if step % 10 == 0: + with torch.no_grad(): + grbm.eval() + xgen = model.decode(q[:100])[-1] + xuni = model.decode(rands_like(q[:100]))[-1] + xhat = model(x[:100])[-1] + save_image(make_grid(xgen.sigmoid(), 10, pad_value=1), outdir + "xgen.png") + save_image(make_grid(xhat.sigmoid(), 10, pad_value=1), outdir + "xhat.png") + save_image(make_grid(xuni.sigmoid(), 10, pad_value=1), outdir + "xuni.png") + grbm.train() + + +if __name__ == "__main__": + run(num_steps=10_000) From 8353c75be7fd0f63b2ffb68377848a1ea6fb7288 Mon Sep 17 00:00:00 2001 From: kchern Date: Thu, 20 Nov 2025 23:26:12 +0000 Subject: [PATCH 37/39] Simplify model and use public solver --- examples/mmd_ae.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/examples/mmd_ae.py b/examples/mmd_ae.py index 87beeee..f7bebf3 100644 --- a/examples/mmd_ae.py +++ b/examples/mmd_ae.py @@ -11,9 +11,9 @@ from torchvision.utils import make_grid, save_image from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM -from dwave.plugins.torch.nn import (ConvolutionNetwork, FullyConnectedNetwork, LinearBlock, - MaximumMeanDiscrepancy, RadialBasis, StraightThroughTanh, - rands_like, zephyr_subgraph) +from dwave.plugins.torch.nn.modules import (ConvolutionNetwork, FullyConnectedNetwork, + MaximumMeanDiscrepancy, RadialBasis, + StraightThroughTanh, rands_like, zephyr_subgraph) from dwave.system import DWaveSampler @@ -32,7 +32,6 @@ def __init__(self, shape, n_bits): nn.Flatten(), FullyConnectedNetwork(chidden*h*w, n_bits, depth_fcnn, False, dropout), ) - self.mixer = LinearBlock(n_bits, n_bits, False, dropout) self.binarizer = StraightThroughTanh() self.decoder = nn.Sequential( FullyConnectedNetwork(n_bits, chidden*h*w, depth_fcnn, False, dropout), @@ -41,24 +40,21 @@ def __init__(self, shape, n_bits): ) def decode(self, q): - z = self.mixer(q) - xhat = self.decoder(z) - return z, xhat + xhat = self.decoder(q) + return xhat def forward(self, x): spins = self.binarizer(self.encoder(x)) - z, xhat = self.decode(spins) - return spins, z, xhat + xhat = self.decode(spins) + return spins, xhat def collect_stats(model, grbm, x, q, compute_mmd): - s, z, xhat = model(x) - zgen, xgen = model.decode(q) + s, xhat = model(x) stats = { "quasi": grbm.quasi_objective(s.detach(), q), "bce": nn.functional.binary_cross_entropy_with_logits(xhat, x), "mmd": compute_mmd(s, q), - "mmd2": compute_mmd(z, zgen), } return stats @@ -79,10 +75,10 @@ def round_graph_down(graph, group_size): def run(*, num_steps): - sampler = DWaveSampler(solver="Advantage2_system1.7") + sampler = DWaveSampler(solver="Advantage2_system1.8") sample_params = dict(num_reads=500, annealing_time=0.5, answer_mode="raw", auto_scale=False) h_range, j_range = sampler.properties["h_range"], sampler.properties["j_range"] - outdir = "output/mmd_ae/" + outdir = "output/example_mmd_ae/" makedirs(outdir, exist_ok=True) device = "cuda" @@ -117,7 +113,7 @@ def run(*, num_steps): # Train autoencoder stats = collect_stats(model, grbm, x, q, compute_mmd) opt_ae.zero_grad() - (stats["bce"] + stats["mmd"] + stats["mmd2"]).backward() + (stats["bce"] + stats["mmd"]).backward() opt_ae.step() # Train GRBM @@ -131,9 +127,10 @@ def run(*, num_steps): if step % 10 == 0: with torch.no_grad(): grbm.eval() - xgen = model.decode(q[:100])[-1] - xuni = model.decode(rands_like(q[:100]))[-1] + xgen = model.decode(q[:100]) + xuni = model.decode(rands_like(q[:100])) xhat = model(x[:100])[-1] + save_image(make_grid(x[:100], 10, pad_value=1), outdir + "x.png") save_image(make_grid(xgen.sigmoid(), 10, pad_value=1), outdir + "xgen.png") save_image(make_grid(xhat.sigmoid(), 10, pad_value=1), outdir + "xhat.png") save_image(make_grid(xuni.sigmoid(), 10, pad_value=1), outdir + "xuni.png") From 706bdcda9b0b65d268398754d7fefe53e1c08758 Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 13 Feb 2026 08:54:06 -0800 Subject: [PATCH 38/39] Revert to self-contained example --- examples/mmd_ae.py | 356 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 304 insertions(+), 52 deletions(-) diff --git a/examples/mmd_ae.py b/examples/mmd_ae.py index f7bebf3..7c7609e 100644 --- a/examples/mmd_ae.py +++ b/examples/mmd_ae.py @@ -1,7 +1,9 @@ from itertools import cycle from math import prod -from os import makedirs +import dwave_networkx as dnx +import matplotlib.pyplot as plt +import numpy as np import torch from torch import nn from torch.optim import SGD, AdamW @@ -11,14 +13,203 @@ from torchvision.utils import make_grid, save_image from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM -from dwave.plugins.torch.nn.modules import (ConvolutionNetwork, FullyConnectedNetwork, - MaximumMeanDiscrepancy, RadialBasis, - StraightThroughTanh, rands_like, zephyr_subgraph) +from dwave.plugins.torch.nn.functional import bit2spin_soft, spin2bit_soft from dwave.system import DWaveSampler +class RadialBasisFunction(nn.Module): + + def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None): + super().__init__() + bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2) + self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) + self.bandwidth = bandwidth + + def get_bandwidth(self, l2_dist): + if self.bandwidth is None: + n = l2_dist.shape[0] + avg = l2_dist.sum() / (n**2 - n) # (diagonal is zero) + return avg + + return self.bandwidth + + def forward(self, X): + l2 = torch.cdist(X, X) ** 2 + bandwidth = self.get_bandwidth(l2.detach()) * self.bandwidth_multipliers + res = torch.exp(-l2.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) + return res + + +class MMDLoss(nn.Module): + def __init__(self, kernel): + super().__init__() + self.kernel = kernel + + def forward(self, X, Y): + K = self.kernel(torch.vstack([X.flatten(1), Y.flatten(1)])) + n = X.shape[0] + m = Y.shape[0] + XX = (K[:n, :n].sum() - K[:n, :n].trace()) / (n*(n-1)) + YY = (K[n:, n:].sum() - K[n:, n:].trace()) / (m*(m-1)) + XY = K[:n, n:].mean() + mmd = XX - 2 * XY + YY + return mmd + + +class SkipLinear(nn.Module): + def __init__(self, din, dout) -> None: + super().__init__() + self.linear = nn.Linear(din, dout, bias=False) + + def forward(self, x): + return self.linear(x) + + +class LinearBlock(nn.Module): + def __init__(self, din, dout, sn, p, bias) -> None: + super().__init__() + self.skip = SkipLinear(din, dout) + linear_1 = nn.Linear(din, dout, bias) + linear_2 = nn.Linear(dout, dout, bias) + self.block = nn.Sequential( + nn.LayerNorm(din), + linear_1, + nn.Dropout(p), + nn.ReLU(), + nn.LayerNorm(dout), + linear_2, + ) + + def forward(self, x): + return self.block(x) + self.skip(x) + + +class ConvolutionBlock(nn.Module): + def __init__(self, input_shape: tuple[int, int, int], cout: int): + super().__init__() + input_shape = tuple(input_shape) + cin, hx, wx = input_shape + if hx != wx: + raise NotImplementedError("TODO") + + self.input_shape = tuple(input_shape) + self.cin = cin + self.cout = cout + + self.block = nn.Sequential( + nn.LayerNorm(input_shape), + nn.Conv2d(cin, cout, 3, 1, 1), + nn.ReLU(), + nn.LayerNorm((cout, hx, wx)), + nn.Conv2d(cout, cout, 3, 1, 1), + ) + self.skip = SkipConv2d(cin, cout) + + def forward(self, x): + return self.block(x) + self.skip(x) + + +class SkipConv2d(nn.Module): + def __init__(self, cin: int, cout: int): + super().__init__() + self.skip = nn.Conv2d(cin, cout, 1, bias=False) + + def forward(self, x): + return self.skip(x) + + +class ConvolutionNetwork(nn.Module): + def __init__( + self, channels: list[int], input_shape: tuple[int, int, int] + ): + super().__init__() + channels = channels.copy() + input_shape = tuple(input_shape) + cx, hx, wx = input_shape + if hx != wx: + raise NotImplementedError("TODO") + self.channels = channels + self.cin = cx + self.cout = self.channels[-1] + self.input_shape = input_shape + + channels_in = [cx] + channels[:-1] + self.blocks = nn.Sequential() + for cin, cout in zip(channels_in, channels): + self.blocks.append(ConvolutionBlock((cin, hx, wx), cout)) + self.blocks.append(nn.ReLU()) + self.blocks.pop(-1) + self.skip = SkipConv2d(cx, cout) + + def forward(self, x): + x = self.blocks(x) + self.skip(x) + return x + + +class FullyConnectedNetwork(nn.Module): + def __init__(self, din, dout, depth, sn, p, bias=True) -> None: + super().__init__() + if depth == 1: + raise ValueError("Depth must be at least 2.") + self.skip = SkipLinear(din, dout) + big_d = max(din, dout) + dims = [big_d]*(depth-1) + [dout] + self.blocks = nn.Sequential() + for d_in, d_out in zip([din]+dims[:-1], dims): + self.blocks.append(LinearBlock(d_in, d_out, sn, p, bias)) + self.blocks.append(nn.Dropout(p)) + self.blocks.append(nn.ReLU()) + # Remove the last ReLU and Dropout + self.blocks.pop(-1) + self.blocks.pop(-1) + + def forward(self, x): + return self.blocks(x) + self.skip(x) + + +def straight_through_bitrounding(fuzzy_bits): + if not ((fuzzy_bits >= 0) & (fuzzy_bits <= 1)).all(): + raise ValueError(f"Inputs should be in [0, 1]: {fuzzy_bits}") + bits = fuzzy_bits + (fuzzy_bits.round() - fuzzy_bits).detach() + return bits + + +class StraightThroughTanh(nn.Module): + def __init__(self): + super().__init__() + self.hth = nn.Tanh() + + def forward(self, x): + fuzzy_spins = self.hth(x) + fuzzy_bits = spin2bit_soft(fuzzy_spins) + bits = straight_through_bitrounding(fuzzy_bits) + spins = bit2spin_soft(bits) + return spins + + +def zephyr_subgraph(G, zephyr_m): + Z_m = dnx.zephyr_graph(zephyr_m) + zsm = next(dnx.zephyr_sublattice_mappings(Z_m, G)) + S = G.subgraph([zsm(z) for z in Z_m]) + original_m = S.graph['rows'] + if original_m == zephyr_m: + return G.copy() + S.graph = G.graph.copy() + S.graph['rows'] = zephyr_m + S.graph['columns'] = zephyr_m + S.graph['name'] = S.graph['name'].replace(f"({original_m},", "("+str(zephyr_m)+",") + S.graph['name'] = S.graph['name'] + "-subgraph of " + G.graph['name'] + return S + + +def subtile(G, num_tiles): + zc = dnx.zephyr_coordinates(G.graph['rows'], 4) + return G.subgraph([g for g in G if zc.linear_to_zephyr(g)[2] < num_tiles]) + + @torch.compile class Autoencoder(nn.Module): + def __init__(self, shape, n_bits): super().__init__() dim = prod(shape) @@ -36,7 +227,8 @@ def __init__(self, shape, n_bits): self.decoder = nn.Sequential( FullyConnectedNetwork(n_bits, chidden*h*w, depth_fcnn, False, dropout), nn.Unflatten(1, (chidden, h, w)), - ConvolutionNetwork([chidden]*(depth_cnn-1) + [1], (chidden, h, w)) + ConvolutionNetwork([chidden]*(depth_cnn-1) + [1], (chidden, h, w)), + # nn.Sigmoid() ) def decode(self, q): @@ -44,17 +236,20 @@ def decode(self, q): return xhat def forward(self, x): - spins = self.binarizer(self.encoder(x)) + z = self.encoder(x) + spins = self.binarizer(z) xhat = self.decode(spins) - return spins, xhat + return z, spins, xhat -def collect_stats(model, grbm, x, q, compute_mmd): - s, xhat = model(x) +def collect_stats(model, grbm, x, q, compute_mmd, compute_pkl): + z, s, xhat = model(x) stats = { "quasi": grbm.quasi_objective(s.detach(), q), + "mse": nn.functional.mse_loss(xhat.sigmoid(), x), "bce": nn.functional.binary_cross_entropy_with_logits(xhat, x), "mmd": compute_mmd(s, q), + "pkl": compute_pkl(grbm, z, s, q), } return stats @@ -65,77 +260,134 @@ def get_dataset(bs, data_dir="/tmp/"): transforms = Compose([transforms, lambda x: 1 - x]) data_train = MNIST(transform=transforms, **train_kwargs) train_loader = DataLoader(data_train, bs, True) - return train_loader + data_test = MNIST(transform=transforms, **train_kwargs, train=False) + test_loader = DataLoader(data_test, bs, True) + return train_loader, test_loader -def round_graph_down(graph, group_size): - n_in = graph.number_of_nodes() - no = group_size*(n_in//group_size) - return graph.subgraph(list(graph.nodes)[:no]) +def save_viz(step, grbm, model, x, q): + bs = min(x.shape[0], 500) + rows = int(bs**0.5) + with torch.no_grad(): + # Save images + xgen = model.decode(q[:bs]).sigmoid() + xuni = model.decode(bit2spin_soft(torch.randint_like(q[:bs], 2))).sigmoid() + z, s, xhat = model(x[:bs]) + xhat = xhat.sigmoid() + xgrid = make_grid(x[:bs], rows, pad_value=1) + xgengrid = make_grid(xgen, rows, pad_value=1) + xunigrid = make_grid(xuni, rows, pad_value=1) + xhatgrid = make_grid(xhat, rows, pad_value=1) + save_image(xgrid, "x.png") + save_image(xgengrid, "xgen.png") + save_image(xunigrid, "xuni.png") + save_image(xhatgrid, "xhat.png") -def run(*, num_steps): - sampler = DWaveSampler(solver="Advantage2_system1.8") - sample_params = dict(num_reads=500, annealing_time=0.5, answer_mode="raw", auto_scale=False) - h_range, j_range = sampler.properties["h_range"], sampler.properties["j_range"] - outdir = "output/example_mmd_ae/" - makedirs(outdir, exist_ok=True) - - device = "cuda" - - # Setup data - train_loader = get_dataset(500) - +def get_qpu_model_grbm(solver, device): + # Set up QPU and QPU parameters + qpu = DWaveSampler(solver=solver) # Instantiate model - G = zephyr_subgraph(sampler.to_networkx_graph(), 4) + # G = zephyr_subgraph(qpu.to_networkx_graph(), 4) + G = subtile(zephyr_subgraph(qpu.to_networkx_graph(), 5), 3) nodes = list(G.nodes) edges = list(G.edges) grbm = GRBM(nodes, edges).to(device) + # grbm.linear.data[:] = 0 + # grbm.quadratic.data[:] = 0 model = Autoencoder((1, 28, 28), grbm.n_nodes).to(device) + return qpu, model, grbm + + +def run(*, title, loss_fn, solver, stop_grbm, num_reads, + annealing_time, alpha, num_steps, args): + device = "cuda" + qpu, model, grbm = get_qpu_model_grbm(solver, device) + nprng = np.random.default_rng(8257213849) + grbm.linear.data[:] = 0.1 * bit2spin_soft(torch.tensor(nprng.binomial(1, 0.5, grbm.n_nodes))) + grbm.quadratic.data[:] = bit2spin_soft(torch.tensor(nprng.binomial(1, 0.5, grbm.n_edges))) + sampler = qpu + model.train() grbm.train() - compute_mmd = MaximumMeanDiscrepancy(RadialBasis()).to(device) - opt_grbm = SGD(grbm.parameters(), lr=1e-3) - opt_ae = AdamW(model.parameters(), lr=1e-3) + opt_model = AdamW(model.parameters(), lr=1e-3) + + sample_params = dict(num_reads=num_reads, annealing_time=annealing_time, + answer_mode="raw", auto_scale=False) + h_range, j_range = qpu.properties["h_range"], qpu.properties["j_range"] - for step, (x, y) in enumerate(cycle(train_loader)): + # Set up data + train_loader, test_loader = get_dataset(num_reads) + + compute_mmd = MMDLoss(RadialBasisFunction()).to(device) + + def compute_pkl(grbm: GRBM, logits_data: torch.Tensor, spins_data: torch.Tensor, + spins_model: torch.Tensor): + probabilities = torch.sigmoid(logits_data) + entropy = torch.nn.functional.binary_cross_entropy_with_logits(logits_data, probabilities) + # bce = p(log(q)) + (1-p) log(1-q) + cross_entropy = grbm.quasi_objective(spins_data, spins_model) + pkl = cross_entropy - entropy + return pkl + + for step, (x, _) in enumerate(cycle(train_loader), 1): torch.cuda.empty_cache() if step > num_steps: break # Send data to device x = x.to(device).float() - - q = grbm.sample(sampler, prefactor=1, linear_range=h_range, quadratic_range=j_range, + q = grbm.sample(sampler, prefactor=1, + linear_range=h_range, quadratic_range=j_range, device=device, sample_params=sample_params) # Train autoencoder - stats = collect_stats(model, grbm, x, q, compute_mmd) - opt_ae.zero_grad() - (stats["bce"] + stats["mmd"]).backward() - opt_ae.step() + stats = collect_stats(model, grbm, x, q, compute_mmd, compute_pkl) + opt_model.zero_grad() + (stats["bce"] + alpha*stats[loss_fn]).backward() + # alpha ~ 1e-6 + opt_model.step() # Train GRBM - if step < 1000: - # NOTE: collecting stats because the autoencoder has been updated. - stats = collect_stats(model, grbm, x, q, compute_mmd) + if step < stop_grbm: + # NOTE: collecting stats again because the autoencoder has been updated. + stats = collect_stats(model, grbm, x, q, compute_mmd, compute_pkl) opt_grbm.zero_grad() stats['quasi'].backward() opt_grbm.step() - print(step, {k: v.item() for k, v in stats.items()}) + + print(title, step, {k: f"{v.item():.4f}" + if isinstance(v, torch.Tensor) + else f"{v:.4f}" + for k, v in stats.items()}) + if step % 10 == 0: - with torch.no_grad(): - grbm.eval() - xgen = model.decode(q[:100]) - xuni = model.decode(rands_like(q[:100])) - xhat = model(x[:100])[-1] - save_image(make_grid(x[:100], 10, pad_value=1), outdir + "x.png") - save_image(make_grid(xgen.sigmoid(), 10, pad_value=1), outdir + "xgen.png") - save_image(make_grid(xhat.sigmoid(), 10, pad_value=1), outdir + "xhat.png") - save_image(make_grid(xuni.sigmoid(), 10, pad_value=1), outdir + "xuni.png") - grbm.train() + model.eval() + + xtest = next(iter(test_loader))[0].to(device) + q = grbm.sample(sampler, prefactor=1, + linear_range=h_range, quadratic_range=j_range, + device=device, sample_params=sample_params) + stats = collect_stats(model, grbm, xtest, q, compute_mmd, compute_pkl) + save_viz(step, grbm, model, x, q) + + model.train() if __name__ == "__main__": - run(num_steps=10_000) + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument("--title", type=str, default="NoExperimentName") + parser.add_argument("--annealing_time", type=float, default=0.5) + parser.add_argument("--alpha", type=float, default=1.0) + parser.add_argument("--num_steps", type=int, default=1_000) + parser.add_argument("--num_reads", type=int, default=1000) + parser.add_argument("--stop_grbm", type=int, default=500) + parser.add_argument("--loss_fn", type=str, default="mmd") + parser.add_argument("--solver", type=str, default="Advantage2_system1.11") + args_ = parser.parse_args() + + args_dict = vars(args_) + run(**args_dict, args=args_) + # postprocess(**args_dict, args=args_) From a0600c5286112c991fcc4178d975ee842532a48f Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 13 Feb 2026 10:41:30 -0800 Subject: [PATCH 39/39] Add model loading and saving snippet --- examples/mmd_ae.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/mmd_ae.py b/examples/mmd_ae.py index 7c7609e..4952a26 100644 --- a/examples/mmd_ae.py +++ b/examples/mmd_ae.py @@ -311,6 +311,10 @@ def run(*, title, loss_fn, solver, stop_grbm, num_reads, model.train() grbm.train() + # UNCOMMENT TO LOAD: + # grbm.load_state_dict(torch.load("grbm.pt")) + # model.load_state_dict(torch.load("model.pt")) + opt_grbm = SGD(grbm.parameters(), lr=1e-3) opt_model = AdamW(model.parameters(), lr=1e-3) @@ -373,6 +377,8 @@ def compute_pkl(grbm: GRBM, logits_data: torch.Tensor, spins_data: torch.Tensor, save_viz(step, grbm, model, x, q) model.train() + torch.save(grbm.state_dict(), "grbm.pt") + torch.save(model.state_dict(), "model.pt") if __name__ == "__main__":