diff --git a/ctgan/synthesizers/base.py b/ctgan/synthesizers/base.py index add0dd7e..2fb49db3 100644 --- a/ctgan/synthesizers/base.py +++ b/ctgan/synthesizers/base.py @@ -105,7 +105,13 @@ def __setstate__(self, state): state['random_states'] = (current_numpy_state, current_torch_state) self.__dict__ = state - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + # Prioritize CUDA if available, then MPSCUDA, finally CPU + if torch.cuda.is_available(): + device = torch.device('cuda:0') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') self.set_device(device) def save(self, path): @@ -118,11 +124,33 @@ def save(self, path): @classmethod def load(cls, path): """Load the model stored in the passed `path`.""" - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + # Prioritize CUDA if available, then MPS, finally CPU + if torch.cuda.is_available(): + device = torch.device('cuda:0') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') model = torch.load(path) model.set_device(device) return model + def set_device(self, device): + """Set the `device` to be used ('GPU' or 'CPU').""" + self._device = device + if device.type == 'cuda': + # For CUDA, move the generator to the appropriate device + if self._generator is not None: + self._generator.to(self._device) + elif device.type == 'mps': + # For MPS, move module parameters and buffers to the MPS device + if self._generator is not None: + self._generator.to(self._device) + for parameter in self._generator.parameters(): + parameter.data = parameter.data.to(self._device) + for buffer in self._generator.buffers(): + buffer.data = buffer.data.to(self._device) + def set_random_state(self, random_state): """Set the random state. @@ -148,4 +176,4 @@ def set_random_state(self, random_state): raise TypeError( f'`random_state` {random_state} expected to be an int or a tuple of ' '(`np.random.RandomState`, `torch.Generator`)' - ) + ) \ No newline at end of file diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 29606a34..b7e9d0a0 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -142,6 +142,10 @@ class CTGAN(BaseSynthesizer): Whether to attempt to use cuda for GPU computation. If this is False or CUDA is not available, CPU will be used. Defaults to ``True``. + mps (bool): + Whether to attempt to use mps for GPU computation. + If this is False or MPS is not available, CPU will be used. + Defaults to ``False``. """ def __init__( @@ -160,6 +164,7 @@ def __init__( epochs=300, pac=10, cuda=True, + mps=False, ): assert batch_size % 2 == 0 @@ -179,12 +184,16 @@ def __init__( self._epochs = epochs self.pac = pac - if not cuda or not torch.cuda.is_available(): + if not cuda and not mps: device = 'cpu' + elif mps and torch.backends.mps.is_available(): + device = 'mps' + elif cuda and torch.cuda.is_available(): + device = 'cuda' elif isinstance(cuda, str): device = cuda else: - device = 'cuda' + device = 'cpu' self._device = torch.device(device) diff --git a/tests/integration/synthesizer/test_ctgan_apple_mps.py b/tests/integration/synthesizer/test_ctgan_apple_mps.py new file mode 100644 index 00000000..2a3a9a08 --- /dev/null +++ b/tests/integration/synthesizer/test_ctgan_apple_mps.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Integration tests for ctgan. + +These tests only ensure that the software does not crash and that +the API works as expected in terms of input and output data formats, +but correctness of the data values and the internal behavior of the +model are not checked. +""" + +import tempfile as tf + +import numpy as np +import pandas as pd +import pytest +import torch +import os + +from ctgan.synthesizers.ctgan import CTGAN + +@pytest.fixture +def random_state(): + return 42 + +@pytest.fixture +def train_data(): + size = 100 + # Explicitly specify categorical columns during DataFrame creation + df = pd.DataFrame({ + 'continuous': np.random.normal(size=size), + 'categorical': np.random.choice(['a', 'b', 'c'], size=size), + 'binary': np.random.choice([0, 1], size=size).astype(int) + }) + return df + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_ctgan_fit_sample_apple_mps_hardware(tmpdir, train_data, random_state): + """Test the CTGAN can fit and sample.""" + # Specify discrete columns explicitly + discrete_columns = ['categorical', 'binary'] # Explicitly specify discrete columns + ctgan = CTGAN(cuda=False, epochs=1) + ctgan.set_random_state(random_state) + ctgan.fit(train_data, discrete_columns=discrete_columns) + sampled = ctgan.sample(1000) + assert sampled.shape == (1000, train_data.shape[1]) + + # Save and load + path = os.path.join(tmpdir, 'test_ctgan.pkl') + ctgan.save(path) + ctgan = CTGAN.load(path) + + sampled = ctgan.sample(1000) + assert sampled.shape == (1000, train_data.shape[1]) + + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_training_apple_mps_hardware(tmpdir, train_data, random_state): + """Test CTGAN training on MPS device.""" + ctgan = CTGAN(cuda=False, mps=True, epochs=1) + ctgan.set_random_state(random_state) + discrete_columns = ['categorical', 'binary'] # Explicitly specify discrete columns + + # Check device of model components before training + assert ctgan._device.type == 'mps' + # assert next(ctgan._generator.parameters()).device.type == 'mps' + + ctgan.fit(train_data, discrete_columns=discrete_columns) + + # Check device of model components after training + assert next(ctgan._generator.parameters()).device.type == 'mps' + + sampled = ctgan.sample(100) + assert sampled.shape == (100, train_data.shape[1]) + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_save_load_apple_mps_hardware(tmpdir, train_data, random_state): + """Test the CTGAN saves and loads correctly.""" + ctgan = CTGAN(cuda=False, epochs=1) + ctgan.set_random_state(random_state) + discrete_columns = ['categorical', 'binary'] # Explicitly specify discrete columns + + ctgan.fit(train_data, discrete_columns=discrete_columns) + + # Save and load + path = os.path.join(tmpdir, 'test_ctgan.pkl') + ctgan.save(path) + ctgan = CTGAN.load(path) + + # Check device type after loading + if torch.backends.mps.is_available(): + assert ctgan._device.type == 'mps' + assert next(ctgan._generator.parameters()).device.type == 'mps' + elif torch.cuda.is_available(): + assert ctgan._device.type == 'cuda' + assert next(ctgan._generator.parameters()).device.type == 'cuda' + else: + assert ctgan._device.type == 'cpu' + assert next(ctgan._generator.parameters()).device.type == 'cpu' + + sampled = ctgan.sample(1000) + assert sampled.shape == (1000, train_data.shape[1]) \ No newline at end of file