diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 7f761d8f..90d7a1de 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -121,6 +121,11 @@ def operands(self) -> tuple[cue.SegmentedOperand, ...]: def assert_valid(self): assert Subscripts.is_valid(self.subscripts) + if not all(map(lambda u: u.lower() == u, self.subscripts.modes())): + raise ValueError( + f"subscripts {self.subscripts} must contain only lowercase letters. (Capital letters are reserved for internal use.)" + ) + for m in self.subscripts.modes(): if self.subscripts.count(m) == 1: raise ValueError( @@ -171,6 +176,11 @@ def from_subscripts(cls, subscripts: Subscripts) -> SegmentedTensorProduct: cue.SegmentedOperand(ndim=len(operand)) for operand in subscripts.operands ] + if not all(map(lambda u: u.lower() == u, subscripts.modes())): + raise ValueError( + f"subscripts {subscripts} must contain only lowercase letters. (Capital letters are reserved for internal use.)" + ) + return cls( operands_and_subscripts=list(zip(operands, subscripts.operands)), paths=[], diff --git a/cuequivariance/cuequivariance/segmented_polynomials/subscripts.py b/cuequivariance/cuequivariance/segmented_polynomials/subscripts.py index 89891a67..30fc88b2 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/subscripts.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/subscripts.py @@ -53,7 +53,7 @@ def is_valid(subscripts: str) -> bool: """ if not isinstance(subscripts, str): return False - mode = r"[a-z*]" + mode = r"[a-zA-Z*]" if re.match(rf"^{mode}*({SEP}{mode}*)*(\+{mode}*)?$", subscripts) is None: return False operands_and_coefficients = re.split(rf"[{SEP}+]", subscripts) diff --git a/cuequivariance/tests/segmented_polynomials/subscripts_test.py b/cuequivariance/tests/segmented_polynomials/subscripts_test.py index 303ba5aa..9082ec52 100644 --- a/cuequivariance/tests/segmented_polynomials/subscripts_test.py +++ b/cuequivariance/tests/segmented_polynomials/subscripts_test.py @@ -20,12 +20,6 @@ def test_subscripts(): with pytest.raises(ValueError): sp.Subscripts("#$%@") - with pytest.raises(ValueError): - sp.Subscripts("Zu") # uppercase not supported anymore - - with pytest.raises(ValueError): - sp.Subscripts("uZ") # uppercase after lowercase - with pytest.raises(ValueError): sp.Subscripts("uZ+ij+kl") # multiple + signs diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index 849d3271..49940500 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -34,6 +34,9 @@ from cuequivariance_jax.segmented_polynomials.segmented_polynomial_uniform_1d import ( execute_uniform_1d, ) +from cuequivariance_jax.segmented_polynomials.segmented_polynomial_gemm_grouped import ( + execute_gemm_grouped, +) from cuequivariance_jax.segmented_polynomials.utils import ( batch_size, reshape, @@ -76,6 +79,7 @@ def segmented_polynomial( - ``"naive"``: Uses a naive JAX implementation. It always works but is not optimized. - ``"uniform_1d"``: Uses a CUDA implementation for polynomials with a single uniform mode. - ``"indexed_linear"``: Uses a CUDA implementation for linear layers with indexed weights. + - ``"gemm_grouped"``: Uses a CUDA implementation for polynomials mappable to matrix multiplications. .. note:: The ``"fused_tp"`` method is only available in the PyTorch implementation. @@ -156,6 +160,7 @@ def segmented_polynomial( "To fix this, simply add a `method` parameter to your function call. Here are the available options:\n" "• 'naive' - Works everywhere but not optimized (good for testing)\n" "• 'uniform_1d' - Fast CUDA implementation for single uniform mode polynomials\n" + "• 'gemm_grouped' - Fast CUDA implementation for matrix multiplication patterns\n" "• 'indexed_linear' - Fast CUDA implementation for linear layers with indexed weights\n\n" "Example: outputs = segmented_polynomial(poly, inputs, outputs, method='naive')" ) @@ -486,7 +491,7 @@ def segmented_polynomial_impl( f"{name}: {fl / 1e9:.2f} GFLOP, {mem / 1e9:.2f} GB, arithmetic intensity: {fl / mem:.2f} FLOP/byte" ) - assert method in ("naive", "uniform_1d", "indexed_linear") + assert method in ("naive", "uniform_1d", "gemm_grouped", "indexed_linear") if platform != "cuda" and method != "naive": warnings.warn( f"Method '{method}' requires CUDA, but platform is '{platform}'. " @@ -518,6 +523,8 @@ def segmented_polynomial_impl( return execute_uniform_1d(**kwargs) case "indexed_linear": return execute_indexed_linear(**kwargs, index_mode=index_mode) + case "gemm_grouped": + return execute_gemm_grouped(**kwargs) def segmented_polynomial_jvp( diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_gemm_grouped.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_gemm_grouped.py new file mode 100644 index 00000000..61a9246c --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_gemm_grouped.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import jax +import jax.numpy as jnp +import numpy as np + +import cuequivariance as cue + + +def _prepand_batch_modes(operand_subscript_pair): + array, subscript = operand_subscript_pair + batch_shape = array.shape[: -len(subscript)] if subscript else array.shape + + batch_modes = "" + for i, size in enumerate(batch_shape): + if size == 1: + batch_modes += "1" + else: + # Use letters A, B, C, ... for batch dimensions + batch_modes += chr(ord("A") + i) + + return array, batch_modes + subscript + + +def _squeeze_modes(operand_subscript_pair): + array, subscript = operand_subscript_pair + + # Find positions of '1' in the subscript + squeeze_axes = [] + new_subscript = "" + + for i, char in enumerate(subscript): + if char == "1": + squeeze_axes.append(i) + else: + new_subscript += char + + # Squeeze the array at the identified axes + squeezed_array = array + for axis in reversed(squeeze_axes): # Reverse to maintain correct indices + squeezed_array = jnp.squeeze(squeezed_array, axis=axis) + + return squeezed_array, new_subscript + + +def _consolidate_pairs(operands): + if not operands: + return operands + + # Find all consecutive character pairs across all subscripts + all_pairs = set() + for _, subscript in operands: + for i in range(len(subscript) - 1): + all_pairs.add(subscript[i : i + 2]) + + # Find a pair that can be consolidated (appears in all relevant subscripts) + for pair in all_pairs: + char1, char2 = pair + if all( + pair in sub or (char1 not in sub and char2 not in sub) + for _, sub in operands + ): + # Consolidate this pair + new_operands = [] + for array, subscript in operands: + if pair in subscript: + pos = subscript.index(pair) + # Combine dimensions at pos and pos+1 + new_shape = list(array.shape) + new_shape[pos] *= new_shape[pos + 1] + new_shape.pop(pos + 1) + array = jnp.reshape(array, new_shape) + subscript = subscript.replace(pair, char1) + new_operands.append((array, subscript)) + return _consolidate_pairs(new_operands) + + return operands + + +def execute_gemm_grouped( + inputs: list[jax.Array], # shape (*batch_sizes, operand_size) + outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], + indices: list[jax.Array], + index_configuration: tuple[tuple[int, ...], ...], + polynomial: cue.SegmentedPolynomial, + math_dtype: str | None, + name: str, +) -> list[jax.Array]: + index_configuration = np.array(index_configuration) + num_batch_axes = index_configuration.shape[1] + assert ( + polynomial.num_inputs + len(outputs_shape_dtype) == index_configuration.shape[0] + ) + assert polynomial.num_outputs == len(outputs_shape_dtype) + + assert math_dtype is None + + if not all(x.dtype in {jnp.int32, jnp.int64} for x in indices): + raise ValueError("All indices must have dtype int32 or int64") + + from cuequivariance_ops_jax import gemm_grouped + + # index_configuration = np.concatenate( + # [index_configuration, np.full((len(indices), num_batch_axes), -1, np.int32)] + # ) + + if not np.all(index_configuration == -1): + raise ValueError("method 'gemm_grouped' does not support indices (yet)") + if len(indices) != 0: + raise ValueError("method 'gemm_grouped' does not support indices (yet)") + + gemms = [] + + nin = polynomial.num_inputs + for ope, stp in polynomial.operations: + assert stp.num_operands == 3, ( + f"method 'gemm_grouped' requires STPs with 3 operands, got {stp.num_operands} for {ope}" + ) + assert stp.coefficient_subscripts == "", ( + f"method 'gemm_grouped' requires STPs without coefficient subscripts, got {stp.coefficient_subscripts} for {ope}" + ) + oid, i = ope.output_operand_buffer(nin) + [AA, BB] = [inputs[i] for i in ope.input_buffers(nin)] + CC = outputs_shape_dtype[i - nin] + stp = stp.move_operand_last(oid) + + Aslices = stp.operands[0].segment_slices() + Bslices = stp.operands[1].segment_slices() + + for path in stp.paths: + A = AA[..., Aslices[path.indices[0]]] + B = BB[..., Bslices[path.indices[1]]] + + A = jnp.reshape(A, A.shape[:-1] + stp.operands[0].segments[path.indices[0]]) + B = jnp.reshape(B, B.shape[:-1] + stp.operands[1].segments[path.indices[1]]) + C_shape = CC.shape[:-1] + stp.operands[2].segments[path.indices[2]] + C = jnp.zeros(C_shape, dtype=CC.dtype) + + sa, sb, sc = stp.subscripts.operands + assert A.ndim == num_batch_axes + len(sa) + assert B.ndim == num_batch_axes + len(sb) + assert C.ndim == num_batch_axes + len(sc) + + operands = [(A, sa), (B, sb), (C, sc)] + operands = list(map(_prepand_batch_modes, operands)) + operands = list(map(_squeeze_modes, operands)) + operands = _consolidate_pairs(operands) + + [(A, sa), (B, sb), (C, sc)] = operands + + if len(sc) >= 2: + u, v = sc[-2:] + if u in sb and v in sa: + [(A, sa), (B, sb)] = [(B, sb), (A, sa)] + if len(sc) == 1: + if len(sa) == 2 and len(sb) == 1: + [(A, sa), (B, sb)] = [(B, sb), (A, sa)] + + [sa, sb, sc] = ( + cue.segmented_polynomials.Subscripts.from_operands([sa, sb, sc]) + .canonicalize() + .operands + ) + contr = f"{sa},{sb}->{sc}" + + gemm = None + + if contr == "uvw,uav->uwa": + gemm = (A, B, True, True) + if contr == "uvw,uwa->uva": + gemm = (A, B, False, False) + + if contr == "uv,vw->uw": + gemm = (A, B, False, False) + if contr == "uv,wv->uw": + gemm = (A, B, False, True) + if contr == "uv,uw->vw": + gemm = (A, B, True, False) + if contr == "uv,wu->vw": + gemm = (A, B, True, True) + + if contr == "u,uv->v": + gemm = (A[None, :], B, False, False) + if contr == "u,vu->v": + gemm = (A[None, :], B, False, True) + + if contr == "u,v->uv": + gemm = (A[:, None], B[None, :], False, False) + + if gemm is None: + raise ValueError( + f"gemm_grouped does not support: {A.shape} @ {B.shape} -> {C.shape} with contraction {sa},{sb}->{sc}" + ) + gemms.append(gemm + (path.coefficients.item(),)) + + num_batch_axes = {A.ndim - 2 for A, _, _, _, _ in gemms} + assert len(num_batch_axes) == 1 + num_batch_axes = num_batch_axes.pop() + gemm_outs = gemm_grouped( + gemms, + [], + np.full((2 * len(gemms), num_batch_axes), -1, np.int32), + use_tf32=False, + ) + outputs = [jnp.zeros(x.shape, dtype=x.dtype) for x in outputs_shape_dtype] + + for ope, stp in polynomial.operations: + oid, i = ope.output_operand_buffer(nin) + slices = stp.operands[oid].segment_slices() + segments = stp.operands[oid].segments + + for path in stp.paths: + sid = path.indices[oid] + acc = outputs[i - nin] + outputs[i - nin] = acc.at[..., slices[sid]].add( + jnp.reshape( + gemm_outs.pop(0), acc.shape[:-1] + (math.prod(segments[sid]),) + ) + ) + return outputs diff --git a/cuequivariance_jax/cuequivariance_jax/triangle/_triangle_attention.py b/cuequivariance_jax/cuequivariance_jax/triangle/_triangle_attention.py index 0db866fa..3e044a42 100644 --- a/cuequivariance_jax/cuequivariance_jax/triangle/_triangle_attention.py +++ b/cuequivariance_jax/cuequivariance_jax/triangle/_triangle_attention.py @@ -21,16 +21,6 @@ from cuequivariance_jax.triangle._naive_batching import naive_batching_rule -try: - from cuequivariance_ops_jax import ( - triangle_attention_cuda_bwd, - triangle_attention_cuda_fwd, - ) - - HAS_CUE_OPS_JAX = True -except ImportError: - HAS_CUE_OPS_JAX = False - def triangle_attention( q: jax.Array, # [B, N, H, S_qo, D] @@ -157,9 +147,12 @@ def triangle_attention_fwd_impl( precision: jax.lax.Precision | None = None, ) -> tuple[jax.Array, jax.Array, jax.Array]: if platform == "cuda": - assert HAS_CUE_OPS_JAX, ( - "Please install cuequivariance_ops_jax for CUDA support." - ) + try: + from cuequivariance_ops_jax import triangle_attention_cuda_fwd + except ImportError as e: + raise ImportError( + "Please install cuequivariance_ops_jax for CUDA support." + ) from e return triangle_attention_cuda_fwd(q, k, v, mask, bias, scale, precision) else: return triangle_attention_jax_fwd(q, k, v, bias, mask, scale, precision) @@ -180,9 +173,12 @@ def triangle_attention_bwd_impl( precision: jax.lax.Precision | None = None, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]: if platform == "cuda": - assert HAS_CUE_OPS_JAX, ( - "Please install cuequivariance_ops_jax for CUDA support." - ) + try: + from cuequivariance_ops_jax import triangle_attention_cuda_bwd + except ImportError as e: + raise ImportError( + "Please install cuequivariance_ops_jax for CUDA support." + ) from e return triangle_attention_cuda_bwd( da, a, lse, q, k, v, mask, bias, scale, precision ) diff --git a/cuequivariance_jax/pyproject.toml b/cuequivariance_jax/pyproject.toml index 79a5c17f..b577904f 100644 --- a/cuequivariance_jax/pyproject.toml +++ b/cuequivariance_jax/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "cuequivariance", "jax", "packaging", + "einops", ] classifiers = [ "Intended Audience :: Developers",