diff --git a/aimnet/calculators/nb_kernel_rocm.py b/aimnet/calculators/nb_kernel_rocm.py new file mode 100644 index 0000000..d1fbfbb --- /dev/null +++ b/aimnet/calculators/nb_kernel_rocm.py @@ -0,0 +1,223 @@ +import numba +import numpy as np + + +@numba.njit(cache=True, parallel=False, fastmath=True) +def _nbmat_dual_rocm( + coord: np.ndarray, # float, (N, 3) + cutoff1_squared: float, + cutoff2_squared: float, + mol_idx: np.ndarray, # int, (N,) + mol_end_idx: np.ndarray, # int, (M,) + nbmat1: np.ndarray, # int, (N, maxnb1) + nbmat2: np.ndarray, # int, (N, maxnb2) + nnb1: np.ndarray, # int, zeros, (N,) + nnb2: np.ndarray, # int, zeros, (N,) +): + maxnb1 = nbmat1.shape[1] + maxnb2 = nbmat2.shape[1] + N = coord.shape[0] + for i in range(N): + c_i = coord[i] + _mol_idx = mol_idx[i] + _j_start = i + 1 + _j_end = mol_end_idx[_mol_idx] + for j in range(_j_start, _j_end): + diff = c_i - coord[j] + dx, dy, dz = diff[0], diff[1], diff[2] + dist2 = dx * dx + dy * dy + dz * dz + if dist2 < cutoff1_squared: + pos = nnb1[i] + nnb1[i] += 1 + if pos < maxnb1: + nbmat1[i, pos] = j + + if dist2 < cutoff2_squared: + pos = nnb2[i] + nnb2[i] += 1 + if pos < maxnb2: + nbmat2[i, pos] = j + _expand_nb(nnb1, nbmat1) + _expand_nb(nnb2, nbmat2) + + +@numba.njit(cache=True, parallel=False, fastmath=True) +def _nbmat_rocm( + coord: np.ndarray, # float, (N, 3) + cutoff1_squared: float, + mol_idx: np.ndarray, # int, (N,) + mol_end_idx: np.ndarray, # int, (M,) + nbmat1: np.ndarray, # int, (N, maxnb1) + nnb1: np.ndarray, # int, zeros, (N,) +): + maxnb1 = nbmat1.shape[1] + N = coord.shape[0] + for i in range(N): + c_i = coord[i] + _mol_idx = mol_idx[i] + _j_start = i + 1 + _j_end = mol_end_idx[_mol_idx] + for j in range(_j_start, _j_end): + diff = c_i - coord[j] + dx, dy, dz = diff[0], diff[1], diff[2] + dist2 = dx * dx + dy * dy + dz * dz + if dist2 < cutoff1_squared: + pos = nnb1[i] + nnb1[i] += 1 + if pos < maxnb1: + nbmat1[i, pos] = j + _expand_nb(nnb1, nbmat1) + + +@numba.njit(cache=True, inline="always") +def _expand_nb(nnb, nbmat): + nnb_copy = nnb.copy() + N = nnb.shape[0] + for i in range(N): + for m in range(nnb_copy[i]): + if m >= nbmat.shape[1]: + continue + j = nbmat[i, m] + if j < N: + pos = nnb[j] + nnb[j] += 1 + if pos < nbmat.shape[1]: + nbmat[j, pos] = i + + +@numba.njit(cache=True, inline="always") +def _expand_nb_pbc(nnb, nbmat, shifts): + nnb_copy = nnb.copy() + N = nnb.shape[0] + for i in range(N): + for m in range(nnb_copy[i]): + if m >= nbmat.shape[1]: + continue + j = nbmat[i, m] + if j < N: + pos = nnb[j] + nnb[j] += 1 + if pos < nbmat.shape[1]: + nbmat[j, pos] = i + shift = shifts[i, m] + shifts[j, pos] = -shift + + +@numba.njit(cache=True) +def _expand_shifts(nshift): + tot_shifts = (nshift[0] + 1) * (2 * nshift[1] + 1) * (2 * nshift[2] + 1) + shifts = np.zeros((tot_shifts, 3), dtype=np.float32) + i = 0 + for k1 in range(-nshift[0], nshift[0] + 1): + for k2 in range(-nshift[1], nshift[1] + 1): + for k3 in range(-nshift[2], nshift[2] + 1): + if k1 > 0 or (k1 == 0 and k2 > 0) or (k1 == 0 and k2 == 0 and k3 >= 0): + shifts[i, 0] = k1 + shifts[i, 1] = k2 + shifts[i, 2] = k3 + i += 1 + shifts = shifts[:i] + return shifts + + +@numba.njit(cache=True, parallel=False, fastmath=True) +def shift_coords(coord, cell, shifts): + N = coord.shape[0] + S = shifts.shape[0] + # pre-compute shifted coords + coord_shifted = np.empty((N, S, 3), dtype=coord.dtype) + for i in range(N): + for s in range(S): + shift = shifts[s] + c_x = coord[i, 0] + shift[0] * cell[0, 0] + shift[1] * cell[1, 0] + shift[2] * cell[2, 0] + c_y = coord[i, 1] + shift[0] * cell[0, 1] + shift[1] * cell[1, 1] + shift[2] * cell[2, 1] + c_z = coord[i, 2] + shift[0] * cell[0, 2] + shift[1] * cell[1, 2] + shift[2] * cell[2, 2] + coord_shifted[i, s] = c_x, c_y, c_z + return coord_shifted + + +@numba.njit(cache=True, parallel=False, fastmath=True) +def _nbmat_pbc_rocm( + coord: np.ndarray, # float, (N, 3) + cell: np.ndarray, # float, (3, 3) + cutoff1_squared: float, + shifts: np.ndarray, # float, (S, 3) + nnb1: np.ndarray, # int, zeros, (N,) + nbmat1: np.ndarray, # int, (N, M) + shifts1: np.ndarray, # int, (N, M, 3) +): + maxnb1 = nbmat1.shape[1] + N = coord.shape[0] + S = shifts.shape[0] + + coord_shifted = shift_coords(coord, cell, shifts) + + for i in range(N): + c_i = coord[i] + for s in range(S): + shift = shifts[s] + zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0 + _j_end = i if zero_shift else N + for j in range(_j_end): + c_j = coord_shifted[j, s] + dx = c_i[0] - c_j[0] + dy = c_i[1] - c_j[1] + dz = c_i[2] - c_j[2] + r2 = dx * dx + dy * dy + dz * dz + if r2 < cutoff1_squared: + pos = nnb1[i] + nnb1[i] += 1 + if pos < maxnb1: + nbmat1[i, pos] = j + shifts1[i, pos] = shift + _expand_nb_pbc(nnb1, nbmat1, shifts1) + + +@numba.njit(cache=True, parallel=False, fastmath=True) +def _nbmat_dual_pbc_rocm( + coord: np.ndarray, # float, (N, 3) + cell: np.ndarray, # float, (3, 3) + cutoff1_squared: float, + cutoff2_squared: float, + shifts: np.ndarray, # float, (S, 3) + nnb1: np.ndarray, # int, zeros, (N,) + nnb2: np.ndarray, # int, zeros, (N,) + nbmat1: np.ndarray, # int, (N, M) + nbmat2: np.ndarray, # int, (N, M) + shifts1: np.ndarray, # int, (N, M, 3) + shifts2: np.ndarray, # int, (N, M, 3) +): + maxnb1 = nbmat1.shape[1] + maxnb2 = nbmat2.shape[1] + N = coord.shape[0] + S = shifts.shape[0] + + coord_shifted = shift_coords(coord, cell, shifts) + + for i in range(N): + c_i = coord[i] + for s in range(S): + shift = shifts[s] + zero_shift = shift[0] == 0 and shift[1] == 0 and shift[2] == 0 + _j_end = i if zero_shift else N + for j in range(_j_end): + c_j = coord_shifted[j, s] + dx = c_i[0] - c_j[0] + dy = c_i[1] - c_j[1] + dz = c_i[2] - c_j[2] + r2 = dx * dx + dy * dy + dz * dz + if r2 < cutoff1_squared: + pos = nnb1[i] + nnb1[i] += 1 + if pos < maxnb1: + nbmat1[i, pos] = j + shifts1[i, pos] = shift + if r2 < cutoff2_squared: + pos = nnb2[i] + nnb2[i] += 1 + if pos < maxnb2: + nbmat2[i, pos] = j + shifts2[i, pos] = shift + + _expand_nb_pbc(nnb1, nbmat1, shifts1) + _expand_nb_pbc(nnb2, nbmat2, shifts2) diff --git a/aimnet/calculators/nbmat.py b/aimnet/calculators/nbmat.py index 6d42998..380355f 100644 --- a/aimnet/calculators/nbmat.py +++ b/aimnet/calculators/nbmat.py @@ -11,17 +11,26 @@ class TooManyNeighborsError(Exception): if torch.cuda.is_available(): - import numba.cuda + if torch.version.cuda is None: + _numba_cuda_available = False + from .nb_kernel_rocm import _nbmat_dual_pbc_rocm, _nbmat_dual_rocm, _nbmat_pbc_rocm, _nbmat_rocm - if not numba.cuda.is_available(): - raise ImportError("PyTorch CUDA is available, but Numba CUDA is not available.") - _numba_cuda_available = True - from .nb_kernel_cuda import _nbmat_cuda, _nbmat_dual_cuda, _nbmat_pbc_cuda, _nbmat_pbc_dual_cuda + _kernel_nbmat = _nbmat_rocm + _kernel_nbmat_dual = _nbmat_dual_rocm + _kernel_nbmat_pbc = _nbmat_pbc_rocm + _kernel_nbmat_pbc_dual = _nbmat_dual_pbc_rocm + else: + import numba.cuda + + if not numba.cuda.is_available(): + raise ImportError("PyTorch CUDA is available, but Numba CUDA is not available.") + _numba_cuda_available = True + from .nb_kernel_cuda import _nbmat_cuda, _nbmat_dual_cuda, _nbmat_pbc_cuda, _nbmat_pbc_dual_cuda - _kernel_nbmat = _nbmat_cuda - _kernel_nbmat_dual = _nbmat_dual_cuda - _kernel_nbmat_pbc = _nbmat_pbc_cuda - _kernel_nbmat_pbc_dual = _nbmat_pbc_dual_cuda + _kernel_nbmat = _nbmat_cuda + _kernel_nbmat_dual = _nbmat_dual_cuda + _kernel_nbmat_pbc = _nbmat_pbc_cuda + _kernel_nbmat_pbc_dual = _nbmat_pbc_dual_cuda else: _numba_cuda_available = False from .nb_kernel_cpu import _nbmat_cpu, _nbmat_dual_cpu, _nbmat_dual_pbc_cpu, _nbmat_pbc_cpu @@ -57,6 +66,8 @@ def calc_nbmat( raise ValueError("Numba CUDA is available, but the input tensors are not on CUDA.") _cuda = device.type == "cuda" and _numba_cuda_available + _rocm = device.type == "cuda" and not _numba_cuda_available + _dual_cutoff = cutoffs[1] is not None if _dual_cutoff and maxnb[1] is None: raise ValueError("maxnb[1] must be specified for dual cutoff.") @@ -147,6 +158,79 @@ def calc_nbmat( _nbmat1, _nnb1, ) + elif _rocm: + coord = coord.cpu() + mol_idx = mol_idx.cpu() + mol_end_idx = mol_end_idx.cpu() + nnb1 = nnb1.cpu() + nbmat1 = nbmat1.cpu() + _coord = coord.numpy() + _mol_idx = mol_idx.numpy() + _mol_end_idx = mol_end_idx.numpy() + _nnb1 = nnb1.numpy() + _nbmat1 = nbmat1.numpy() + if _dual_cutoff: + nnb2 = nnb2.cpu() + nbmat2 = nbmat2.cpu() + _nnb2 = nnb2.numpy() + _nbmat2 = nbmat2.numpy() + if _pbc: + cell = cell.cpu() + _cell = cell.numpy() # type: ignore[union-attr] + shifts = shifts.cpu() + _shifts = shifts.numpy() + + if _pbc: + shifts1 = shifts1.cpu() + _shifts1 = shifts1.numpy() + if _dual_cutoff: + shifts2 = shifts2.cpu() + _shifts2 = shifts2.numpy() + _kernel_nbmat_pbc_dual( + _coord, + _cell, + cutoffs[0] ** 2, + cutoffs[1] ** 2, # type: ignore + _shifts, + _nnb1, + _nnb2, + _nbmat1, + _nbmat2, + _shifts1, # type: ignore + _shifts2, # type: ignore + ) # type: ignore + else: + _kernel_nbmat_pbc(_coord, _cell, cutoffs[0] ** 2, _shifts, _nnb1, _nbmat1, _shifts1) # type: ignore + else: + if _dual_cutoff: + _kernel_nbmat_dual( + _coord, + cutoffs[0] ** 2, + cutoffs[1] ** 2, # type: ignore + _mol_idx, + _mol_end_idx, + _nbmat1, + _nbmat2, + _nnb1, + _nnb2, + ) # type: ignore + else: + _kernel_nbmat(_coord, cutoffs[0] ** 2, _mol_idx, _mol_end_idx, _nbmat1, _nnb1) + + coord = coord.cuda() + mol_idx = mol_idx.cuda() + mol_end_idx = mol_end_idx.cuda() + nnb1 = nnb1.cuda() + nbmat1 = nbmat1.cuda() + if _dual_cutoff: + nnb2 = nnb2.cuda() + nbmat2 = nbmat2.cuda() + if _pbc: + cell = cell.cuda() + shifts = shifts.cuda() + shifts1 = shifts1.cuda() + if _dual_cutoff: + shifts2 = shifts2.cuda() else: _coord = coord.numpy()