diff --git a/iris/_common.py b/iris/_common.py new file mode 100644 index 00000000..4b7be2d6 --- /dev/null +++ b/iris/_common.py @@ -0,0 +1,1421 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Common base class and utilities shared between Iris and IrisGluon. + +This module contains the shared implementation for initialization, +logging, device validation, and utility methods used by both +Triton and Gluon backends. +""" + +import numpy as np +import math +import torch +import ctypes +import logging + +from iris._distributed_helpers import ( + init_distributed, + distributed_allgather, + distributed_barrier, + distributed_broadcast_scalar, + distributed_broadcast_tensor, +) +from iris.hip import ( + set_device, + get_cu_count, + count_devices, + get_ipc_handle, + open_ipc_handle, +) +from iris.logging import logger + + +# Import tensor operations for use in IrisBase methods +from iris._tensor_ops import create_zeros, create_ones, create_full, create_zeros_like + + +class IrisBase: + """ + Base class for Iris implementations containing shared functionality. + + This class provides common initialization, logging, device validation, + and utility methods used by both Triton and Gluon backends. + """ + + def __init__(self, heap_size=1 << 30): + """ + Initialize the Iris base class. + + Args: + heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) + """ + # Initialize distributed environment + comm, cur_rank, num_ranks = init_distributed() + num_gpus = count_devices() + + gpu_id = cur_rank % num_gpus + set_device(gpu_id) + + self.comm = comm + self.num_ranks = num_ranks + self.cur_rank = cur_rank + self.gpu_id = gpu_id + self.heap_size = heap_size + self.heap_offset = 0 + self.alignment = 1024 + self.device = f"cuda:{gpu_id}" + self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) + + heap_base = self.memory_pool.data_ptr() + heap_base_ptr = ctypes.c_void_p(heap_base) + + heap_bases = np.zeros(num_ranks, dtype=np.uint64) + heap_bases[cur_rank] = heap_base + ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank) + + distributed_barrier() + + all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8).copy()) + + distributed_barrier() + + ipc_heap_bases = np.zeros(num_ranks, dtype=np.uintp) + for rank in range(num_ranks): + if rank != cur_rank: + handle = open_ipc_handle(all_ipc_handles[rank], cur_rank) + ipc_heap_bases[rank] = int(handle) + else: + ipc_heap_bases[rank] = heap_bases[rank] + + for i in range(num_ranks): + self.debug(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") + + distributed_barrier() + self.heap_bases = torch.from_numpy(ipc_heap_bases).to(device=self.device, dtype=torch.uint64) + + distributed_barrier() + + def _log_with_rank(self, level, message): + """Helper method to log with rank information injected into the record.""" + if logger.isEnabledFor(level): + record = logging.LogRecord( + name=logger.name, level=level, pathname="", lineno=0, msg=message, args=(), exc_info=None + ) + # Inject rank information into the record + record.iris_rank = self.cur_rank + record.iris_num_ranks = self.num_ranks + logger.handle(record) + + def debug(self, message): + """ + Log a debug message with rank information. + + Args: + message (str): Human-readable message to log at debug level. + """ + self._log_with_rank(logging.DEBUG, message) + + def info(self, message): + """ + Log an info message with rank information. + + Args: + message (str): Human-readable message to log at info level. + """ + self._log_with_rank(logging.INFO, message) + + def warning(self, message): + """ + Log a warning message with rank information. + + Args: + message (str): Human-readable message to log at warning level. + """ + self._log_with_rank(logging.WARNING, message) + + def error(self, message): + """ + Log an error message with rank information. + + Args: + message (str): Human-readable message to log at error level. + """ + self._log_with_rank(logging.ERROR, message) + + def broadcast(self, value, source_rank): + """ + Broadcast a value from one rank to all ranks. + + This method automatically detects the type of value and uses the appropriate + broadcast mechanism: + - For tensors and arrays: uses efficient PyTorch distributed tensor collectives + - For scalars and other objects: uses object broadcast + + Args: + value (Any): The value to broadcast. Can be a scalar, tensor, numpy array, + or any picklable object. Only the ``source_rank`` value is used; + other ranks should pass a placeholder (e.g., ``None``). + source_rank (int): Rank id that holds the authoritative value. + + Returns: + Any: The value broadcast to all ranks. Tensors and arrays are returned as + numpy arrays; scalars and objects are returned in their original type. + """ + # Check if the value on source_rank is a tensor or array-like + if self.cur_rank == source_rank and value is not None: + # Explicitly exclude strings and non-numeric types + if isinstance(value, (str, dict, bool)): + is_tensor = False + elif isinstance(value, torch.Tensor): + is_tensor = True + elif isinstance(value, np.ndarray): + is_tensor = True + elif isinstance(value, (list, tuple)): + # Try to convert list/tuple to tensor to check if it's numeric + try: + torch.as_tensor(value) + is_tensor = True + except (TypeError, ValueError): + is_tensor = False + else: + # For other types, try to convert and check + try: + test_array = np.asarray(value) + # Check if it's a numeric dtype that torch can handle + if np.issubdtype(test_array.dtype, np.number): + torch.as_tensor(test_array) + is_tensor = True + else: + is_tensor = False + except (TypeError, ValueError): + is_tensor = False + else: + is_tensor = False + + # Broadcast the type decision to all ranks + is_tensor = distributed_broadcast_scalar(is_tensor, source_rank) + + if is_tensor: + return distributed_broadcast_tensor(value, root=source_rank) + else: + return distributed_broadcast_scalar(value, source_rank) + + def _allocate(self, num_elements, dtype): + """ + Internal method to allocate memory from the symmetric heap. + + Args: + num_elements (int): Number of elements to allocate + dtype (torch.dtype): Data type of the elements + + Returns: + torch.Tensor: Allocated tensor on the symmetric heap + """ + self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") + + element_size = torch.tensor([], dtype=dtype).element_size() + size_in_bytes = num_elements * element_size + aligned_size = math.ceil(size_in_bytes / self.alignment) * self.alignment + + if self.heap_offset + aligned_size > self.heap_size: + raise MemoryError("Heap out of memory") + + start = self.heap_offset + self.heap_offset += aligned_size + + sub_buffer = self.memory_pool[start : start + size_in_bytes].view(dtype) + return sub_buffer.reshape((num_elements,)) + + def _parse_size(self, size): + """ + Parse size parameter and calculate number of elements. + + Args: + size (tuple): Size specification (can be nested) + + Returns: + tuple: (parsed_size, num_elements) + """ + # Handle nested tuples/lists by flattening them recursively + while len(size) == 1 and isinstance(size[0], (tuple, list)): + size = size[0] + num_elements = math.prod(size) + return size, num_elements + + def _throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): + """ + Validate that an output tensor meets requirements. + + Args: + tensor: The tensor to validate + num_elements: Expected number of elements + dtype: Expected data type + + Raises: + RuntimeError: If validation fails + """ + if not self._tensor_on_device(tensor): + raise RuntimeError( + f"The output tensor is not on the same device as the Iris instance. " + f"The Iris instance is on device {self.device} but the output tensor is on device {tensor.device}" + ) + if not self._on_symmetric_heap(tensor): + raise RuntimeError( + f"The output tensor is not on the symmetric heap. " + f"The Iris instance is on heap base {self.heap_bases[self.cur_rank]} " + f"but the output tensor is on heap base {tensor.data_ptr()}" + ) + if tensor.numel() != num_elements: + raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") + if tensor.dtype != dtype: + raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") + + def _throw_if_invalid_device(self, device): + """ + Throw a RuntimeError if the requested device is not compatible with this Iris instance. + + Args: + device: The requested device (can be string, torch.device, or None) + + Raises: + RuntimeError: If the device is not compatible + """ + if not self._is_valid_device(device): + raise RuntimeError( + f"Device mismatch: requested device {device} but Iris instance is on device {self.device}. " + f"Iris only supports tensors on its own device." + ) + + def _apply_layout(self, tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: + """ + Apply the requested layout to a tensor. + + Args: + tensor: The tensor to modify + layout: The desired layout + + Returns: + Tensor with the requested layout + """ + if layout == torch.strided: + # Strided layout is the default - no changes needed + return tensor + else: + # Only support strided layout for now + raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") + + def _tensor_on_device(self, tensor: torch.Tensor): + """ + Check if a tensor is on the same device as this Iris instance. + + Args: + tensor: The tensor to check + + Returns: + bool: True if tensor is on compatible device + """ + # Get the Iris device from memory_pool.device + iris_device = self.get_device() + tensor_device = tensor.device + + # For CUDA devices, check if they're compatible + if tensor_device.type == "cuda" and iris_device.type == "cuda": + if iris_device.index is None: + return True + return tensor_device.index == iris_device.index + + # For non-CUDA devices, they must be exactly equal + return tensor_device == iris_device + + def _on_symmetric_heap(self, tensor: torch.Tensor): + """ + Check if a tensor is allocated on the symmetric heap. + + Args: + tensor: The tensor to check + + Returns: + bool: True if tensor is on symmetric heap + """ + # Special case for empty tensors - they might not have a valid data_ptr + if tensor.numel() == 0: + self.debug("Empty tensor detected, skipping heap check") + return True + + # Convert CUDA pointer to integer for comparison + tensor_ptr = int(tensor.data_ptr()) + heap_base = int(self.heap_bases[self.cur_rank]) + + result = tensor_ptr >= heap_base and tensor_ptr < heap_base + self.heap_size + + return result + + def _is_valid_device(self, device) -> bool: + """ + Check if the requested device is compatible with this Iris instance. + + Args: + device: The requested device (can be string, torch.device, or None) + + Returns: + bool: True if the device is compatible, False otherwise + """ + if device is None: + return True # None means use default device + + # Convert device strings to torch.device objects for proper comparison + requested_device = torch.device(device) if isinstance(device, str) else device + iris_device = self.get_device() + + # Check if both are CUDA devices + if requested_device.type == "cuda" and iris_device.type == "cuda": + # Check if index matches or if requested is "cuda" (any index) + if requested_device.index is None: + return True + else: + return requested_device.index == iris_device.index + + # For non-CUDA devices, always return False + return False + + def get_heap_bases(self): + """ + Return the tensor of symmetric heap base addresses for all ranks. + + Returns: + torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` + on the Iris device. + """ + return self.heap_bases + + def barrier(self, stream=None): + """ + Synchronize all ranks and their CUDA devices. + + This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has + finished all queued work, then performs a global distributed barrier so that all + ranks reach the same point before proceeding. + + Args: + stream: If stream is given: wait only for that stream before barrier. + If stream is None: legacy behavior (device-wide sync). + """ + # Wait for all GPUs to finish work + if stream is None: + torch.cuda.synchronize() + else: + stream.synchronize() + + # Distributed barrier + distributed_barrier() + + def get_device(self): + """ + Get the underlying device where the Iris symmetric heap resides. + + Returns: + torch.device: The CUDA device of Iris-managed memory. + """ + return self.memory_pool.device + + def get_cu_count(self): + """ + Get the number of compute units (CUs) for the current GPU. + + Returns: + int: Number of compute units on this rank's GPU. + """ + return get_cu_count(self.gpu_id) + + def get_rank(self): + """ + Get this process's rank id in the distributed communicator. + + Returns: + int: Zero-based rank id of the current process. + """ + return self.cur_rank + + def get_num_ranks(self): + """ + Get the total number of ranks in the distributed communicator. + + Returns: + int: World size (number of ranks). + """ + return self.num_ranks + + def _apply_memory_format( + self, tensor: torch.Tensor, size: tuple, memory_format: torch.memory_format, input_tensor: torch.Tensor = None + ): + """ + Apply the requested memory format to a tensor by setting appropriate strides. + This keeps the tensor on the symmetric heap while changing how PyTorch interprets the memory layout. + + Args: + tensor: The tensor to modify + size: The tensor's size/dimensions + memory_format: The desired memory format + input_tensor: The original input tensor (needed for preserve_format detection) + """ + if memory_format == torch.contiguous_format: + # Default format, no changes needed + return tensor + elif memory_format == torch.channels_last and len(size) == 4: + # For channels_last format: preserve shape (N, C, H, W) but change strides + # channels_last strides: [C*H*W, 1, C*W, C] for shape (N, C, H, W) + N, C, H, W = size[0], size[1], size[2], size[3] + # Keep the original shape (N, C, H, W) but use channels_last strides + tensor = self._create_tensor_with_strides(tensor, size, (C * H * W, 1, C * W, C)) + return tensor + elif memory_format == torch.channels_last_3d and len(size) == 5: + # For channels_last_3d format: preserve shape (N, C, D, H, W) but change strides + # channels_last_3d strides: [C*D*H*W, 1, C*D*W, C*W, C] for shape (N, C, D, H, W) + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + # Keep the original shape (N, C, D, H, W) but use channels_last_3d strides + tensor = self._create_tensor_with_strides(tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) + return tensor + elif memory_format == torch.preserve_format: + # For preserve_format, we need to detect the input tensor's memory format + # and apply the same format to the output + if input_tensor is not None: + # Check the actual memory format of the input tensor + if len(size) == 4: + # Check if input tensor is in channels_last format by examining strides + # channels_last format has strides[1] == 1 (channels dimension is contiguous) + input_strides = input_tensor.stride() + if len(input_strides) == 4 and input_strides[1] == 1: + # Input is in channels_last format, preserve it + # Use the input tensor's actual shape, not the size parameter + input_shape = input_tensor.shape + if len(input_shape) == 4: + # Input is already in channels_last format (N, H, W, C) + new_size = input_shape + # Use the input tensor's strides directly + tensor = self._create_tensor_with_strides(tensor, new_size, input_strides) + return tensor + elif len(size) == 5: + # Check if input tensor is in channels_last_3d format + input_strides = input_tensor.stride() + if len(input_strides) == 5 and input_strides[1] == 1: + # Input is in channels_last_3d format, preserve it + # Use the input tensor's actual shape, not the size parameter + input_shape = input_tensor.shape + if len(input_shape) == 5: + # Input is already in channels_last_3d format (N, D, H, W, C) + new_size = input_shape + # Use the input tensor's strides directly + tensor = self._create_tensor_with_strides(tensor, new_size, input_strides) + return tensor + # If no special format detected or no input tensor provided, use contiguous format + return tensor + else: + # Unsupported format or dimension combination + self.debug( + f"Warning: Memory format {memory_format} not supported for {len(size)}D tensor, using contiguous format" + ) + # For unsupported formats, return the tensor as-is (contiguous) + return tensor + + def _create_tensor_with_strides(self, original_tensor: torch.Tensor, size: tuple, strides: tuple) -> torch.Tensor: + """ + Create a new tensor with the specified strides while keeping the data on the symmetric heap. + + Args: + original_tensor: The original tensor (source of data and heap allocation) + size: The tensor's size/dimensions + strides: The desired strides for the new memory format + + Returns: + A new tensor with the specified strides, data copied from original, on the same heap + """ + # First, create a temporary tensor with the correct strides using PyTorch + temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) + + # Handle different cases based on whether size changes and what the strides indicate + if size != original_tensor.shape: + # Size is different - this might be a format change that requires permutation + # Check if this is a channels_last format by comparing strides + if len(size) == 4: + # For channels_last: expected strides are [H*W*C, 1, W*C, C] for shape (N, H, W, C) + N, H, W, C = size[0], size[1], size[2], size[3] + expected_strides = (H * W * C, 1, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + else: + # If the size differs for other reasons, do not permute; just reshape if possible + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + elif len(size) == 5: + # For channels_last_3d: expected strides are [D*H*W*C, 1, H*W*C, W*C, C] for shape (N, D, H, W, C) + N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] + expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) + if strides == expected_strides: + permuted = original_tensor.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) + else: + # If the size differs for other reasons, do not permute; just reshape if possible + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + # For other dimensions, just try to reshape + try: + permuted = original_tensor.reshape(size) + except Exception: + raise ValueError( + "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." + ) + else: + # Size is the same - this is a stride-only change (like channels_last with preserved shape) + # We need to reorder the data to match the new stride pattern + if len(size) == 4: + # Check if this is channels_last format with preserved shape + N, C, H, W = size[0], size[1], size[2], size[3] + expected_strides = (C * H * W, 1, C * W, C) + if strides == expected_strides: + permuted = original_tensor + else: + permuted = original_tensor + elif len(size) == 5: + # Check if this is channels_last_3d format with preserved shape + N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] + expected_strides = (C * D * H * W, 1, C * D * W, C * W, C) + if strides == expected_strides: + permuted = original_tensor + else: + permuted = original_tensor + else: + permuted = original_tensor + + # Copy the permuted data to the temporary tensor + temp_tensor.copy_(permuted) + + # Now allocate a new tensor on our symmetric heap + num_elements = math.prod(size) + heap_tensor = self._allocate(num_elements, original_tensor.dtype) + + # Reshape to the desired size + heap_tensor = heap_tensor.reshape(size) + + # Copy the data from the temporary tensor to our heap tensor + heap_tensor.copy_(temp_tensor) + + # Clean up the temporary tensor + del temp_tensor + + # Now we need to create a view with the correct strides + # We can't use as_strided directly on our heap tensor, but we can + # create a new tensor with the right strides and copy the data again + final_tensor = torch.as_strided(heap_tensor, size, strides) + + return final_tensor + + def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Returns: + torch.Tensor: Zero-initialized tensor on the symmetric heap + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.zeros(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') + """ + return create_zeros( + self, *size, out=out, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Returns: + torch.Tensor: Ones-initialized tensor on the symmetric heap + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.ones(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') + """ + return create_ones(self, *size, out=out, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad) + + def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. + The tensor is allocated on the Iris symmetric heap. + + Args: + size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Returns: + torch.Tensor: Tensor filled with fill_value + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.full((2, 3), 3.14) + >>> print(tensor.shape) # torch.Size([2, 3]) + >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') + """ + return create_full( + self, size, fill_value, out=out, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad + ) + + def zeros_like( + self, + input, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + memory_format=torch.preserve_format, + ): + """ + Returns a tensor filled with the scalar value 0, with the same size as input, + allocated on the Iris symmetric heap. + + Args: + input (Tensor): the size of input will determine size of the output tensor. + + Keyword Arguments: + dtype (torch.dtype, optional): the desired data type of returned Tensor. + Default: if None, defaults to the dtype of input. + layout (torch.layout, optional): the desired layout of returned tensor. + Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). + device (torch.device, optional): the desired device of returned tensor. + Default: if None, defaults to the device of input. Must be compatible with this Iris instance. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. + Default: torch.preserve_format. + + Returns: + torch.Tensor: Zero-initialized tensor with same shape as input + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> input_tensor = ctx.ones(2, 3) + >>> zeros_tensor = ctx.zeros_like(input_tensor) + >>> print(zeros_tensor.shape) # torch.Size([2, 3]) + """ + return create_zeros_like( + self, + input, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + def arange( + self, start=0, end=None, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False + ): + """ + Returns a 1-D tensor of size ⌈(end - start) / step⌉ with values from the interval [start, end) + taken with common difference step beginning from start. The tensor is allocated on the symmetric heap. + + Note: When using floating-point dtypes (especially reduced precision types like bfloat16), + the results may be affected by floating-point rounding behavior. Some values in the sequence + might not be exactly representable in certain floating-point formats, which can lead to + repeated values or unexpected rounding. For precise sequences, it is recommended to use + integer dtypes instead of floating-point dtypes. + + Note that non-integer step is subject to floating point rounding errors when comparing + against end; to avoid inconsistency, we advise subtracting a small epsilon from end in such cases. + + Args: + start (Number, optional): the starting value for the set of points. Default: 0. + end (Number): the ending value for the set of points + step (Number, optional): the gap between each pair of adjacent points. Default: 1. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Returns: + torch.Tensor: 1-D tensor with evenly spaced values + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] + >>> print(tensor.shape) # torch.Size([5]) + """ + self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") + + # Handle the case where only one argument is provided (end) + if end is None: + end = start + start = 0 + + # Validate inputs + if step == 0: + raise ValueError("step must be non-zero") + + # Validate step direction consistency + if step > 0 and start >= end: + raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") + elif step < 0 and start <= end: + raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") + + # Calculate the number of elements + num_elements = math.ceil((end - start) / step) + + # Infer dtype if not provided + if dtype is None: + if any(isinstance(x, float) for x in [start, end, step]): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self._throw_if_invalid_device(device) + + if out is not None: + self._throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out + else: + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + + target_device = tensor.device + arange_tensor = torch.arange(start, end, step, dtype=dtype, device=target_device) + + tensor[:] = arange_tensor + + tensor = self._apply_layout(tensor, layout) + + if requires_grad: + tensor.requires_grad_() + + return tensor + + def randn( + self, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + ): + """ + Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Default: False. + + Returns: + torch.Tensor: Tensor filled with random numbers from normal distribution + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.randn(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self._throw_if_invalid_output_tensor(out, num_elements, dtype) + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + out.copy_(random_data) + tensor = out.view(size) + else: + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) + tensor.copy_(random_data) + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = self._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): + """ + Returns a tensor filled with random numbers from a uniform distribution, allocated on the Iris symmetric heap. + + Args: + size (int or tuple of ints): the size of the output tensor. + low (float, optional): the lower bound of the uniform distribution. Default: 0.0. + high (float, optional): the upper bound of the uniform distribution. Default: 1.0. + dtype (torch.dtype, optional): the desired data type of returned tensor. Default: torch.float. + + Returns: + torch.Tensor: A tensor filled with random numbers from a uniform distribution. + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") + size, num_elements = self._parse_size(size) + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + tensor.uniform_(low, high) + return tensor.reshape(size) + + def empty( + self, + *size, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + ): + """ + Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Default: False. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.contiguous_format. + + Returns: + torch.Tensor: Tensor with uninitialized data + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.empty(2, 3) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self._throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + tensor = tensor.reshape(size) + + # Apply the requested memory format + tensor = self._apply_memory_format(tensor, size, memory_format) + + # Apply the requested layout + tensor = self._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def randint( + self, *args, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False + ): + """ + Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). + The shape of the tensor is defined by the variable argument size. The tensor is allocated on the Iris symmetric heap. + + Args: + low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. + high (int): One above the highest integer to be drawn from the distribution. + size (tuple): a tuple defining the shape of the output tensor. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): if None, this function returns a tensor with dtype torch.int64. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Returns: + torch.Tensor: Tensor filled with random integers + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Parse arguments to determine low, high, and size + if len(args) == 2: + high, size = args + low = 0 + elif len(args) == 3: + low, high, size = args + else: + raise ValueError(f"randint expects 2 or 3 positional arguments, got {len(args)}") + + # Use default dtype if None is provided + if dtype is None: + dtype = torch.int64 + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self._throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + tensor = tensor.reshape(size) + + # Generate random integers using PyTorch's randint + target_device = device if device is not None else self.device + + # Handle generator parameter + if generator is not None: + torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=target_device) + else: + torch.randint(low, high, size, out=tensor, dtype=dtype, device=target_device) + + # Apply the requested layout + tensor = self._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive. + The tensor is allocated on the Iris symmetric heap. + + Args: + start (float or Tensor): the starting value for the set of points. + end (float or Tensor): the ending value for the set of points. + steps (int): size of the constructed tensor. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the data type to perform the computation in. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + + Returns: + torch.Tensor: 1-D tensor with evenly spaced values from start to end + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] + >>> print(tensor) + """ + self.debug( + f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Use global default dtype if None is provided + if dtype is None: + # Check if start or end are complex numbers + start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) + end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) + + if start_is_complex or end_is_complex: + dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 + else: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self._throw_if_invalid_device(device) + + # Parse steps and extract the integer value + if isinstance(steps, (tuple, list)): + if len(steps) == 1: + steps_int = steps[0] + if isinstance(steps_int, (tuple, list)): + steps_int = steps_int[0] + else: + size, num_elements = self._parse_size(steps) + steps_int = num_elements + else: + steps_int = steps + + steps_int = int(steps_int) + size = (steps_int,) + num_elements = steps_int + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self._throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + tensor = tensor.reshape(size) + + # Generate linspace using PyTorch's linspace + target_device = device if device is not None else self.device + torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=target_device) + + # Apply the requested layout + tensor = self._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def rand( + self, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + ): + """ + Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1). + The tensor is allocated on the Iris symmetric heap. + + Args: + *size (int...): a sequence of integers defining the shape of the output tensor. + + Keyword Arguments: + generator (torch.Generator, optional): a pseudorandom number generator for sampling. + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. + device (torch.device, optional): the desired device of returned tensor. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. + pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Default: False. + + Returns: + torch.Tensor: Tensor filled with random values in [0, 1) + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) + >>> print(tensor.shape) # torch.Size([2, 3]) + """ + self.debug( + f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" + ) + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = self.device + + # Validate device compatibility with Iris + self._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = self._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + self._throw_if_invalid_output_tensor(out, num_elements, dtype) + tensor = out.view(size) + else: + tensor = self._allocate(num_elements=num_elements, dtype=dtype) + tensor = tensor.reshape(size) + + # Generate random numbers using PyTorch's rand + if generator is not None: + torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) + else: + torch.rand(size, out=tensor, dtype=dtype, device=device) + + # Apply the requested layout + tensor = self._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + def get_device_context(self): + """ + Get the device context tensor for kernels. + + Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` + + This method is useful for both Gluon kernels and future Triton backends that + utilize aggregates for passing context information. + + Returns: + torch.Tensor: Encoded context data as int64 tensor on device + + Example: + >>> import iris # or: from iris.experimental import iris_gluon + >>> ctx = iris.Iris(1 << 20) # or: ctx = iris_gluon.IrisGluon(1 << 20) + >>> context_tensor = ctx.get_device_context() + >>> + >>> @gluon.jit + >>> def kernel(IrisDeviceCtx: gl.constexpr, context_tensor): + >>> ctx = IrisDeviceCtx.initialize(context_tensor) + >>> data = ctx.load(buffer, 1) + """ + # Convert heap_bases to a list for concatenation + heap_bases_list = self.heap_bases.tolist() + + # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] + context_data = [self.cur_rank, self.num_ranks] + heap_bases_list + context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) + + return context_tensor + + def get_backend(self): + """ + Legacy method for backward compatibility. + Use get_device_context() for kernel context. + + Returns: + torch.Tensor: Device context tensor + """ + return self.get_device_context() + + +class CCLBase: + """ + Base Collective Communication Library (CCL) interface. + + Provides collective operations that can be called as methods on Iris instances. + This base class contains common CCL operations shared by both Triton and Gluon backends. + """ + + def __init__(self, iris_instance): + """ + Initialize CCL with a reference to the parent Iris instance. + + Args: + iris_instance: The parent Iris instance (either Iris or IrisGluon) + """ + self._iris = iris_instance + + def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): + """ + All-to-all collective operation. + + Each rank sends a tensor chunk to each other rank and receives + a tensor chunk from each other rank. Input/output tensors should have + shape (M, N * world_size) where each chunk of N columns corresponds to one rank. + + Args: + output_tensor: Output tensor of shape (M, N * world_size) + input_tensor: Input tensor of shape (M, N * world_size) + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Set config.use_gluon=True to use Gluon implementation with traffic shaping. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.all_to_all(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) + """ + from iris.ccl.all_to_all import all_to_all as _all_to_all + + _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + + def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + + def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + + _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) diff --git a/iris/_tensor_ops.py b/iris/_tensor_ops.py new file mode 100644 index 00000000..e0db0b42 --- /dev/null +++ b/iris/_tensor_ops.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Common tensor operations for Iris implementations. + +This module contains shared tensor creation methods used by both +Triton and Gluon backends. +""" + +import torch + + +def create_zeros(iris_instance, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + iris_instance: The Iris instance (Iris or IrisGluon) + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Returns: + torch.Tensor: Zero-initialized tensor + """ + iris_instance.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = iris_instance.device + + # Validate device compatibility with Iris + iris_instance._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = iris_instance._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + iris_instance._throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with zeros + out.zero_() + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = iris_instance._allocate(num_elements=num_elements, dtype=dtype) + # Fill with zeros + tensor.zero_() + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = iris_instance._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + +def create_ones(iris_instance, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + """ + Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. + The tensor is allocated on the Iris symmetric heap. + + Args: + iris_instance: The Iris instance (Iris or IrisGluon) + *size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Returns: + torch.Tensor: Ones-initialized tensor + """ + iris_instance.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") + + # Use global default dtype if None is provided + if dtype is None: + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = iris_instance.device + + # Validate device compatibility with Iris + iris_instance._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = iris_instance._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + iris_instance._throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with ones + out.fill_(1) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = iris_instance._allocate(num_elements=num_elements, dtype=dtype) + # Fill with ones + tensor.fill_(1) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = iris_instance._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + +def create_full( + iris_instance, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False +): + """ + Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. + The tensor is allocated on the Iris symmetric heap. + + Args: + iris_instance: The Iris instance (Iris or IrisGluon) + size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. + fill_value (Scalar): the value to fill the output tensor with. + + Keyword Arguments: + out (Tensor, optional): the output tensor. + dtype (torch.dtype, optional): the desired data type of returned tensor. + Default: if None, uses a global default (see torch.set_default_dtype()). + layout (torch.layout, optional): the desired layout of returned Tensor. + Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. + device (torch.device, optional): the desired device of returned tensor. + Default: if None, uses the current device for the default tensor type. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + + Returns: + torch.Tensor: Tensor filled with fill_value + """ + iris_instance.debug( + f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Infer dtype from fill_value if not provided + if dtype is None: + if isinstance(fill_value, (int, float)): + if isinstance(fill_value, float): + dtype = torch.get_default_dtype() + else: + dtype = torch.int64 + else: + # For other types (like tensors), use their dtype + dtype = torch.get_default_dtype() + + # Use current device if none specified + if device is None: + device = iris_instance.device + + # Validate device compatibility with Iris + iris_instance._throw_if_invalid_device(device) + + # Parse size and calculate number of elements + size, num_elements = iris_instance._parse_size(size) + + # If out is provided, use it; otherwise allocate new tensor + if out is not None: + iris_instance._throw_if_invalid_output_tensor(out, num_elements, dtype) + # Fill with the specified value + out.fill_(fill_value) + # Create a reshaped view of the out tensor + tensor = out.view(size) + else: + tensor = iris_instance._allocate(num_elements=num_elements, dtype=dtype) + # Fill with the specified value + tensor.fill_(fill_value) + # Reshape to the desired size + tensor = tensor.reshape(size) + + # Apply the requested layout + tensor = iris_instance._apply_layout(tensor, layout) + + # Set requires_grad if specified + if requires_grad: + tensor.requires_grad_() + + return tensor + + +def create_zeros_like( + iris_instance, + input, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + memory_format=torch.preserve_format, +): + """ + Returns a tensor filled with the scalar value 0, with the same size as input, + allocated on the Iris symmetric heap. + + Args: + iris_instance: The Iris instance (Iris or IrisGluon) + input (Tensor): the size of input will determine size of the output tensor. + + Keyword Arguments: + dtype (torch.dtype, optional): the desired data type of returned Tensor. + Default: if None, defaults to the dtype of input. + layout (torch.layout, optional): the desired layout of returned tensor. + Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). + device (torch.device, optional): the desired device of returned tensor. + Default: if None, defaults to the device of input. Must be compatible with this Iris instance. + requires_grad (bool, optional): If autograd should record operations on the returned tensor. + Default: False. + memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. + Default: torch.preserve_format. + + Returns: + torch.Tensor: Zero-initialized tensor with same shape as input + """ + iris_instance.debug( + f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" + ) + + # Use input's properties as defaults if not specified + if dtype is None: + dtype = input.dtype + if layout is None: + layout = input.layout + if device is None: + device = input.device + + # Validate device compatibility with Iris + iris_instance._throw_if_invalid_device(device) + + # Get the size from input tensor + size = input.size() + num_elements = input.numel() + + # Allocate new tensor with the same size + new_tensor = iris_instance._allocate(num_elements, dtype) + new_tensor.zero_() + + # Reshape to match input size + new_tensor = new_tensor.reshape(size) + + # Apply the requested memory format + new_tensor = iris_instance._apply_memory_format(new_tensor, size, memory_format, input) + + # Apply the requested layout + new_tensor = iris_instance._apply_layout(new_tensor, layout) + + # Set requires_grad if specified + if requires_grad: + new_tensor.requires_grad_() + + return new_tensor diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index 3bf3b1fc..ea9bf702 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -39,28 +39,7 @@ import triton.language as tl -from iris._distributed_helpers import ( - init_distributed, - distributed_allgather, - distributed_barrier, - distributed_broadcast_scalar, - distributed_broadcast_tensor, -) -from iris.hip import ( - set_device, - get_cu_count, - count_devices, - get_ipc_handle, - open_ipc_handle, -) -import numpy as np -import math -import torch -import ctypes -import logging - -# Import logging functionality from the separate logging module -from ..logging import logger +from iris._common import IrisBase, CCLBase @aggregate @@ -469,7 +448,7 @@ def atomic_max(self, pointer, val, to_rank, mask=None, sem=None, scope=None): return gl.atomic_max(translated_ptr, val, mask=mask, sem=sem, scope=scope) -class IrisGluon: +class IrisGluon(IrisBase): """ Gluon-based Iris class for multi-GPU communication and memory management. @@ -486,682 +465,11 @@ class IrisGluon: """ def __init__(self, heap_size=1 << 30): - # Initialize (same as original Iris) - comm, cur_rank, num_ranks = init_distributed() - num_gpus = count_devices() - - gpu_id = cur_rank % num_gpus - set_device(gpu_id) - - self.comm = comm - self.num_ranks = num_ranks - self.cur_rank = cur_rank - self.gpu_id = gpu_id - self.heap_size = heap_size - self.heap_offset = 0 - self.alignment = 1024 - self.device = f"cuda:{gpu_id}" - self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) - - heap_base = self.memory_pool.data_ptr() - heap_base_ptr = ctypes.c_void_p(heap_base) - - heap_bases = np.zeros(num_ranks, dtype=np.uint64) - heap_bases[cur_rank] = heap_base - ipc_handles = np.zeros((num_ranks, 64), dtype=np.uint8) - ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank) - - distributed_barrier() - - all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8)) - all_heap_bases = distributed_allgather(np.array([heap_bases[cur_rank]], dtype=np.uint64)) - - distributed_barrier() - - ipc_heap_bases = np.zeros(num_ranks, dtype=np.uintp) - for rank in range(num_ranks): - if rank != cur_rank: - handle = open_ipc_handle(all_ipc_handles[rank], cur_rank) - ipc_heap_bases[rank] = int(handle) - else: - ipc_heap_bases[rank] = heap_bases[rank] - - for i in range(num_ranks): - self.debug(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") - - distributed_barrier() - self.heap_bases = torch.from_numpy(ipc_heap_bases).to(device=self.device, dtype=torch.uint64) - - distributed_barrier() + # Initialize base class + super().__init__(heap_size) # Initialize CCL interface - self.ccl = self.CCL(self) - - class CCL: - """ - Collective Communication Library (CCL) interface for IrisGluon. - - Provides collective operations that can be called as methods on the IrisGluon instance. - Example usage: - >>> shmem = iris_gluon.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) - """ - - def __init__(self, iris_instance): - """ - Initialize CCL with a reference to the parent IrisGluon instance. - - Args: - iris_instance: The parent IrisGluon instance - """ - self._iris = iris_instance - - def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): - """ - All-to-all collective operation. - - Each rank sends a tensor chunk to each other rank and receives - a tensor chunk from each other rank. Input/output tensors should have - shape (M, N * world_size) where each chunk of N columns corresponds to one rank. - - Args: - output_tensor: Output tensor of shape (M, N * world_size) - input_tensor: Input tensor of shape (M, N * world_size) - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - Set config.use_gluon=True to use Gluon implementation with traffic shaping. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - - Example: - >>> shmem = iris_gluon.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) - - >>> # Custom configuration with Gluon traffic shaping - >>> from iris.ccl import Config - >>> config = Config(use_gluon=True, block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) - """ - from iris.ccl.all_to_all import all_to_all as _all_to_all - - _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) - - def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): - """ - All-gather collective operation. - - Each rank sends its input tensor to all ranks, and all ranks receive - and concatenate all input tensors along dimension 0 (rows), matching - torch.distributed.all_gather_into_tensor behavior. - - Args: - output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs - input_tensor: Input tensor of shape (M, N) - local rank's data to send - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - - Example: - >>> shmem = iris_gluon.iris() - >>> # Input: (M, N), Output: (world_size * M, N) - >>> shmem.ccl.all_gather(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) - """ - from iris.ccl.all_gather import all_gather as _all_gather - - _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) - - def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): - """ - Reduce-scatter collective operation. - - Each rank reduces its assigned tiles from all ranks' inputs and stores - the result only to its own output tensor. This is similar to all-reduce - but without broadcasting the result to all ranks. - - Args: - output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank - input_tensor: Input tensor of shape (M, N) - local rank's partial data - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - Only supports reduce_scatter_variant="two_shot". - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - - Example: - >>> shmem = iris_gluon.iris() - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) - """ - from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter - - _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) - - def _log_with_rank(self, level, message): - """Helper method to log with rank information injected into the record.""" - extra = {"iris_rank": self.cur_rank, "iris_num_ranks": self.num_ranks} - logger.log(level, message, extra=extra) - - def debug(self, message): - """Log a debug message with rank information.""" - self._log_with_rank(logging.DEBUG, message) - - def info(self, message): - """Log an info message with rank information.""" - self._log_with_rank(logging.INFO, message) - - def warning(self, message): - """Log a warning message with rank information.""" - self._log_with_rank(logging.WARNING, message) - - def error(self, message): - """Log an error message with rank information.""" - self._log_with_rank(logging.ERROR, message) - - def get_device_context(self): - """ - Get the device context tensor for Gluon kernels. - - Returns a tensor encoding: `[cur_rank, num_ranks, heap_base_0, heap_base_1, ...]` - - Returns: - torch.Tensor: Encoded context data as int64 tensor on device - - Example: - >>> ctx = iris_gluon.iris() - >>> context_tensor = ctx.get_device_context() - >>> - >>> @gluon.jit - >>> def kernel(IrisDeviceCtx: gl.constexpr, context_tensor): - >>> ctx = IrisDeviceCtx.initialize(context_tensor) - >>> data = ctx.load(buffer, 1) - """ - # Convert heap_bases to a list for concatenation - heap_bases_list = self.heap_bases.tolist() - - # Create context tensor: [cur_rank, num_ranks, heap_base_0, heap_base_1, ...] - context_data = [self.cur_rank, self.num_ranks] + heap_bases_list - context_tensor = torch.tensor(context_data, dtype=torch.int64, device=self.device) - - return context_tensor - - def get_backend(self): - """ - Legacy method for backward compatibility. - Use get_device_context() for Gluon kernels. - - Returns: - torch.Tensor: Device context tensor - """ - return self.get_device_context() - - def get_heap_bases(self): - """ - Return the tensor of symmetric heap base addresses for all ranks. - - Returns: - torch.Tensor: A 1D tensor of uint64 heap base addresses - """ - return self.heap_bases - - def barrier(self): - """ - Synchronize all ranks using a distributed barrier. - """ - distributed_barrier() - - def get_device(self): - """ - Get the underlying device where the Iris symmetric heap resides. - - Returns: - torch.device: The CUDA device of Iris-managed memory - """ - return self.memory_pool.device - - def get_cu_count(self): - """ - Get the number of compute units (CUs) for the current GPU. - - Returns: - int: Number of compute units on this rank's GPU - """ - return get_cu_count(self.gpu_id) - - def get_rank(self): - """ - Get the current rank ID. - - Returns: - int: The current rank ID - """ - return self.cur_rank - - def get_num_ranks(self): - """ - Get the total number of ranks. - - Returns: - int: The total number of ranks in the distributed system - """ - return self.num_ranks - - def broadcast(self, data, src_rank=0): - """ - Broadcast data from source rank to all ranks. - - Args: - data: Data to broadcast (scalar or tensor) - src_rank: Source rank for broadcast (default: 0) - - Returns: - The broadcasted data - """ - # Check if the value on src_rank is a tensor or array-like - if self.cur_rank == src_rank and data is not None: - # Explicitly exclude strings and non-numeric types - if isinstance(data, (str, dict, bool)): - is_tensor = False - elif isinstance(data, torch.Tensor): - is_tensor = True - elif isinstance(data, np.ndarray): - is_tensor = True - elif isinstance(data, (list, tuple)): - # Try to convert list/tuple to tensor to check if it's numeric - try: - torch.as_tensor(data) - is_tensor = True - except (TypeError, ValueError): - is_tensor = False - else: - # For other types, try to convert and check - try: - test_array = np.asarray(data) - # Check if it's a numeric dtype that torch can handle - if np.issubdtype(test_array.dtype, np.number): - torch.as_tensor(test_array) - is_tensor = True - else: - is_tensor = False - except (TypeError, ValueError): - is_tensor = False - else: - is_tensor = False - - # Broadcast the type decision to all ranks - is_tensor = distributed_broadcast_scalar(is_tensor, src_rank) - - if is_tensor: - return distributed_broadcast_tensor(data, root=src_rank) - else: - return distributed_broadcast_scalar(data, src_rank) - - def __allocate(self, num_elements, dtype): - """Internal method to allocate memory from the symmetric heap.""" - self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") - - element_size = torch.tensor([], dtype=dtype).element_size() - size_in_bytes = num_elements * element_size - aligned_size = math.ceil(size_in_bytes / self.alignment) * self.alignment - - if self.heap_offset + aligned_size > self.heap_size: - raise MemoryError("Heap out of memory") - - start = self.heap_offset - self.heap_offset += aligned_size - - sub_buffer = self.memory_pool[start : start + size_in_bytes].view(dtype) - return sub_buffer.reshape((num_elements,)) - - def __parse_size(self, size): - """Parse size parameter and calculate number of elements.""" - # Handle nested tuples/lists by flattening them recursively - while len(size) == 1 and isinstance(size[0], (tuple, list)): - size = size[0] - num_elements = math.prod(size) - return size, num_elements - - def __throw_if_invalid_device(self, device): - """Check if the requested device is compatible with this Iris instance.""" - if not self.__is_valid_device(device): - raise ValueError( - f"Requested device {device} does not match Iris device {self.get_device()}. " - f"All Iris tensors must be on the same device as the Iris symmetric heap." - ) - - def __is_valid_device(self, device) -> bool: - """Check if the requested device is compatible with this Iris instance.""" - if device is None: - return True # None means use default device - - # Convert device strings to torch.device objects for proper comparison - requested_device = torch.device(device) if isinstance(device, str) else device - iris_device = self.get_device() - - # Check if both are CUDA devices - if requested_device.type == "cuda" and iris_device.type == "cuda": - # Check if index matches or if requested is "cuda" (any index) - if requested_device.index is None: - return True - else: - return requested_device.index == iris_device.index - - # For non-CUDA devices, always return False - return False - - def __apply_layout(self, tensor, layout): - """Apply the requested layout to the tensor.""" - if layout == torch.strided: - return tensor - else: - raise ValueError(f"Unsupported layout: {layout}") - - def zeros( - self, - *size, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - ): - """ - Create a tensor filled with zeros on the symmetric heap. - - Args: - size: Shape of the tensor - dtype: Data type (default: torch.float32) - device: Device (must match Iris device) - layout: Layout (default: torch.strided) - requires_grad: Whether to track gradients - - Returns: - torch.Tensor: Zero-initialized tensor on the symmetric heap - """ - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # Allocate memory from symmetric heap - tensor = self.__allocate(num_elements, dtype) - - # Zero-initialize - tensor.zero_() - - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def ones( - self, - *size, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - ): - """ - Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris_gluon.iris(1 << 20) - >>> tensor = ctx.ones(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') - """ - self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with ones - out.fill_(1) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with ones - tensor.fill_(1) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def full( - self, - size, - fill_value, - *, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - ): - """ - Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. - The tensor is allocated on the Iris symmetric heap. - - Args: - size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. - fill_value (Scalar): the value to fill the output tensor with. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris_gluon.iris(1 << 20) - >>> tensor = ctx.full((2, 3), 3.14) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') - """ - self.debug( - f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Infer dtype from fill_value if not provided - if dtype is None: - if isinstance(fill_value, (int, float)): - if isinstance(fill_value, float): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - else: - # For other types (like tensors), use their dtype - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with the specified value - out.fill_(fill_value) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with the specified value - tensor.fill_(fill_value) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def zeros_like( - self, - input, - *, - dtype=None, - layout=None, - device=None, - requires_grad=False, - memory_format=torch.preserve_format, - ): - """ - Returns a tensor filled with the scalar value 0, with the same size as input, allocated on the Iris symmetric heap. - - Args: - input (Tensor): the size of input will determine size of the output tensor. - - Keyword Arguments: - dtype (torch.dtype, optional): the desired data type of returned Tensor. - Default: if None, defaults to the dtype of input. - layout (torch.layout, optional): the desired layout of returned tensor. - Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). - device (torch.device, optional): the desired device of returned tensor. - Default: if None, defaults to the device of input. Must be compatible with this Iris instance. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. - Default: torch.preserve_format. - - Example: - >>> ctx = iris_gluon.iris(1 << 20) - >>> input_tensor = ctx.ones(2, 3) - >>> zeros_tensor = ctx.zeros_like(input_tensor) - >>> print(zeros_tensor.shape) # torch.Size([2, 3]) - """ - self.debug( - f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Use input's properties as defaults if not specified - if dtype is None: - dtype = input.dtype - if layout is None: - layout = input.layout - if device is None: - device = input.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Get the size from input tensor - size = input.size() - num_elements = input.numel() - - # Allocate new tensor with the same size - new_tensor = self.__allocate(num_elements, dtype) - new_tensor.zero_() - - # Reshape to match input size - new_tensor = new_tensor.reshape(size) - - # Apply the requested layout - new_tensor = self.__apply_layout(new_tensor, layout) - - # Set requires_grad if specified - if requires_grad: - new_tensor.requires_grad_() - - return new_tensor - - def __throw_if_invalid_output_tensor(self, out, num_elements, dtype): - """Check if the output tensor is valid.""" - if out.numel() != num_elements: - raise RuntimeError(f"The output tensor has {out.numel()} elements, but {num_elements} are required") - - if out.dtype != dtype: - raise RuntimeError(f"The output tensor has dtype {out.dtype}, but {dtype} is required") - - if not self.__on_symmetric_heap(out): - raise RuntimeError("The output tensor is not on the symmetric heap") - - def __on_symmetric_heap(self, tensor): - """Check if tensor is allocated on the symmetric heap.""" - heap_start = self.memory_pool.data_ptr() - heap_end = heap_start + self.heap_size - tensor_ptr = tensor.data_ptr() - return heap_start <= tensor_ptr < heap_end + self.ccl = CCLBase(self) def iris(heap_size=1 << 30): diff --git a/iris/iris.py b/iris/iris.py index 156950e4..223aacd3 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -25,32 +25,10 @@ import triton import triton.language as tl -from iris._distributed_helpers import ( - init_distributed, - distributed_allgather, - distributed_barrier, - distributed_broadcast_scalar, - distributed_broadcast_tensor, -) -from iris.hip import ( - set_device, - get_cu_count, - count_devices, - get_ipc_handle, - open_ipc_handle, - get_ipc_handle_size, -) -import numpy as np -import math -import torch -import ctypes -import logging - -# Import logging functionality from the separate logging module -from .logging import logger - - -class Iris: +from iris._common import IrisBase, CCLBase + + +class Iris(IrisBase): """ Main Iris class for multi-GPU communication and memory management. @@ -67,1519 +45,26 @@ class Iris: """ def __init__(self, heap_size=1 << 30): - # Initialize - comm, cur_rank, num_ranks = init_distributed() - num_gpus = count_devices() - - gpu_id = cur_rank % num_gpus - set_device(gpu_id) - - self.comm = comm - self.num_ranks = num_ranks - self.cur_rank = cur_rank - self.gpu_id = gpu_id - self.heap_size = heap_size - self.heap_offset = 0 - self.alignment = 1024 - self.device = f"cuda:{gpu_id}" - self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) - - heap_base = self.memory_pool.data_ptr() - heap_base_ptr = ctypes.c_void_p(heap_base) - - heap_bases = np.zeros(num_ranks, dtype=np.uint64) - heap_bases[cur_rank] = heap_base - ipc_handle_size = get_ipc_handle_size() - ipc_handles = np.zeros((num_ranks, ipc_handle_size), dtype=np.uint8) - ipc_handle = get_ipc_handle(heap_base_ptr, cur_rank) - - distributed_barrier() - - all_ipc_handles = distributed_allgather(np.frombuffer(ipc_handle, dtype=np.uint8).copy()) - heap_base_bytes = np.array([heap_bases[cur_rank]], dtype=np.uint64).tobytes() - all_heap_bases_bytes = distributed_allgather(np.frombuffer(heap_base_bytes, dtype=np.uint8).copy()) - all_heap_bases = np.frombuffer(all_heap_bases_bytes.tobytes(), dtype=np.uint64).reshape(num_ranks, -1) - - distributed_barrier() - - ipc_heap_bases = np.zeros(num_ranks, dtype=np.uintp) - for rank in range(num_ranks): - if rank != cur_rank: - handle = open_ipc_handle(all_ipc_handles[rank], cur_rank) - ipc_heap_bases[rank] = int(handle) - else: - ipc_heap_bases[rank] = heap_bases[rank] - - for i in range(num_ranks): - self.debug(f"GPU {i}: Heap base {hex(int(ipc_heap_bases[i]))}") - - distributed_barrier() - self.heap_bases = torch.from_numpy(ipc_heap_bases).to(device=self.device, dtype=torch.uint64) - - distributed_barrier() + # Initialize base class + super().__init__(heap_size) # Initialize CCL interface self.ccl = self.CCL(self) - def _log_with_rank(self, level, message): - """Helper method to log with rank information injected into the record.""" - if logger.isEnabledFor(level): - record = logging.LogRecord( - name=logger.name, level=level, pathname="", lineno=0, msg=message, args=(), exc_info=None - ) - # Inject rank information into the record - record.iris_rank = self.cur_rank - record.iris_num_ranks = self.num_ranks - logger.handle(record) - - def debug(self, message): - """ - Log a debug message with rank information. - - Args: - message (str): Human-readable message to log at debug level. - - Notes: - The log record is enriched with ``iris_rank`` and ``iris_num_ranks`` so - formatters can display the originating rank and world size. - - Example: - >>> ctx = iris.iris() - >>> iris.set_logger_level(iris.DEBUG) - >>> ctx.debug("Allocating buffers") # [Iris] [0/1] Allocating buffers - """ - self._log_with_rank(logging.DEBUG, message) - - def info(self, message): - """ - Log an info message with rank information. - - Args: - message (str): Human-readable message to log at info level. - - Example: - >>> ctx = iris.iris() - >>> ctx.info("Starting iteration 0") # [Iris] [0/1] Starting iteration 0 - """ - self._log_with_rank(logging.INFO, message) - - def warning(self, message): - """ - Log a warning message with rank information. - - Args: - message (str): Human-readable message to log at warning level. - - Example: - >>> ctx = iris.iris() - >>> ctx.warning("Memory usage is high") # [Iris] [0/1] Memory usage is high - """ - self._log_with_rank(logging.WARNING, message) - - def error(self, message): - """ - Log an error message with rank information. - - Args: - message (str): Human-readable message to log at error level. - - Example: - >>> ctx = iris.iris() - >>> ctx.error("Failed to allocate memory") # [Iris] [0/1] Failed to allocate memory - """ - self._log_with_rank(logging.ERROR, message) - - def broadcast(self, value, source_rank): - """ - Broadcast a value from one rank to all ranks. - - This method automatically detects the type of value and uses the appropriate - broadcast mechanism: - - For tensors and arrays: uses efficient PyTorch distributed tensor collectives - - For scalars and other objects: uses object broadcast - - Args: - value (Any): The value to broadcast. Can be a scalar, tensor, numpy array, - or any picklable object. Only the ``source_rank`` value is used; - other ranks should pass a placeholder (e.g., ``None``). - source_rank (int): Rank id that holds the authoritative value. - - Returns: - Any: The value broadcast to all ranks. Tensors and arrays are returned as - numpy arrays; scalars and objects are returned in their original type. - - Examples: - >>> ctx = iris.iris() - >>> # Broadcasting a scalar - >>> value = 42 if ctx.cur_rank == 0 else None - >>> value = ctx.broadcast(value, source_rank=0) # All ranks get 42 - >>> - >>> # Broadcasting a tensor - >>> if ctx.cur_rank == 0: - >>> data = torch.randn(10, 10) - >>> else: - >>> data = None - >>> data = ctx.broadcast(data, source_rank=0) # All ranks get the same array - """ - # Check if the value on source_rank is a tensor or array-like - if self.cur_rank == source_rank and value is not None: - # Explicitly exclude strings and non-numeric types - if isinstance(value, (str, dict, bool)): - is_tensor = False - elif isinstance(value, torch.Tensor): - is_tensor = True - elif isinstance(value, np.ndarray): - is_tensor = True - elif isinstance(value, (list, tuple)): - # Try to convert list/tuple to tensor to check if it's numeric - try: - torch.as_tensor(value) - is_tensor = True - except (TypeError, ValueError): - is_tensor = False - else: - # For other types, try to convert and check - try: - test_array = np.asarray(value) - # Check if it's a numeric dtype that torch can handle - if np.issubdtype(test_array.dtype, np.number): - torch.as_tensor(test_array) - is_tensor = True - else: - is_tensor = False - except (TypeError, ValueError): - is_tensor = False - else: - is_tensor = False - - # Broadcast the type decision to all ranks - is_tensor = distributed_broadcast_scalar(is_tensor, source_rank) - - if is_tensor: - return distributed_broadcast_tensor(value, root=source_rank) - else: - return distributed_broadcast_scalar(value, source_rank) - - def __allocate(self, num_elements, dtype): - self.debug(f"allocate: num_elements = {num_elements}, dtype = {dtype}") - - element_size = torch.tensor([], dtype=dtype).element_size() - size_in_bytes = num_elements * element_size - aligned_size = math.ceil(size_in_bytes / self.alignment) * self.alignment - - if self.heap_offset + aligned_size > self.heap_size: - raise MemoryError("Heap out of memory") - - start = self.heap_offset - self.heap_offset += aligned_size - - sub_buffer = self.memory_pool[start : start + size_in_bytes].view(dtype) - return sub_buffer.reshape((num_elements,)) - - def __parse_size(self, size): - # Handle nested tuples/lists by flattening them recursively - while len(size) == 1 and isinstance(size[0], (tuple, list)): - size = size[0] - num_elements = math.prod(size) - return size, num_elements - - def zeros_like( - self, input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format - ): - """ - Returns a tensor filled with the scalar value 0, with the same size as input, allocated on the Iris symmetric heap. - - Args: - input (Tensor): the size of input will determine size of the output tensor. - - Keyword Arguments: - dtype (torch.dtype, optional): the desired data type of returned Tensor. - Default: if None, defaults to the dtype of input. - layout (torch.layout, optional): the desired layout of returned tensor. - Default: if None, defaults to the layout of input. Note: Iris tensors are always contiguous (strided). - device (torch.device, optional): the desired device of returned tensor. - Default: if None, defaults to the device of input. Must be compatible with this Iris instance. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. - Default: torch.preserve_format. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> input_tensor = ctx.ones(2, 3) - >>> zeros_tensor = ctx.zeros_like(input_tensor) - >>> print(zeros_tensor.shape) # torch.Size([2, 3]) - """ - self.debug( - f"zeros_like: input_shape = {input.shape}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Use input's properties as defaults if not specified - if dtype is None: - dtype = input.dtype - if layout is None: - layout = input.layout - if device is None: - device = input.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Get the size from input tensor - size = input.size() - num_elements = input.numel() - - # Allocate new tensor with the same size - new_tensor = self.__allocate(num_elements, dtype) - new_tensor.zero_() - - # Reshape to match input size - new_tensor = new_tensor.reshape(size) - - # Apply the requested memory format - new_tensor = self.__apply_memory_format(new_tensor, size, memory_format, input) - - # Apply the requested layout - new_tensor = self.__apply_layout(new_tensor, layout) - - # Set requires_grad if specified - if requires_grad: - new_tensor.requires_grad_() - - return new_tensor - - def arange( - self, start=0, end=None, step=1, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False - ): - """ - Returns a 1-D tensor of size ⌈(end - start) / step⌉ with values from the interval [start, end) - taken with common difference step beginning from start. The tensor is allocated on the symmetric heap. - - Note: When using floating-point dtypes (especially reduced precision types like bfloat16), - the results may be affected by floating-point rounding behavior. Some values in the sequence - might not be exactly representable in certain floating-point formats, which can lead to - repeated values or unexpected rounding. For precise sequences, it is recommended to use - integer dtypes instead of floating-point dtypes. - - Note that non-integer step is subject to floating point rounding errors when comparing - against end; to avoid inconsistency, we advise subtracting a small epsilon from end in such cases. - - Args: - start (Number, optional): the starting value for the set of points. Default: 0. - end (Number): the ending value for the set of points - step (Number, optional): the gap between each pair of adjacent points. Default: 1. - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.get_default_dtype()). - If dtype is not given, infer the data type from the other input arguments. - If any of start, end, or step are floating-point, the dtype is inferred - be the default dtype, see get_default_dtype(). Otherwise, the dtype is inferred - to be torch.int64. - layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. - Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.arange(0, 10, 2) # [0, 2, 4, 6, 8] - >>> print(tensor.shape) # torch.Size([5]) - """ - self.debug(f"arange: start = {start}, end = {end}, step = {step}, dtype = {dtype}, device = {device}") - - # Handle the case where only one argument is provided (end) - if end is None: - end = start - start = 0 - - # Validate inputs - if step == 0: - raise ValueError("step must be non-zero") - - # Validate step direction consistency - if step > 0 and start >= end: - raise ValueError(f"Invalid range: start >= end with positive step (start={start}, end={end}, step={step})") - elif step < 0 and start <= end: - raise ValueError(f"Invalid range: start <= end with negative step (start={start}, end={end}, step={step})") - - # Calculate the number of elements - num_elements = math.ceil((end - start) / step) - - # Infer dtype if not provided - if dtype is None: - if any(isinstance(x, float) for x in [start, end, step]): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - tensor = out - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - - target_device = tensor.device - arange_tensor = torch.arange(start, end, step, dtype=dtype, device=target_device) - - tensor[:] = arange_tensor - - tensor = self.__apply_layout(tensor, layout) - - if requires_grad: - tensor.requires_grad_() - - return tensor - - def zeros(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.zeros(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([0., 0., 0.], device='cuda:0') - """ - self.debug(f"zeros: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with zeros - out.zero_() - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with zeros - tensor.zero_() - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def randn( - self, - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - ): - """ - Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 - (also called the standard normal distribution). The tensor is allocated on the Iris symmetric heap. - - .. math:: - \\text{out}_i \\sim \\mathcal{N}(0, 1) - - For complex dtypes, the tensor is i.i.d. sampled from a complex normal distribution with zero mean - and unit variance as - - .. math:: - \\text{out}_i \\sim \\mathcal{CN}(0, 1) - - This is equivalent to separately sampling the real :math:`(\\text{Re})` and imaginary :math:`(\\text{Im})` - part of :math:`\\text{out}_i` as - - .. math:: - \\text{Re}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}), \\quad \\text{Im}(\\text{out}_i) \\sim \\mathcal{N}(0, \\frac{1}{2}) - - The shape of the tensor is defined by the variable argument size. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - generator (torch.Generator, optional): a pseudorandom number generator for sampling - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type (see torch.set_default_device()). - device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. - Works only for CPU tensors. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.randn(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([ 0.3982, -0.0059, -0.4365], device='cuda:0') - """ - self.debug( - f"randn: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" - ) - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Generate random data and copy to out tensor - random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) - out.copy_(random_data) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Generate random data and copy to tensor - random_data = torch.randn(num_elements, generator=generator, dtype=dtype, device=device, layout=layout) - tensor.copy_(random_data) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def ones(self, *size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.ones(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([1., 1., 1.], device='cuda:0') - """ - self.debug(f"ones: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with ones - out.fill_(1) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with ones - tensor.fill_(1) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def full(self, size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Creates a tensor of size size filled with fill_value. The tensor's dtype is inferred from fill_value. - The tensor is allocated on the Iris symmetric heap. - - Args: - size (int...): a list, tuple, or torch.Size of integers defining the shape of the output tensor. - fill_value (Scalar): the value to fill the output tensor with. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.full((2, 3), 3.14) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([3.1400, 3.1400, 3.1400], device='cuda:0') - """ - self.debug( - f"full: size = {size}, fill_value = {fill_value}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Infer dtype from fill_value if not provided - if dtype is None: - if isinstance(fill_value, (int, float)): - if isinstance(fill_value, float): - dtype = torch.get_default_dtype() - else: - dtype = torch.int64 - else: - # For other types (like tensors), use their dtype - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Fill with the specified value - out.fill_(fill_value) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Fill with the specified value - tensor.fill_(fill_value) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def uniform(self, size, low=0.0, high=1.0, dtype=torch.float): - """ - Returns a tensor filled with random numbers from a uniform distribution, allocated on the Iris symmetric heap. - - Args: - size (int or tuple of ints): the size of the output tensor. - low (float, optional): the lower bound of the uniform distribution. Default: 0.0. - high (float, optional): the upper bound of the uniform distribution. Default: 1.0. - dtype (torch.dtype, optional): the desired data type of returned tensor. Default: torch.float. - - Returns: - Tensor: A tensor filled with random numbers from a uniform distribution. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.uniform((2, 3), low=0.0, high=1.0) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') - """ - self.debug(f"uniform: size = {size}, low = {low}, high = {high}, dtype = {dtype}") - size, num_elements = self.__parse_size(size) - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - tensor.uniform_(low, high) - return tensor.reshape(size) - - def empty( - self, - *size, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - ): - """ - Returns a tensor filled with uninitialized data. The shape of the tensor is defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Note: - If torch.use_deterministic_algorithms() and torch.utils.deterministic.fill_uninitialized_memory are both set to True, - the output tensor is initialized to prevent any possible nondeterministic behavior from using the data as an input to an operation. - Floating point and complex tensors are filled with NaN, and integer tensors are filled with the maximum value. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. - Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. - memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. - Default: torch.contiguous_format. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.empty(2, 3) - >>> print(tensor.shape) # torch.Size([2, 3]) - """ - self.debug( - f"empty: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" - ) - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Apply the requested memory format - tensor = self.__apply_memory_format(tensor, size, memory_format) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def randint( - self, *args, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False - ): - """ - Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive). - The shape of the tensor is defined by the variable argument size. - The tensor is allocated on the Iris symmetric heap. - - Note: - With the global dtype default (torch.float32), this function returns a tensor with dtype torch.int64. - - Args: - low (int, optional): Lowest integer to be drawn from the distribution. Default: 0. - high (int): One above the highest integer to be drawn from the distribution. - size (tuple): a tuple defining the shape of the output tensor. - - Keyword Arguments: - generator (torch.Generator, optional): a pseudorandom number generator for sampling. - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): if None, this function returns a tensor with dtype torch.int64. - layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. - device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.randint(0, 10, (2, 3)) # Random integers [0, 10) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([7, 2, 9], device='cuda:0') - """ - self.debug(f"randint: args = {args}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}") - - # Parse arguments to determine low, high, and size - # PyTorch randint signatures: - # randint(high, size) - where high is the upper bound and size is the shape - # randint(low, high, size) - where low and high are bounds, size is the shape - if len(args) == 2: - # randint(high, size) - high, size = args - low = 0 - elif len(args) == 3: - # randint(low, high, size) - low, high, size = args - else: - raise ValueError(f"randint expects 2 or 3 positional arguments, got {len(args)}") - - # Use default dtype if None is provided - if dtype is None: - dtype = torch.int64 - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate random integers using PyTorch's randint - # Use specified device or fall back to current device - target_device = device if device is not None else self.device - - # Handle generator parameter - if generator is not None: - torch.randint(low, high, size, generator=generator, out=tensor, dtype=dtype, device=target_device) - else: - torch.randint(low, high, size, out=tensor, dtype=dtype, device=target_device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def linspace(self, start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - """ - Creates a one-dimensional tensor of size steps whose values are evenly spaced from start to end, inclusive. - The tensor is allocated on the Iris symmetric heap. - - The values are: - (start, start + (end-start)/(steps-1), ..., start + (steps-2)*(end-start)/(steps-1), end) - - Args: - start (float or Tensor): the starting value for the set of points. If Tensor, it must be 0-dimensional. - end (float or Tensor): the ending value for the set of points. If Tensor, it must be 0-dimensional. - steps (int): size of the constructed tensor. - - Keyword Arguments: - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the data type to perform the computation in. - Default: if None, uses the global default dtype when both start and end are real, - and corresponding complex dtype when either is complex. - layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided. - device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.linspace(0, 10, 5) # [0, 2.5, 5, 7.5, 10] - >>> print(tensor) # tensor([ 0.0000, 2.5000, 5.0000, 7.5000, 10.0000], device='cuda:0') - """ - self.debug( - f"linspace: start = {start}, end = {end}, steps = {steps}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}" - ) - - # Use global default dtype if None is provided - if dtype is None: - # Check if start or end are complex numbers - start_is_complex = isinstance(start, complex) or (hasattr(start, "dtype") and torch.is_complex(start)) - end_is_complex = isinstance(end, complex) or (hasattr(end, "dtype") and torch.is_complex(end)) - - if start_is_complex or end_is_complex: - # Infer complex dtype based on default dtype - dtype = torch.complex64 if torch.get_default_dtype() == torch.float32 else torch.complex128 - else: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse steps and extract the integer value - if isinstance(steps, (tuple, list)): - if len(steps) == 1: - # Single-element tuple/list like (5,) or [5] - steps_int = steps[0] - # Handle nested tuples like ((5,),) - if isinstance(steps_int, (tuple, list)): - steps_int = steps_int[0] - else: - # Multi-element tuple/list - use __parse_size for compatibility - size, num_elements = self.__parse_size(steps) - steps_int = num_elements - else: - # steps is a single integer - steps_int = steps - - # Ensure steps_int is an integer - steps_int = int(steps_int) - size = (steps_int,) - num_elements = steps_int - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate linspace using PyTorch's linspace - # Use specified device or fall back to current device - target_device = device if device is not None else self.device - torch.linspace(start, end, steps_int, out=tensor, dtype=dtype, device=target_device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - - def rand( - self, - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - ): - """ - Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1). - The tensor is allocated on the Iris symmetric heap. - - Args: - *size (int...): a sequence of integers defining the shape of the output tensor. - Can be a variable number of arguments or a collection like a list or tuple. - - Keyword Arguments: - generator (torch.Generator, optional): a pseudorandom number generator for sampling. - out (Tensor, optional): the output tensor. - dtype (torch.dtype, optional): the desired data type of returned tensor. - Default: if None, uses a global default (see torch.set_default_dtype()). - layout (torch.layout, optional): the desired layout of returned Tensor. - Default: torch.strided. Note: Iris tensors always use `torch.strided` regardless of this parameter. - device (torch.device, optional): the desired device of returned tensor. - Default: if None, uses the current device for the default tensor type. - requires_grad (bool, optional): If autograd should record operations on the returned tensor. - Default: False. - pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. - Works only for CPU tensors. Default: False. Note: Iris tensors are always on GPU. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> tensor = ctx.rand(2, 3) # Random values in [0, 1) - >>> print(tensor.shape) # torch.Size([2, 3]) - >>> print(tensor[0]) # tensor([0.1234, 0.5678, 0.9012], device='cuda:0') - """ - self.debug( - f"rand: size = {size}, dtype = {dtype}, device = {device}, requires_grad = {requires_grad}, pin_memory = {pin_memory}" - ) - - # Use global default dtype if None is provided - if dtype is None: - dtype = torch.get_default_dtype() - - # Use current device if none specified - if device is None: - device = self.device - - # Validate device compatibility with Iris - self.__throw_if_invalid_device(device) - - # Parse size and calculate number of elements - size, num_elements = self.__parse_size(size) - - # If out is provided, use it; otherwise allocate new tensor - if out is not None: - self.__throw_if_invalid_output_tensor(out, num_elements, dtype) - # Create a reshaped view of the out tensor - tensor = out.view(size) - else: - tensor = self.__allocate(num_elements=num_elements, dtype=dtype) - # Reshape to the desired size - tensor = tensor.reshape(size) - - # Generate random numbers using PyTorch's rand - # Use specified device (already validated and set above) - - # Handle generator parameter - if generator is not None: - torch.rand(size, generator=generator, out=tensor, dtype=dtype, device=device) - else: - torch.rand(size, out=tensor, dtype=dtype, device=device) - - # Apply the requested layout - tensor = self.__apply_layout(tensor, layout) - - # Set requires_grad if specified - if requires_grad: - tensor.requires_grad_() - - return tensor - def __deallocate(self, pointer): pass - def get_heap_bases(self): - """ - Return the tensor of symmetric heap base addresses for all ranks. - - Returns: - torch.Tensor: A 1D tensor of ``uint64`` heap base addresses of size ``num_ranks`` - on the Iris device. Pass this to device-side Triton kernels that require - heap translation. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> heap_bases = ctx.get_heap_bases() - >>> print(heap_bases.shape) # torch.Size([num_ranks]) - """ - return self.heap_bases - - def barrier(self, stream=None): - """ - Synchronize all ranks and their CUDA devices. - - This first calls ``torch.cuda.synchronize()`` or ``stream.synchronize()`` to ensure the local GPU has - finished all queued work, then performs a global distributed barrier so that all - ranks reach the same point before proceeding. - Args: - stream: If stream is given: wait only for that stream before barrier. If stream is None: legacy behavior (device-wide sync). - - Example: - >>> ctx = iris.iris(1 << 20) - >>> ctx.barrier() # Synchronize all ranks - """ - # Wait for all GPUs to finish work - if stream is None: - torch.cuda.synchronize() - else: - stream.synchronize() - - # Distributed barrier - distributed_barrier() - - def get_device(self): - """ - Get the underlying device where the Iris symmetric heap resides. - - Returns: - torch.device: The CUDA device of Iris-managed memory. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> device = ctx.get_device() - >>> print(device) # cuda:0 - """ - return self.memory_pool.device - - def get_cu_count(self): - """ - Get the number of compute units (CUs) for the current GPU. - - Returns: - int: Number of compute units on this rank's GPU. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> cu_count = ctx.get_cu_count() - >>> print(f"GPU has {cu_count} CUs") # GPU has 304 CUs - """ - return get_cu_count(self.gpu_id) - - def get_rank(self): - """ - Get this process's rank id in the distributed communicator. - - Returns: - int: Zero-based rank id of the current process. - - Example: - >>> ctx = iris.iris(1 << 20) - >>> rank = ctx.get_rank() - >>> print(f"This is rank {rank}") # This is rank 0 - """ - return self.cur_rank - - def get_num_ranks(self): - """ - Get the total number of ranks in the distributed communicator. - - Returns: - int: World size (number of ranks). - - Example: - >>> ctx = iris.iris(1 << 20) - >>> num_ranks = ctx.get_num_ranks() - >>> print(f"Total ranks: {num_ranks}") # Total ranks: 1 - """ - return self.num_ranks - - def __throw_if_invalid_output_tensor(self, tensor: torch.Tensor, num_elements: int, dtype: torch.dtype): - if not self.__tensor_on_device(tensor): - raise RuntimeError( - f"The output tensor is not on the same device as the Iris instance. The Iris instance is on device {self.device} but the output tensor is on device {tensor.device}" - ) - if not self.__on_symmetric_heap(tensor): - raise RuntimeError( - f"The output tensor is not on the symmetric heap. The Iris instance is on heap base {self.heap_bases[self.cur_rank]} but the output tensor is on heap base {tensor.data_ptr()}" - ) - if tensor.numel() != num_elements: - raise RuntimeError(f"The output tensor has {tensor.numel()} elements, but {num_elements} are required") - if tensor.dtype != dtype: - raise RuntimeError(f"The output tensor has dtype {tensor.dtype}, but {dtype} is required") - - def __throw_if_invalid_device(self, device): - """ - Throw a RuntimeError if the requested device is not compatible with this Iris instance. - - Args: - device: The requested device (can be string, torch.device, or None) - - Raises: - RuntimeError: If the device is not compatible - """ - if not self.__is_valid_device(device): - raise RuntimeError( - f"Device mismatch: requested device {device} but Iris instance is on device {self.device}. " - f"Iris only supports tensors on its own device." - ) - - def __apply_memory_format( - self, tensor: torch.Tensor, size: tuple, memory_format: torch.memory_format, input_tensor: torch.Tensor = None - ): - """ - Apply the requested memory format to a tensor by setting appropriate strides. - This keeps the tensor on the symmetric heap while changing how PyTorch interprets the memory layout. - - Args: - tensor: The tensor to modify - size: The tensor's size/dimensions - memory_format: The desired memory format - input_tensor: The original input tensor (needed for preserve_format detection) - """ - if memory_format == torch.contiguous_format: - # Default format, no changes needed - return tensor - elif memory_format == torch.channels_last and len(size) == 4: - # For channels_last format: preserve shape (N, C, H, W) but change strides - # channels_last strides: [C*H*W, 1, C*W, C] for shape (N, C, H, W) - N, C, H, W = size[0], size[1], size[2], size[3] - # Keep the original shape (N, C, H, W) but use channels_last strides - tensor = self.__create_tensor_with_strides(tensor, size, (C * H * W, 1, C * W, C)) - return tensor - elif memory_format == torch.channels_last_3d and len(size) == 5: - # For channels_last_3d format: preserve shape (N, C, D, H, W) but change strides - # channels_last_3d strides: [C*D*H*W, 1, C*D*W, C*W, C] for shape (N, C, D, H, W) - N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] - # Keep the original shape (N, C, D, H, W) but use channels_last_3d strides - tensor = self.__create_tensor_with_strides(tensor, size, (C * D * H * W, 1, C * D * W, C * W, C)) - return tensor - elif memory_format == torch.preserve_format: - # For preserve_format, we need to detect the input tensor's memory format - # and apply the same format to the output - if input_tensor is not None: - # Check the actual memory format of the input tensor - if len(size) == 4: - # Check if input tensor is in channels_last format by examining strides - # channels_last format has strides[1] == 1 (channels dimension is contiguous) - input_strides = input_tensor.stride() - if len(input_strides) == 4 and input_strides[1] == 1: - # Input is in channels_last format, preserve it - # Use the input tensor's actual shape, not the size parameter - input_shape = input_tensor.shape - if len(input_shape) == 4: - # Input is already in channels_last format (N, H, W, C) - new_size = input_shape - # Use the input tensor's strides directly - tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) - return tensor - elif len(size) == 5: - # Check if input tensor is in channels_last_3d format - input_strides = input_tensor.stride() - if len(input_strides) == 5 and input_strides[1] == 1: - # Input is in channels_last_3d format, preserve it - # Use the input tensor's actual shape, not the size parameter - input_shape = input_tensor.shape - if len(input_shape) == 5: - # Input is already in channels_last_3d format (N, D, H, W, C) - new_size = input_shape - # Use the input tensor's strides directly - tensor = self.__create_tensor_with_strides(tensor, new_size, input_strides) - return tensor - # If no special format detected or no input tensor provided, use contiguous format - return tensor - else: - # Unsupported format or dimension combination - self.debug( - f"Warning: Memory format {memory_format} not supported for {len(size)}D tensor, using contiguous format" - ) - # For unsupported formats, return the tensor as-is (contiguous) - return tensor - - def __create_tensor_with_strides(self, original_tensor: torch.Tensor, size: tuple, strides: tuple) -> torch.Tensor: - """ - Create a new tensor with the specified strides while keeping the data on the symmetric heap. - - Args: - original_tensor: The original tensor (source of data and heap allocation) - size: The tensor's size/dimensions - strides: The desired strides for the new memory format - - Returns: - A new tensor with the specified strides, data copied from original, on the same heap - """ - - # First, create a temporary tensor with the correct strides using PyTorch - temp_tensor = torch.empty_strided(size, strides, dtype=original_tensor.dtype, device=original_tensor.device) - - # Handle different cases based on whether size changes and what the strides indicate - if size != original_tensor.shape: - # Size is different - this might be a format change that requires permutation - # Check if this is a channels_last format by comparing strides - if len(size) == 4: - # For channels_last: expected strides are [H*W*C, 1, W*C, C] for shape (N, H, W, C) - N, H, W, C = size[0], size[1], size[2], size[3] - expected_strides = (H * W * C, 1, W * C, C) - if strides == expected_strides: - permuted = original_tensor.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - else: - # If the size differs for other reasons, do not permute; just reshape if possible - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - elif len(size) == 5: - # For channels_last_3d: expected strides are [D*H*W*C, 1, H*W*C, W*C, C] for shape (N, D, H, W, C) - N, D, H, W, C = size[0], size[1], size[2], size[3], size[4] - expected_strides = (D * H * W * C, 1, H * W * C, W * C, C) - if strides == expected_strides: - permuted = original_tensor.permute(0, 2, 3, 4, 1) # (N, C, D, H, W) -> (N, D, H, W, C) - else: - # If the size differs for other reasons, do not permute; just reshape if possible - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - else: - # For other dimensions, just try to reshape - try: - permuted = original_tensor.reshape(size) - except Exception: - raise ValueError( - "Cannot safely permute or reshape tensor: size differs from original shape for unknown reason." - ) - else: - # Size is the same - this is a stride-only change (like channels_last with preserved shape) - # We need to reorder the data to match the new stride pattern - if len(size) == 4: - # Check if this is channels_last format with preserved shape - N, C, H, W = size[0], size[1], size[2], size[3] - expected_strides = (C * H * W, 1, C * W, C) - if strides == expected_strides: - permuted = original_tensor - else: - permuted = original_tensor - elif len(size) == 5: - # Check if this is channels_last_3d format with preserved shape - N, C, D, H, W = size[0], size[1], size[2], size[3], size[4] - expected_strides = (C * D * H * W, 1, C * D * W, C * W, C) - if strides == expected_strides: - permuted = original_tensor - else: - permuted = original_tensor - else: - permuted = original_tensor - - # Copy the permuted data to the temporary tensor - temp_tensor.copy_(permuted) - - # Now allocate a new tensor on our symmetric heap - num_elements = math.prod(size) - heap_tensor = self.__allocate(num_elements, original_tensor.dtype) - - # Reshape to the desired size - heap_tensor = heap_tensor.reshape(size) - - # Copy the data from the temporary tensor to our heap tensor - heap_tensor.copy_(temp_tensor) - - # Clean up the temporary tensor - del temp_tensor - - # Now we need to create a view with the correct strides - # We can't use as_strided directly on our heap tensor, but we can - # create a new tensor with the right strides and copy the data again - final_tensor = torch.as_strided(heap_tensor, size, strides) - - return final_tensor - - def __apply_layout(self, tensor: torch.Tensor, layout: torch.layout) -> torch.Tensor: - """ - Apply the requested layout to a tensor. - - Args: - tensor: The tensor to modify - layout: The desired layout - - Returns: - Tensor with the requested layout - """ - - if layout == torch.strided: - # Strided layout is the default - no changes needed - return tensor - else: - # Only support strided layout for now - raise ValueError(f"Layout {layout} not supported. Only torch.strided is currently supported.") - - def __tensor_on_device(self, tensor: torch.Tensor): - # Get the Iris device from memory_pool.device - iris_device = self.get_device() - tensor_device = tensor.device - - # For CUDA devices, check if they're compatible - if tensor_device.type == "cuda" and iris_device.type == "cuda": - if iris_device.index is None: - return True - return tensor_device.index == iris_device.index - - # For non-CUDA devices, they must be exactly equal - return tensor_device == iris_device - - def __on_symmetric_heap(self, tensor: torch.Tensor): - # Special case for empty tensors - they might not have a valid data_ptr - if tensor.numel() == 0: - self.debug("Empty tensor detected, skipping heap check") - return True - - # Convert CUDA pointer to integer for comparison - tensor_ptr = int(tensor.data_ptr()) - heap_base = int(self.heap_bases[self.cur_rank]) - - result = tensor_ptr >= heap_base and tensor_ptr < heap_base + self.heap_size - - return result - - def __is_valid_device(self, device) -> bool: - """ - Check if the requested device is compatible with this Iris instance. - - Args: - device: The requested device (can be string, torch.device, or None) - - Returns: - bool: True if the device is compatible, False otherwise - """ - if device is None: - return True # None means use default device - - # Convert device strings to torch.device objects for proper comparison - requested_device = torch.device(device) if isinstance(device, str) else device - iris_device = self.get_device() - - # Check if both are CUDA devices - if requested_device.type == "cuda" and iris_device.type == "cuda": - # Check if index matches or if requested is "cuda" (any index) - if requested_device.index is None: - return True - else: - return requested_device.index == iris_device.index - - # For non-CUDA devices, always return False - return False - - class CCL: + class CCL(CCLBase): """ Collective Communication Library (CCL) interface for Iris. + Extends CCLBase with Triton-specific all_reduce operations. Provides collective operations that can be called as methods on the Iris instance. Example usage: >>> shmem = iris.iris() >>> shmem.ccl.all_to_all(output_tensor, input_tensor) """ - def __init__(self, iris_instance): - """ - Initialize CCL with a reference to the parent Iris instance. - - Args: - iris_instance: The parent Iris instance - """ - self._iris = iris_instance - - def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): - """ - All-to-all collective operation. - - Each rank sends a tensor chunk to each other rank and receives - a tensor chunk from each other rank. Input/output tensors should have - shape (M, N * world_size) where each chunk of N columns corresponds to one rank. - - Args: - output_tensor: Output tensor of shape (M, N * world_size) - input_tensor: Input tensor of shape (M, N * world_size) - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - - Example: - >>> shmem = iris.iris() - >>> shmem.ccl.all_to_all(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, config=config) - - >>> # Async operation (no barrier) - >>> shmem.ccl.all_to_all(output_tensor, input_tensor, async_op=True) - """ - from iris.ccl.all_to_all import all_to_all as _all_to_all - - _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) - - def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): - """ - All-gather collective operation. - - Each rank sends its input tensor to all ranks, and all ranks receive - and concatenate all input tensors along dimension 0 (rows), matching - torch.distributed.all_gather_into_tensor behavior. - - Args: - output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs - input_tensor: Input tensor of shape (M, N) - local rank's data to send - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - - Example: - >>> shmem = iris.iris() - >>> # Input: (M, N), Output: (world_size * M, N) - >>> shmem.ccl.all_gather(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(block_size_m=128, block_size_n=32) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) - - >>> # Async operation (no barrier) - >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) - """ - from iris.ccl.all_gather import all_gather as _all_gather - - _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) - def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): """ Prepare reusable workspace for all-reduce. @@ -1648,36 +133,6 @@ def all_reduce(self, output_tensor, input_tensor, config=None, async_op=False, w workspace=workspace, ) - def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): - """ - Reduce-scatter collective operation. - - Each rank reduces its assigned tiles from all ranks' inputs and stores - the result only to its own output tensor. This is similar to all-reduce - but without broadcasting the result to all ranks. - - Args: - output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank - input_tensor: Input tensor of shape (M, N) - local rank's partial data - config: Config instance with kernel parameters (default: None). - If None, uses default Config values. - Only supports reduce_scatter_variant="two_shot". - async_op: If False, performs a barrier at the end. If True, returns immediately. - Default: False. - - Example: - >>> shmem = iris.iris() - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) - - >>> # Custom configuration - >>> from iris.ccl import Config - >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) - >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) - """ - from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter - - _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) - @triton.jit def __translate(ptr, from_rank, to_rank, heap_bases): diff --git a/tests/unittests/test_arange.py b/tests/unittests/test_arange.py index e3183faf..2d1e7b96 100644 --- a/tests/unittests/test_arange.py +++ b/tests/unittests/test_arange.py @@ -15,27 +15,27 @@ def test_arange_basic_functionality(): assert result1.shape == (5,) assert torch.all(result1 == torch.tensor([0, 1, 2, 3, 4], device=result1.device)) assert result1.dtype == torch.int64 - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test 2: arange(start, end) - two arguments result2 = shmem.arange(1, 4) assert result2.shape == (3,) assert torch.all(result2 == torch.tensor([1, 2, 3], device=result2.device)) assert result2.dtype == torch.int64 - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test 3: arange(start, end, step) - three arguments result3 = shmem.arange(1, 2.5, 0.5) assert result3.shape == (3,) assert torch.allclose(result3, torch.tensor([1.0, 1.5, 2.0], device=result3.device)) assert result3.dtype == torch.float32 - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test 4: arange with negative step result4 = shmem.arange(5, 0, -1) assert result4.shape == (5,) assert torch.all(result4 == torch.tensor([5, 4, 3, 2, 1], device=result4.device)) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_arange_dtype_inference(): @@ -45,22 +45,22 @@ def test_arange_dtype_inference(): # Test integer dtype inference result_int = shmem.arange(3) assert result_int.dtype == torch.int64 - assert shmem._Iris__on_symmetric_heap(result_int) + assert shmem._on_symmetric_heap(result_int) # Test float dtype inference result_float = shmem.arange(1.0, 3.0) assert result_float.dtype == torch.float32 - assert shmem._Iris__on_symmetric_heap(result_float) + assert shmem._on_symmetric_heap(result_float) # Test explicit dtype override result_explicit = shmem.arange(3, dtype=torch.float64) assert result_explicit.dtype == torch.float64 - assert shmem._Iris__on_symmetric_heap(result_explicit) + assert shmem._on_symmetric_heap(result_explicit) # Test mixed types (should infer float) result_mixed = shmem.arange(1, 3.5, 0.5) assert result_mixed.dtype == torch.float32 - assert shmem._Iris__on_symmetric_heap(result_mixed) + assert shmem._on_symmetric_heap(result_mixed) def test_arange_device_handling(): @@ -70,18 +70,18 @@ def test_arange_device_handling(): # Test default device (should use Iris device) result_default = shmem.arange(3) assert str(result_default.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result_default) + assert shmem._on_symmetric_heap(result_default) # Test explicit device iris_device = str(shmem.get_device()) result_explicit = shmem.arange(3, device=iris_device) assert str(result_explicit.device) == iris_device - assert shmem._Iris__on_symmetric_heap(result_explicit) + assert shmem._on_symmetric_heap(result_explicit) # Test device=None (should use Iris device) result_none = shmem.arange(3, device=None) assert str(result_none.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result_none) + assert shmem._on_symmetric_heap(result_none) def test_arange_layout_handling(): @@ -91,7 +91,7 @@ def test_arange_layout_handling(): # Test default layout (strided) result_strided = shmem.arange(3, layout=torch.strided) assert result_strided.layout == torch.strided - assert shmem._Iris__on_symmetric_heap(result_strided) + assert shmem._on_symmetric_heap(result_strided) def test_arange_requires_grad(): @@ -101,17 +101,17 @@ def test_arange_requires_grad(): # Test default (False) result_default = shmem.arange(3) assert not result_default.requires_grad - assert shmem._Iris__on_symmetric_heap(result_default) + assert shmem._on_symmetric_heap(result_default) # Test True result_true = shmem.arange(3, dtype=torch.float32, requires_grad=True) assert result_true.requires_grad - assert shmem._Iris__on_symmetric_heap(result_true) + assert shmem._on_symmetric_heap(result_true) # Test False explicitly result_false = shmem.arange(3, requires_grad=False) assert not result_false.requires_grad - assert shmem._Iris__on_symmetric_heap(result_false) + assert shmem._on_symmetric_heap(result_false) def test_arange_out_parameter(): @@ -119,20 +119,20 @@ def test_arange_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(3, torch.int64) + out_tensor = shmem._allocate(3, torch.int64) result = shmem.arange(3, out=out_tensor) # Should return the same tensor object assert result is out_tensor assert torch.all(result == torch.tensor([0, 1, 2], device=result.device)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_float = shmem._Iris__allocate(3, torch.float32) + out_tensor_float = shmem._allocate(3, torch.float32) result_float = shmem.arange(3, dtype=torch.float32, out=out_tensor_float) assert result_float is out_tensor_float assert result_float.dtype == torch.float32 - assert shmem._Iris__on_symmetric_heap(result_float) + assert shmem._on_symmetric_heap(result_float) def test_arange_error_handling(): @@ -164,7 +164,7 @@ def test_arange_edge_cases(): assert result_single.shape == (1,) assert result_single.numel() == 1 assert result_single[0] == 1 - assert shmem._Iris__on_symmetric_heap(result_single) + assert shmem._on_symmetric_heap(result_single) # Test large tensor result_large = shmem.arange(1000) @@ -172,14 +172,14 @@ def test_arange_edge_cases(): assert result_large.numel() == 1000 assert result_large[0] == 0 assert result_large[-1] == 999 - assert shmem._Iris__on_symmetric_heap(result_large) + assert shmem._on_symmetric_heap(result_large) # Test floating point precision result_float = shmem.arange(0, 1, 0.1) assert result_float.shape == (10,) assert torch.allclose(result_float[0], torch.tensor(0.0)) assert torch.allclose(result_float[-1], torch.tensor(0.9)) - assert shmem._Iris__on_symmetric_heap(result_float) + assert shmem._on_symmetric_heap(result_float) def test_arange_pytorch_equivalence(): @@ -229,7 +229,7 @@ def test_arange_parameter_combinations(params): # Verify basic properties assert result.dtype == params["dtype"] - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify values match PyTorch pytorch_result = torch.arange( @@ -274,7 +274,7 @@ def test_arange_symmetric_heap_verification(arange_args, kwargs): result = shmem.arange(*arange_args, **kwargs) # Verify symmetric heap allocation - assert shmem._Iris__on_symmetric_heap(result), ( + assert shmem._on_symmetric_heap(result), ( f"Tensor {result} with args={arange_args}, kwargs={kwargs} is not on symmetric heap" ) diff --git a/tests/unittests/test_empty.py b/tests/unittests/test_empty.py index e51fb4c2..9507e71d 100644 --- a/tests/unittests/test_empty.py +++ b/tests/unittests/test_empty.py @@ -41,7 +41,7 @@ def test_empty_basic(dtype, size): assert result.dtype == dtype # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Note: We don't check the values since they are uninitialized @@ -53,7 +53,7 @@ def test_empty_default_dtype(): result = shmem.empty(2, 3) expected_dtype = torch.get_default_dtype() assert result.dtype == expected_dtype - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -71,7 +71,7 @@ def test_empty_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_empty_device_handling(): @@ -80,23 +80,23 @@ def test_empty_device_handling(): # Test default behavior (should use Iris device) result = shmem.empty(3, 3) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.empty(3, 3, device=shmem.device) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.empty(3, 3, device="cuda") assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.empty(3, 3, device=None) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -117,7 +117,7 @@ def test_empty_layout_handling(): # Test with strided layout (default) result = shmem.empty(2, 4, layout=torch.strided) assert result.layout == torch.strided - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -128,20 +128,20 @@ def test_empty_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.float32) + out_tensor = shmem._allocate(6, torch.float32) result = shmem.empty(2, 3, out=out_tensor) # Should share the same underlying data (same data_ptr) assert result.data_ptr() == out_tensor.data_ptr() assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_int = shmem._Iris__allocate(6, torch.int32) + out_tensor_int = shmem._allocate(6, torch.int32) result_int = shmem.empty(2, 3, dtype=torch.int32, out=out_tensor_int) assert result_int.data_ptr() == out_tensor_int.data_ptr() assert result_int.dtype == torch.int32 - assert shmem._Iris__on_symmetric_heap(result_int) + assert shmem._on_symmetric_heap(result_int) def test_empty_size_variations(): @@ -150,22 +150,22 @@ def test_empty_size_variations(): # Test single dimension result1 = shmem.empty(5) assert result1.shape == (5,) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.empty(2, 3, 4) assert result2.shape == (2, 3, 4) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as single argument result3 = shmem.empty((3, 4)) assert result3.shape == (3, 4) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.empty([2, 5]) assert result4.shape == (2, 5) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_empty_edge_cases(): @@ -175,25 +175,25 @@ def test_empty_edge_cases(): empty_result = shmem.empty(0) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.empty(1) assert single_result.shape == (1,) assert single_result.numel() == 1 - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.empty(100, 100) assert large_result.shape == (100, 100) assert large_result.numel() == 10000 - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.empty(()) assert scalar_result.shape == () assert scalar_result.numel() == 1 - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_empty_pytorch_equivalence(): @@ -243,7 +243,7 @@ def test_empty_parameter_combinations(params): # Verify basic functionality assert result.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -278,7 +278,7 @@ def test_empty_symmetric_heap_shapes_dtypes(size, dtype): result = shmem.empty(*size, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" # Also verify basic functionality assert result.shape == size @@ -291,7 +291,7 @@ def test_empty_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.empty(3, 3, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -301,20 +301,20 @@ def test_empty_symmetric_heap_other_params(): # Test with requires_grad result = shmem.empty(3, 3, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.empty(3, 3, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.empty(3, 3, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.float32) + out_tensor = shmem._allocate(9, torch.float32) result = shmem.empty(3, 3, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_empty_invalid_output_tensor(): @@ -322,12 +322,12 @@ def test_empty_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.float32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.float32) # Wrong size for (3, 3) with pytest.raises(RuntimeError): shmem.empty(3, 3, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.int32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.int32) # Wrong dtype with pytest.raises(RuntimeError): shmem.empty(3, 3, dtype=torch.float32, out=wrong_dtype_tensor) @@ -393,17 +393,17 @@ def test_empty_memory_format(): # Test contiguous format (default) result_contig = shmem.empty(2, 3, 4, memory_format=torch.contiguous_format) assert result_contig.is_contiguous() - assert shmem._Iris__on_symmetric_heap(result_contig) + assert shmem._on_symmetric_heap(result_contig) # Test channels_last format (should work for 4D tensors) result_cl = shmem.empty(2, 3, 4, 5, memory_format=torch.channels_last) assert result_cl.shape == (2, 3, 4, 5) - assert shmem._Iris__on_symmetric_heap(result_cl) + assert shmem._on_symmetric_heap(result_cl) # Test channels_last_3d format (should work for 5D tensors) result_cl3d = shmem.empty(2, 3, 4, 5, 6, memory_format=torch.channels_last_3d) assert result_cl3d.shape == (2, 3, 4, 5, 6) - assert shmem._Iris__on_symmetric_heap(result_cl3d) + assert shmem._on_symmetric_heap(result_cl3d) def test_empty_pin_memory(): @@ -413,7 +413,7 @@ def test_empty_pin_memory(): # Test with pin_memory=True (should work but be ignored since Iris tensors are on GPU) result = shmem.empty(2, 3, pin_memory=True) assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Note: pin_memory is ignored for GPU tensors, so we just verify it doesn't cause errors @@ -424,7 +424,7 @@ def test_empty_deterministic_behavior(): # Test that empty works regardless of deterministic settings result = shmem.empty(2, 3) assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Note: The actual deterministic behavior (filling with NaN/max values) # is handled by PyTorch internally, so we just verify our function works diff --git a/tests/unittests/test_full.py b/tests/unittests/test_full.py index a42d4ddb..399151c5 100644 --- a/tests/unittests/test_full.py +++ b/tests/unittests/test_full.py @@ -44,7 +44,7 @@ def test_full_basic(fill_value, size): assert torch.all(result == fill_value) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_full_dtype_inference(): @@ -54,19 +54,19 @@ def test_full_dtype_inference(): result_int = shmem.full((2, 3), 42) assert result_int.dtype == torch.int64 assert torch.all(result_int == 42) - assert shmem._Iris__on_symmetric_heap(result_int) + assert shmem._on_symmetric_heap(result_int) # Test float fill_value (should infer default float dtype) result_float = shmem.full((2, 3), 3.141592) assert result_float.dtype == torch.get_default_dtype() assert torch.allclose(result_float, torch.tensor(3.141592)) - assert shmem._Iris__on_symmetric_heap(result_float) + assert shmem._on_symmetric_heap(result_float) # Test explicit dtype override result_explicit = shmem.full((2, 3), 42, dtype=torch.float32) assert result_explicit.dtype == torch.float32 assert torch.all(result_explicit == 42) - assert shmem._Iris__on_symmetric_heap(result_explicit) + assert shmem._on_symmetric_heap(result_explicit) @pytest.mark.parametrize( @@ -85,7 +85,7 @@ def test_full_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad assert torch.all(result == 1.5) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_full_device_handling(): @@ -95,26 +95,26 @@ def test_full_device_handling(): result = shmem.full((3, 3), 2.5) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 2.5) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.full((3, 3), 2.5, device=shmem.device) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 2.5) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.full((3, 3), 2.5, device="cuda") assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 2.5) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.full((3, 3), 2.5, device=None) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 2.5) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -136,7 +136,7 @@ def test_full_layout_handling(): result = shmem.full((2, 4), 1.0, layout=torch.strided) assert result.layout == torch.strided assert torch.all(result == 1.0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -147,22 +147,22 @@ def test_full_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.float32) + out_tensor = shmem._allocate(6, torch.float32) result = shmem.full((2, 3), 3.141592, out=out_tensor) # Should share the same underlying data (same data_ptr) assert result.data_ptr() == out_tensor.data_ptr() assert torch.allclose(result, torch.tensor(3.141592)) assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_int = shmem._Iris__allocate(6, torch.int32) + out_tensor_int = shmem._allocate(6, torch.int32) result_int = shmem.full((2, 3), 42, dtype=torch.int32, out=out_tensor_int) assert result_int.data_ptr() == out_tensor_int.data_ptr() assert result_int.dtype == torch.int32 assert torch.all(result_int == 42) - assert shmem._Iris__on_symmetric_heap(result_int) + assert shmem._on_symmetric_heap(result_int) def test_full_size_variations(): @@ -172,25 +172,25 @@ def test_full_size_variations(): result1 = shmem.full((5,), 2.0) assert result1.shape == (5,) assert torch.all(result1 == 2.0) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.full((2, 3, 4), 1.5) assert result2.shape == (2, 3, 4) assert torch.all(result2 == 1.5) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as single argument result3 = shmem.full((3, 4), 0.5) assert result3.shape == (3, 4) assert torch.all(result3 == 0.5) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.full([2, 5], -1.0) assert result4.shape == (2, 5) assert torch.all(result4 == -1.0) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_full_edge_cases(): @@ -200,28 +200,28 @@ def test_full_edge_cases(): empty_result = shmem.full((0,), 1.0) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.full((1,), 5.0) assert single_result.shape == (1,) assert single_result.numel() == 1 assert single_result[0] == 5.0 - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.full((100, 100), 0.1) assert large_result.shape == (100, 100) assert large_result.numel() == 10000 assert torch.all(large_result == 0.1) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.full((), 2.718) assert scalar_result.shape == () assert scalar_result.numel() == 1 assert torch.allclose(scalar_result, torch.tensor(2.718)) - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_full_pytorch_equivalence(): @@ -280,7 +280,7 @@ def test_full_parameter_combinations(params): else: # For integer dtypes, the fill value gets truncated assert torch.all(result == 2) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -315,7 +315,7 @@ def test_full_symmetric_heap_shapes_dtypes(size, fill_value, dtype): result = shmem.full(size, fill_value, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), ( + assert shmem._on_symmetric_heap(result), ( f"Tensor with size {size}, fill_value {fill_value}, dtype {dtype} is NOT on symmetric heap!" ) @@ -331,7 +331,7 @@ def test_full_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.full((3, 3), 1.5, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -341,20 +341,20 @@ def test_full_symmetric_heap_other_params(): # Test with requires_grad result = shmem.full((3, 3), 1.5, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.full((3, 3), 1.5, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.full((3, 3), 1.5, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.float32) + out_tensor = shmem._allocate(9, torch.float32) result = shmem.full((3, 3), 1.5, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_full_invalid_output_tensor(): @@ -362,12 +362,12 @@ def test_full_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.float32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.float32) # Wrong size for (3, 3) with pytest.raises(RuntimeError): shmem.full((3, 3), 1.5, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.int32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.int32) # Wrong dtype with pytest.raises(RuntimeError): shmem.full((3, 3), 1.5, dtype=torch.float32, out=wrong_dtype_tensor) @@ -412,7 +412,7 @@ def test_full_examples(): expected = torch.tensor([[3.141592, 3.141592, 3.141592], [3.141592, 3.141592, 3.141592]], device=result.device) assert result.shape == (2, 3) assert torch.allclose(result, expected) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_full_different_fill_values(): @@ -436,7 +436,7 @@ def test_full_different_fill_values(): result = shmem.full((2, 2), fill_value) assert result.dtype == expected_dtype assert torch.allclose(result, torch.tensor(fill_value, dtype=expected_dtype)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_full_dtype_override(): @@ -447,10 +447,10 @@ def test_full_dtype_override(): result = shmem.full((2, 2), 42, dtype=torch.float32) assert result.dtype == torch.float32 assert torch.allclose(result, torch.tensor(42.0, dtype=torch.float32)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Float fill_value with int dtype result = shmem.full((2, 2), 3.14, dtype=torch.int32) assert result.dtype == torch.int32 assert torch.all(result == 3) # Truncated to int - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) diff --git a/tests/unittests/test_linspace.py b/tests/unittests/test_linspace.py index 02d26b24..7f01fc1e 100644 --- a/tests/unittests/test_linspace.py +++ b/tests/unittests/test_linspace.py @@ -42,7 +42,7 @@ def test_linspace_basic(dtype, start, end, steps): assert torch.allclose(result[-1], torch.tensor(end, dtype=dtype)) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_linspace_default_dtype(): @@ -52,7 +52,7 @@ def test_linspace_default_dtype(): result = shmem.linspace(0.0, 1.0, 5) expected_dtype = torch.get_default_dtype() assert result.dtype == expected_dtype - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -70,7 +70,7 @@ def test_linspace_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_linspace_device_handling(): @@ -79,23 +79,23 @@ def test_linspace_device_handling(): # Test default behavior (should use Iris device) result = shmem.linspace(0.0, 1.0, 5) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.linspace(0.0, 1.0, 5, device=shmem.device) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.linspace(0.0, 1.0, 5, device="cuda") assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.linspace(0.0, 1.0, 5, device=None) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -116,7 +116,7 @@ def test_linspace_layout_handling(): # Test with strided layout (default) result = shmem.linspace(0.0, 1.0, 5, layout=torch.strided) assert result.layout == torch.strided - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -127,7 +127,7 @@ def test_linspace_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(5, torch.float32) + out_tensor = shmem._allocate(5, torch.float32) result = shmem.linspace(0.0, 1.0, 5, out=out_tensor) # Should share the same underlying data (same data_ptr) @@ -135,14 +135,14 @@ def test_linspace_out_parameter(): assert result.shape == (5,) assert torch.allclose(result[0], torch.tensor(0.0)) assert torch.allclose(result[-1], torch.tensor(1.0)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_float64 = shmem._Iris__allocate(5, torch.float64) + out_tensor_float64 = shmem._allocate(5, torch.float64) result_float64 = shmem.linspace(0.0, 1.0, 5, dtype=torch.float64, out=out_tensor_float64) assert result_float64.data_ptr() == out_tensor_float64.data_ptr() assert result_float64.dtype == torch.float64 - assert shmem._Iris__on_symmetric_heap(result_float64) + assert shmem._on_symmetric_heap(result_float64) def test_linspace_steps_variations(): @@ -152,24 +152,24 @@ def test_linspace_steps_variations(): result1 = shmem.linspace(0.0, 1.0, 1) assert result1.shape == (1,) assert torch.allclose(result1[0], torch.tensor(0.0)) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple steps result2 = shmem.linspace(0.0, 1.0, 10) assert result2.shape == (10,) assert torch.allclose(result2[0], torch.tensor(0.0)) assert torch.allclose(result2[-1], torch.tensor(1.0)) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as steps argument result3 = shmem.linspace(0.0, 1.0, (5,)) assert result3.shape == (5,) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as steps argument result4 = shmem.linspace(0.0, 1.0, [5]) assert result4.shape == (5,) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_linspace_edge_cases(): @@ -179,28 +179,28 @@ def test_linspace_edge_cases(): single_result = shmem.linspace(5.0, 5.0, 1) assert single_result.shape == (1,) assert torch.allclose(single_result[0], torch.tensor(5.0)) - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Two steps two_result = shmem.linspace(0.0, 1.0, 2) assert two_result.shape == (2,) assert torch.allclose(two_result[0], torch.tensor(0.0)) assert torch.allclose(two_result[1], torch.tensor(1.0)) - assert shmem._Iris__on_symmetric_heap(two_result) + assert shmem._on_symmetric_heap(two_result) # Large number of steps large_result = shmem.linspace(0.0, 100.0, 1000) assert large_result.shape == (1000,) assert torch.allclose(large_result[0], torch.tensor(0.0)) assert torch.allclose(large_result[-1], torch.tensor(100.0)) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Negative range neg_result = shmem.linspace(-10.0, -5.0, 6) assert neg_result.shape == (6,) assert torch.allclose(neg_result[0], torch.tensor(-10.0)) assert torch.allclose(neg_result[-1], torch.tensor(-5.0)) - assert shmem._Iris__on_symmetric_heap(neg_result) + assert shmem._on_symmetric_heap(neg_result) def test_linspace_pytorch_equivalence(): @@ -252,7 +252,7 @@ def test_linspace_parameter_combinations(params): assert result.shape == (5,) assert torch.allclose(result[0], torch.tensor(0.0, dtype=result.dtype)) assert torch.allclose(result[-1], torch.tensor(1.0, dtype=result.dtype)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -285,7 +285,7 @@ def test_linspace_symmetric_heap_shapes_dtypes(start, end, steps, dtype): result = shmem.linspace(start, end, steps, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), ( + assert shmem._on_symmetric_heap(result), ( f"Tensor with start={start}, end={end}, steps={steps}, dtype={dtype} is NOT on symmetric heap!" ) @@ -302,7 +302,7 @@ def test_linspace_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.linspace(0.0, 1.0, 5, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -312,20 +312,20 @@ def test_linspace_symmetric_heap_other_params(): # Test with requires_grad result = shmem.linspace(0.0, 1.0, 5, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.linspace(0.0, 1.0, 5, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.linspace(0.0, 1.0, 5, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(5, torch.float32) + out_tensor = shmem._allocate(5, torch.float32) result = shmem.linspace(0.0, 1.0, 5, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_linspace_invalid_output_tensor(): @@ -333,12 +333,12 @@ def test_linspace_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(3, torch.float32) # Wrong size for 5 steps + wrong_size_tensor = shmem._allocate(3, torch.float32) # Wrong size for 5 steps with pytest.raises(RuntimeError): shmem.linspace(0.0, 1.0, 5, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(5, torch.int32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(5, torch.int32) # Wrong dtype with pytest.raises(RuntimeError): shmem.linspace(0.0, 1.0, 5, dtype=torch.float32, out=wrong_dtype_tensor) @@ -407,12 +407,12 @@ def test_linspace_complex_numbers(): assert result.dtype == torch.complex64 assert torch.allclose(result[0], torch.tensor(0.0 + 0.0j, dtype=torch.complex64)) assert torch.allclose(result[-1], torch.tensor(1.0 + 1.0j, dtype=torch.complex64)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with complex dtype inference result = shmem.linspace(0.0 + 0.0j, 1.0 + 1.0j, 5) assert result.dtype == torch.complex64 # Should infer complex dtype - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_linspace_tensor_inputs(): @@ -427,7 +427,7 @@ def test_linspace_tensor_inputs(): assert result.shape == (5,) assert torch.allclose(result[0], torch.tensor(0.0)) assert torch.allclose(result[-1], torch.tensor(1.0)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with complex tensor inputs start_complex = torch.tensor(0.0 + 0.0j, device="cuda") @@ -436,7 +436,7 @@ def test_linspace_tensor_inputs(): result_complex = shmem.linspace(start_complex, end_complex, 5) assert result_complex.shape == (5,) assert result_complex.dtype == torch.complex64 - assert shmem._Iris__on_symmetric_heap(result_complex) + assert shmem._on_symmetric_heap(result_complex) def test_linspace_accuracy(): @@ -473,4 +473,4 @@ def test_linspace_deterministic_behavior(): assert result.shape == (5,) assert torch.allclose(result[0], torch.tensor(0.0)) assert torch.allclose(result[-1], torch.tensor(1.0)) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) diff --git a/tests/unittests/test_ones.py b/tests/unittests/test_ones.py index e70c63f8..1c4edfe3 100644 --- a/tests/unittests/test_ones.py +++ b/tests/unittests/test_ones.py @@ -44,7 +44,7 @@ def test_ones_basic(dtype, size): assert torch.all(result == 1) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_ones_default_dtype(): @@ -55,7 +55,7 @@ def test_ones_default_dtype(): expected_dtype = torch.get_default_dtype() assert result.dtype == expected_dtype assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -74,7 +74,7 @@ def test_ones_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_ones_device_handling(): @@ -84,26 +84,26 @@ def test_ones_device_handling(): result = shmem.ones(3, 3) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.ones(3, 3, device=shmem.device) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.ones(3, 3, device="cuda") assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.ones(3, 3, device=None) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -125,7 +125,7 @@ def test_ones_layout_handling(): result = shmem.ones(2, 4, layout=torch.strided) assert result.layout == torch.strided assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -136,22 +136,22 @@ def test_ones_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.float32) + out_tensor = shmem._allocate(6, torch.float32) result = shmem.ones(2, 3, out=out_tensor) # Should share the same underlying data (same data_ptr) assert result.data_ptr() == out_tensor.data_ptr() assert torch.all(result == 1) assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_int = shmem._Iris__allocate(6, torch.int32) + out_tensor_int = shmem._allocate(6, torch.int32) result_int = shmem.ones(2, 3, dtype=torch.int32, out=out_tensor_int) assert result_int.data_ptr() == out_tensor_int.data_ptr() assert result_int.dtype == torch.int32 assert torch.all(result_int == 1) - assert shmem._Iris__on_symmetric_heap(result_int) + assert shmem._on_symmetric_heap(result_int) def test_ones_size_variations(): @@ -161,25 +161,25 @@ def test_ones_size_variations(): result1 = shmem.ones(5) assert result1.shape == (5,) assert torch.all(result1 == 1) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.ones(2, 3, 4) assert result2.shape == (2, 3, 4) assert torch.all(result2 == 1) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as single argument result3 = shmem.ones((3, 4)) assert result3.shape == (3, 4) assert torch.all(result3 == 1) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.ones([2, 5]) assert result4.shape == (2, 5) assert torch.all(result4 == 1) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_ones_edge_cases(): @@ -189,28 +189,28 @@ def test_ones_edge_cases(): empty_result = shmem.ones(0) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.ones(1) assert single_result.shape == (1,) assert single_result.numel() == 1 assert single_result[0] == 1 - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.ones(100, 100) assert large_result.shape == (100, 100) assert large_result.numel() == 10000 assert torch.all(large_result == 1) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.ones(()) assert scalar_result.shape == () assert scalar_result.numel() == 1 assert scalar_result.item() == 1 - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_ones_pytorch_equivalence(): @@ -262,7 +262,7 @@ def test_ones_parameter_combinations(params): # Verify basic functionality assert result.shape == (3, 3) assert torch.all(result == 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -297,7 +297,7 @@ def test_ones_symmetric_heap_shapes_dtypes(size, dtype): result = shmem.ones(*size, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" # Also verify basic functionality assert result.shape == size @@ -311,7 +311,7 @@ def test_ones_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.ones(3, 3, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -321,20 +321,20 @@ def test_ones_symmetric_heap_other_params(): # Test with requires_grad result = shmem.ones(3, 3, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.ones(3, 3, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.ones(3, 3, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.float32) + out_tensor = shmem._allocate(9, torch.float32) result = shmem.ones(3, 3, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_ones_invalid_output_tensor(): @@ -342,12 +342,12 @@ def test_ones_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.float32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.float32) # Wrong size for (3, 3) with pytest.raises(RuntimeError): shmem.ones(3, 3, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.int32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.int32) # Wrong dtype with pytest.raises(RuntimeError): shmem.ones(3, 3, dtype=torch.float32, out=wrong_dtype_tensor) @@ -415,11 +415,11 @@ def test_ones_examples(): expected1 = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=result1.device) assert result1.shape == (2, 3) assert torch.all(result1 == expected1) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Example 2: torch.ones(5) result2 = shmem.ones(5) expected2 = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], device=result2.device) assert result2.shape == (5,) assert torch.all(result2 == expected2) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) diff --git a/tests/unittests/test_rand.py b/tests/unittests/test_rand.py index 75b6968b..d641baf7 100644 --- a/tests/unittests/test_rand.py +++ b/tests/unittests/test_rand.py @@ -40,7 +40,7 @@ def test_rand_basic(dtype, size): assert torch.all(result < 1) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_rand_default_dtype(): @@ -50,7 +50,7 @@ def test_rand_default_dtype(): result = shmem.rand(2, 3) expected_dtype = torch.get_default_dtype() assert result.dtype == expected_dtype - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -68,7 +68,7 @@ def test_rand_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_rand_device_handling(): @@ -77,23 +77,23 @@ def test_rand_device_handling(): # Test default behavior (should use Iris device) result = shmem.rand(3, 3) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.rand(3, 3, device=shmem.device) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.rand(3, 3, device="cuda") assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.rand(3, 3, device=None) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -114,7 +114,7 @@ def test_rand_layout_handling(): # Test with strided layout (default) result = shmem.rand(2, 4, layout=torch.strided) assert result.layout == torch.strided - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -125,7 +125,7 @@ def test_rand_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.float32) + out_tensor = shmem._allocate(6, torch.float32) result = shmem.rand(2, 3, out=out_tensor) # Should share the same underlying data (same data_ptr) @@ -133,14 +133,14 @@ def test_rand_out_parameter(): assert result.shape == (2, 3) assert torch.all(result >= 0) assert torch.all(result < 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_float64 = shmem._Iris__allocate(6, torch.float64) + out_tensor_float64 = shmem._allocate(6, torch.float64) result_float64 = shmem.rand(2, 3, dtype=torch.float64, out=out_tensor_float64) assert result_float64.data_ptr() == out_tensor_float64.data_ptr() assert result_float64.dtype == torch.float64 - assert shmem._Iris__on_symmetric_heap(result_float64) + assert shmem._on_symmetric_heap(result_float64) def test_rand_size_variations(): @@ -151,28 +151,28 @@ def test_rand_size_variations(): assert result1.shape == (5,) assert torch.all(result1 >= 0) assert torch.all(result1 < 1) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.rand(2, 3, 4) assert result2.shape == (2, 3, 4) assert torch.all(result2 >= 0) assert torch.all(result2 < 1) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as single argument result3 = shmem.rand((3, 4)) assert result3.shape == (3, 4) assert torch.all(result3 >= 0) assert torch.all(result3 < 1) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.rand([2, 5]) assert result4.shape == (2, 5) assert torch.all(result4 >= 0) assert torch.all(result4 < 1) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_rand_edge_cases(): @@ -182,7 +182,7 @@ def test_rand_edge_cases(): empty_result = shmem.rand(0) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.rand(1) @@ -190,7 +190,7 @@ def test_rand_edge_cases(): assert single_result.numel() == 1 assert torch.all(single_result >= 0) assert torch.all(single_result < 1) - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.rand(50, 50) @@ -198,7 +198,7 @@ def test_rand_edge_cases(): assert large_result.numel() == 2500 assert torch.all(large_result >= 0) assert torch.all(large_result < 1) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.rand(()) @@ -206,7 +206,7 @@ def test_rand_edge_cases(): assert scalar_result.numel() == 1 assert torch.all(scalar_result >= 0) assert torch.all(scalar_result < 1) - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_rand_pytorch_equivalence(): @@ -255,7 +255,7 @@ def test_rand_parameter_combinations(params): assert result.shape == (3, 3) assert torch.all(result >= 0) assert torch.all(result < 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -290,7 +290,7 @@ def test_rand_symmetric_heap_shapes_dtypes(size, dtype): result = shmem.rand(*size, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" # Also verify basic functionality assert result.shape == size @@ -305,7 +305,7 @@ def test_rand_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.rand(3, 3, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -315,20 +315,20 @@ def test_rand_symmetric_heap_other_params(): # Test with requires_grad result = shmem.rand(3, 3, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.rand(3, 3, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.rand(3, 3, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.float32) + out_tensor = shmem._allocate(9, torch.float32) result = shmem.rand(3, 3, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_rand_invalid_output_tensor(): @@ -336,12 +336,12 @@ def test_rand_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.float32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.float32) # Wrong size for (3, 3) with pytest.raises(RuntimeError): shmem.rand(3, 3, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.int32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.int32) # Wrong dtype with pytest.raises(RuntimeError): shmem.rand(3, 3, dtype=torch.float32, out=wrong_dtype_tensor) @@ -411,14 +411,14 @@ def test_rand_generator(): assert result1.shape == (3, 3) assert torch.all(result1 >= 0) assert torch.all(result1 < 1) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test without generator (should still work) result2 = shmem.rand(3, 3) assert result2.shape == (3, 3) assert torch.all(result2 >= 0) assert torch.all(result2 < 1) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test that generator produces reproducible results generator1 = torch.Generator(device="cuda") @@ -442,7 +442,7 @@ def test_rand_pin_memory(): assert result.shape == (2, 3) assert torch.all(result >= 0) assert torch.all(result < 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Note: pin_memory is ignored for GPU tensors, so we just verify it doesn't cause errors @@ -468,7 +468,7 @@ def test_rand_distribution(): # Should have some values close to 1 assert max_val > 0.9, f"Maximum value {max_val} is too low" - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_rand_deterministic_behavior(): @@ -480,4 +480,4 @@ def test_rand_deterministic_behavior(): assert result.shape == (2, 3) assert torch.all(result >= 0) assert torch.all(result < 1) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) diff --git a/tests/unittests/test_randint.py b/tests/unittests/test_randint.py index a636be38..ec1fffef 100644 --- a/tests/unittests/test_randint.py +++ b/tests/unittests/test_randint.py @@ -42,7 +42,7 @@ def test_randint_basic(dtype, size): assert torch.all(result < 10) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_randint_default_dtype(): @@ -51,7 +51,7 @@ def test_randint_default_dtype(): # Test with default dtype (should use torch.int64) result = shmem.randint(0, 10, (2, 3)) assert result.dtype == torch.int64 - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -69,7 +69,7 @@ def test_randint_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_randint_device_handling(): @@ -78,23 +78,23 @@ def test_randint_device_handling(): # Test default behavior (should use Iris device) result = shmem.randint(0, 10, (3, 3)) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.randint(0, 10, (3, 3), device=shmem.device) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.randint(0, 10, (3, 3), device="cuda") assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.randint(0, 10, (3, 3), device=None) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -115,7 +115,7 @@ def test_randint_layout_handling(): # Test with strided layout (default) result = shmem.randint(0, 10, (2, 4), layout=torch.strided) assert result.layout == torch.strided - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -126,7 +126,7 @@ def test_randint_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.int64) + out_tensor = shmem._allocate(6, torch.int64) result = shmem.randint(0, 10, (2, 3), out=out_tensor) # Should share the same underlying data (same data_ptr) @@ -134,14 +134,14 @@ def test_randint_out_parameter(): assert result.shape == (2, 3) assert torch.all(result >= 0) assert torch.all(result < 10) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with explicit dtype - out_tensor_int32 = shmem._Iris__allocate(6, torch.int32) + out_tensor_int32 = shmem._allocate(6, torch.int32) result_int32 = shmem.randint(0, 10, (2, 3), dtype=torch.int32, out=out_tensor_int32) assert result_int32.data_ptr() == out_tensor_int32.data_ptr() assert result_int32.dtype == torch.int32 - assert shmem._Iris__on_symmetric_heap(result_int32) + assert shmem._on_symmetric_heap(result_int32) def test_randint_size_variations(): @@ -152,28 +152,28 @@ def test_randint_size_variations(): assert result1.shape == (5,) assert torch.all(result1 >= 0) assert torch.all(result1 < 5) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.randint(0, 10, (2, 3, 4)) assert result2.shape == (2, 3, 4) assert torch.all(result2 >= 0) assert torch.all(result2 < 10) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as single argument result3 = shmem.randint(0, 10, (3, 4)) assert result3.shape == (3, 4) assert torch.all(result3 >= 0) assert torch.all(result3 < 10) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.randint(0, 10, [2, 5]) assert result4.shape == (2, 5) assert torch.all(result4 >= 0) assert torch.all(result4 < 10) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_randint_edge_cases(): @@ -183,7 +183,7 @@ def test_randint_edge_cases(): empty_result = shmem.randint(0, 5, (0,)) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.randint(0, 10, (1,)) @@ -191,7 +191,7 @@ def test_randint_edge_cases(): assert single_result.numel() == 1 assert torch.all(single_result >= 0) assert torch.all(single_result < 10) - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.randint(0, 100, (100, 100)) @@ -199,7 +199,7 @@ def test_randint_edge_cases(): assert large_result.numel() == 10000 assert torch.all(large_result >= 0) assert torch.all(large_result < 100) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.randint(0, 10, ()) @@ -207,7 +207,7 @@ def test_randint_edge_cases(): assert scalar_result.numel() == 1 assert torch.all(scalar_result >= 0) assert torch.all(scalar_result < 10) - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_randint_pytorch_equivalence(): @@ -257,7 +257,7 @@ def test_randint_parameter_combinations(params): assert result.shape == (3, 3) assert torch.all(result >= 0) assert torch.all(result < 10) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -292,7 +292,7 @@ def test_randint_symmetric_heap_shapes_dtypes(size, dtype): result = shmem.randint(0, 10, size, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" # Also verify basic functionality assert result.shape == size @@ -307,7 +307,7 @@ def test_randint_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.randint(0, 10, (3, 3), dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -317,20 +317,20 @@ def test_randint_symmetric_heap_other_params(): # Test with requires_grad result = shmem.randint(0, 10, (3, 3), dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.randint(0, 10, (3, 3), device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.randint(0, 10, (3, 3), layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.int64) # Use default dtype + out_tensor = shmem._allocate(9, torch.int64) # Use default dtype result = shmem.randint(0, 10, (3, 3), out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_randint_invalid_output_tensor(): @@ -338,12 +338,12 @@ def test_randint_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.int32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.int32) # Wrong size for (3, 3) with pytest.raises(RuntimeError): shmem.randint(0, 10, (3, 3), out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.float32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.float32) # Wrong dtype with pytest.raises(RuntimeError): shmem.randint(0, 10, (3, 3), dtype=torch.int32, out=wrong_dtype_tensor) @@ -399,14 +399,14 @@ def test_randint_generator(): assert result1.shape == (3, 3) assert torch.all(result1 >= 0) assert torch.all(result1 < 10) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test without generator (should still work) result2 = shmem.randint(0, 10, (3, 3)) assert result2.shape == (3, 3) assert torch.all(result2 >= 0) assert torch.all(result2 < 10) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) def test_randint_argument_validation(): @@ -457,14 +457,14 @@ def test_randint_pytorch_signatures(): assert result1.shape == (2, 3) assert torch.all(result1 >= 0) assert torch.all(result1 < 10) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test randint(low, high, size) signature result2 = shmem.randint(5, 15, (2, 3)) assert result2.shape == (2, 3) assert torch.all(result2 >= 5) assert torch.all(result2 < 15) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Both should work correctly assert result1.shape == result2.shape @@ -480,4 +480,4 @@ def test_randint_deterministic_behavior(): assert result.shape == (2, 3) assert torch.all(result >= 0) assert torch.all(result < 10) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) diff --git a/tests/unittests/test_randn.py b/tests/unittests/test_randn.py index cb20ec9a..42f5c3b6 100644 --- a/tests/unittests/test_randn.py +++ b/tests/unittests/test_randn.py @@ -36,7 +36,7 @@ def test_randn_basic(dtype, size): assert result.dtype == dtype # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_randn_default_dtype(): @@ -46,7 +46,7 @@ def test_randn_default_dtype(): result = shmem.randn(2, 3) expected_dtype = torch.get_default_dtype() assert result.dtype == expected_dtype - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -64,7 +64,7 @@ def test_randn_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_randn_device_handling(): @@ -73,23 +73,23 @@ def test_randn_device_handling(): # Test default behavior (should use Iris device) result = shmem.randn(3, 3) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.randn(3, 3, device=shmem.device) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.randn(3, 3, device="cuda") assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.randn(3, 3, device=None) assert str(result.device) == str(shmem.get_device()) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -111,27 +111,27 @@ def test_randn_layout_handling(): # Test with strided layout (default) result = shmem.randn(2, 4, layout=torch.strided) assert result.layout == torch.strided - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_randn_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.float32) + out_tensor = shmem._allocate(6, torch.float32) result = shmem.randn(2, 3, out=out_tensor) # Should share the same underlying data (same data_ptr) assert result.data_ptr() == out_tensor.data_ptr() assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor (float32) - out_tensor_float = shmem._Iris__allocate(6, torch.float32) + out_tensor_float = shmem._allocate(6, torch.float32) result_float = shmem.randn(2, 3, dtype=torch.float32, out=out_tensor_float) assert result_float.data_ptr() == out_tensor_float.data_ptr() assert result_float.dtype == torch.float32 - assert shmem._Iris__on_symmetric_heap(result_float) + assert shmem._on_symmetric_heap(result_float) def test_randn_size_variations(): @@ -140,22 +140,22 @@ def test_randn_size_variations(): # Test single dimension result1 = shmem.randn(5) assert result1.shape == (5,) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.randn(2, 3, 4) assert result2.shape == (2, 3, 4) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple as single argument result3 = shmem.randn((3, 4)) assert result3.shape == (3, 4) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.randn([2, 5]) assert result4.shape == (2, 5) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_randn_edge_cases(): @@ -165,25 +165,25 @@ def test_randn_edge_cases(): empty_result = shmem.randn(0) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.randn(1) assert single_result.shape == (1,) assert single_result.numel() == 1 - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.randn(50, 50) assert large_result.shape == (50, 50) assert large_result.numel() == 2500 - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.randn(()) assert scalar_result.shape == () assert scalar_result.numel() == 1 - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_randn_pytorch_equivalence(): @@ -230,7 +230,7 @@ def test_randn_parameter_combinations(params): # Verify basic functionality assert result.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -265,7 +265,7 @@ def test_randn_symmetric_heap_shapes_dtypes(size, dtype): result = shmem.randn(*size, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" # Also verify basic functionality assert result.shape == size @@ -278,7 +278,7 @@ def test_randn_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.randn(3, 3, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -288,20 +288,20 @@ def test_randn_symmetric_heap_other_params(): # Test with requires_grad result = shmem.randn(3, 3, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.randn(3, 3, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.randn(3, 3, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.float32) + out_tensor = shmem._allocate(9, torch.float32) result = shmem.randn(3, 3, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_randn_invalid_output_tensor(): @@ -309,12 +309,12 @@ def test_randn_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.float32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.float32) # Wrong size for (3, 3) with pytest.raises(RuntimeError): shmem.randn(3, 3, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.float64) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.float64) # Wrong dtype with pytest.raises(RuntimeError): shmem.randn(3, 3, dtype=torch.float32, out=wrong_dtype_tensor) @@ -382,12 +382,12 @@ def test_randn_generator(): generator.manual_seed(42) result1 = shmem.randn(3, 3, generator=generator) assert result1.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test without generator (should still work) result2 = shmem.randn(3, 3) assert result2.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test that generator produces reproducible results generator1 = torch.Generator(device="cuda") @@ -409,12 +409,12 @@ def test_randn_pin_memory(): # Test with pin_memory=True result = shmem.randn(3, 3, pin_memory=True) assert result.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with pin_memory=False result = shmem.randn(3, 3, pin_memory=False) assert result.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Note: pin_memory is ignored for GPU tensors, so we just verify it doesn't cause errors @@ -428,7 +428,7 @@ def test_randn_deterministic_behavior(): try: result = shmem.randn(3, 3) assert result.shape == (3, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) finally: torch.use_deterministic_algorithms(False) @@ -440,9 +440,9 @@ def test_randn_examples(): # Example 1: torch.randn(4) result1 = shmem.randn(4) assert result1.shape == (4,) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Example 2: torch.randn(2, 3) result2 = shmem.randn(2, 3) assert result2.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) diff --git a/tests/unittests/test_zeros.py b/tests/unittests/test_zeros.py index 51126fed..73bad9e5 100644 --- a/tests/unittests/test_zeros.py +++ b/tests/unittests/test_zeros.py @@ -44,7 +44,7 @@ def test_zeros_basic(dtype, size): assert torch.all(result == 0) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_zeros_default_dtype(): @@ -55,7 +55,7 @@ def test_zeros_default_dtype(): expected_dtype = torch.get_default_dtype() assert result.dtype == expected_dtype assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -74,7 +74,7 @@ def test_zeros_requires_grad(requires_grad): # Verify requires_grad is set assert result.requires_grad == requires_grad assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) def test_zeros_device_handling(): @@ -84,26 +84,26 @@ def test_zeros_device_handling(): result = shmem.zeros(3, 3) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test explicit device result = shmem.zeros(3, 3, device=shmem.device) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that "cuda" shorthand works (should use current CUDA device) if shmem.device.startswith("cuda:"): result = shmem.zeros(3, 3, device="cuda") assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test None device defaults to Iris device result = shmem.zeros(3, 3, device=None) assert str(result.device) == str(shmem.get_device()) assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that different device throws error different_device = "cpu" # CPU is always different from CUDA @@ -125,7 +125,7 @@ def test_zeros_layout_handling(): result = shmem.zeros(2, 4, layout=torch.strided) assert result.layout == torch.strided assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test that unsupported layout throws error with pytest.raises(ValueError): @@ -136,22 +136,22 @@ def test_zeros_out_parameter(): shmem = iris.iris(1 << 20) # Test with out parameter - out_tensor = shmem._Iris__allocate(6, torch.float32) + out_tensor = shmem._allocate(6, torch.float32) result = shmem.zeros(2, 3, out=out_tensor) # Should share the same underlying data (same data_ptr) assert result.data_ptr() == out_tensor.data_ptr() assert torch.all(result == 0) assert result.shape == (2, 3) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Test with different dtype out tensor - out_tensor_int = shmem._Iris__allocate(6, torch.int32) + out_tensor_int = shmem._allocate(6, torch.int32) result_int = shmem.zeros(2, 3, dtype=torch.int32, out=out_tensor_int) assert result_int.data_ptr() == out_tensor_int.data_ptr() assert result_int.dtype == torch.int32 assert torch.all(result_int == 0) - assert shmem._Iris__on_symmetric_heap(result_int) + assert shmem._on_symmetric_heap(result_int) def test_zeros_size_variations(): @@ -161,25 +161,25 @@ def test_zeros_size_variations(): result1 = shmem.zeros(5) assert result1.shape == (5,) assert torch.all(result1 == 0) - assert shmem._Iris__on_symmetric_heap(result1) + assert shmem._on_symmetric_heap(result1) # Test multiple dimensions result2 = shmem.zeros(2, 3, 4) assert result2.shape == (2, 3, 4) assert torch.all(result2 == 0) - assert shmem._Iris__on_symmetric_heap(result2) + assert shmem._on_symmetric_heap(result2) # Test with tuple/list as single argument result3 = shmem.zeros((3, 4)) assert result3.shape == (3, 4) assert torch.all(result3 == 0) - assert shmem._Iris__on_symmetric_heap(result3) + assert shmem._on_symmetric_heap(result3) # Test with list as single argument result4 = shmem.zeros([2, 5]) assert result4.shape == (2, 5) assert torch.all(result4 == 0) - assert shmem._Iris__on_symmetric_heap(result4) + assert shmem._on_symmetric_heap(result4) def test_zeros_edge_cases(): @@ -189,28 +189,28 @@ def test_zeros_edge_cases(): empty_result = shmem.zeros(0) assert empty_result.shape == (0,) assert empty_result.numel() == 0 - assert shmem._Iris__on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(empty_result) # Single element tensor single_result = shmem.zeros(1) assert single_result.shape == (1,) assert single_result.numel() == 1 assert single_result[0] == 0 - assert shmem._Iris__on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(single_result) # Large tensor large_result = shmem.zeros(100, 100) assert large_result.shape == (100, 100) assert large_result.numel() == 10000 assert torch.all(large_result == 0) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(large_result) # Zero-dimensional tensor (scalar) scalar_result = shmem.zeros(()) assert scalar_result.shape == () assert scalar_result.numel() == 1 assert scalar_result.item() == 0 - assert shmem._Iris__on_symmetric_heap(scalar_result) + assert shmem._on_symmetric_heap(scalar_result) def test_zeros_pytorch_equivalence(): @@ -262,7 +262,7 @@ def test_zeros_parameter_combinations(params): # Verify basic functionality assert result.shape == (3, 3) assert torch.all(result == 0) - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) # Verify dtype if specified if "dtype" in params: @@ -297,7 +297,7 @@ def test_zeros_symmetric_heap_shapes_dtypes(size, dtype): result = shmem.zeros(*size, dtype=dtype) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with size {size}, dtype {dtype} is NOT on symmetric heap!" # Also verify basic functionality assert result.shape == size @@ -311,7 +311,7 @@ def test_zeros_symmetric_heap_dtype_override(dtype): shmem = iris.iris(1 << 20) result = shmem.zeros(3, 3, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -321,20 +321,20 @@ def test_zeros_symmetric_heap_other_params(): # Test with requires_grad result = shmem.zeros(3, 3, dtype=torch.float32, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.zeros(3, 3, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.zeros(3, 3, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" # Test with out parameter - out_tensor = shmem._Iris__allocate(9, torch.float32) + out_tensor = shmem._allocate(9, torch.float32) result = shmem.zeros(3, 3, out=out_tensor) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with out parameter is NOT on symmetric heap!" def test_zeros_invalid_output_tensor(): @@ -342,12 +342,12 @@ def test_zeros_invalid_output_tensor(): shmem = iris.iris(1 << 20) # Test with wrong size output tensor - wrong_size_tensor = shmem._Iris__allocate(4, torch.float32) # Wrong size for (3, 3) + wrong_size_tensor = shmem._allocate(4, torch.float32) # Wrong size for (3, 3) with pytest.raises(RuntimeError, match="The output tensor has 4 elements, but 9 are required"): shmem.zeros(3, 3, out=wrong_size_tensor) # Test with wrong dtype output tensor - wrong_dtype_tensor = shmem._Iris__allocate(9, torch.int32) # Wrong dtype + wrong_dtype_tensor = shmem._allocate(9, torch.int32) # Wrong dtype with pytest.raises(RuntimeError, match="The output tensor has dtype torch.int32, but torch.float32 is required"): shmem.zeros(3, 3, dtype=torch.float32, out=wrong_dtype_tensor) diff --git a/tests/unittests/test_zeros_like.py b/tests/unittests/test_zeros_like.py index b7a0ff0c..07a80826 100644 --- a/tests/unittests/test_zeros_like.py +++ b/tests/unittests/test_zeros_like.py @@ -227,9 +227,9 @@ def test_zeros_like_memory_format(): ) # Verify all results are on the symmetric heap - assert shmem._Iris__on_symmetric_heap(result_4d) - assert shmem._Iris__on_symmetric_heap(result_5d) - assert shmem._Iris__on_symmetric_heap(result_preserve_channels_last) + assert shmem._on_symmetric_heap(result_4d) + assert shmem._on_symmetric_heap(result_5d) + assert shmem._on_symmetric_heap(result_preserve_channels_last) def test_channels_last_format_shape_preservation(): @@ -285,8 +285,8 @@ def test_channels_last_format_shape_preservation(): ) # Verify tensors are on symmetric heap - assert shmem._Iris__on_symmetric_heap(result_4d) - assert shmem._Iris__on_symmetric_heap(result_5d) + assert shmem._on_symmetric_heap(result_4d) + assert shmem._on_symmetric_heap(result_5d) def test_zeros_like_pytorch_equivalence(): @@ -345,9 +345,9 @@ def test_zeros_like_edge_cases(): assert torch.all(large_result == 0) # Verify all edge case results are on symmetric heap - assert shmem._Iris__on_symmetric_heap(empty_result) - assert shmem._Iris__on_symmetric_heap(single_result) - assert shmem._Iris__on_symmetric_heap(large_result) + assert shmem._on_symmetric_heap(empty_result) + assert shmem._on_symmetric_heap(single_result) + assert shmem._on_symmetric_heap(large_result) @pytest.mark.parametrize( @@ -382,7 +382,7 @@ def test_zeros_like_parameter_combinations(params): assert result.requires_grad == params["requires_grad"] # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result) + assert shmem._on_symmetric_heap(result) @pytest.mark.parametrize( @@ -422,7 +422,7 @@ def test_zeros_like_symmetric_heap_shapes_dtypes(shape, dtype): result = shmem.zeros_like(input_tensor, memory_format=memory_format) # Verify tensor is on symmetric heap - assert shmem._Iris__on_symmetric_heap(result), ( + assert shmem._on_symmetric_heap(result), ( f"Tensor with shape {shape}, dtype {dtype}, memory_format {memory_format} is NOT on symmetric heap!" ) @@ -440,7 +440,7 @@ def test_zeros_like_symmetric_heap_dtype_override(dtype): input_tensor = shmem.full((3, 3), 1, dtype=torch.float32) result = shmem.zeros_like(input_tensor, dtype=dtype) - assert shmem._Iris__on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), f"Tensor with dtype {dtype} is NOT on symmetric heap!" assert result.dtype == dtype @@ -451,12 +451,12 @@ def test_zeros_like_symmetric_heap_other_params(): # Test with requires_grad result = shmem.zeros_like(input_tensor, requires_grad=True) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with requires_grad=True is NOT on symmetric heap!" # Test with device override result = shmem.zeros_like(input_tensor, device=shmem.device) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with device override is NOT on symmetric heap!" # Test with layout override (only strided is supported) result = shmem.zeros_like(input_tensor, layout=torch.strided) - assert shmem._Iris__on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!" + assert shmem._on_symmetric_heap(result), "Tensor with layout override is NOT on symmetric heap!"