Skip to content

Zoom backend #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
61ec568
init minimal zoom backend
Dec 22, 2024
16d3bea
resolve some build deps
123epsilon Dec 23, 2024
53deb95
minimize, fix build, torchgen logic
123epsilon Dec 25, 2024
0b7cc75
add kernel deps for llama3
123epsilon Dec 28, 2024
b33031d
fix matmul kernel
123epsilon Jan 13, 2025
1aa7e92
some ops + llama script running
123epsilon Jan 16, 2025
aaef6b9
remove deps on hipblas, hipblaslt, hipsparse, hipsolver, hipfft, roct…
123epsilon Jan 17, 2025
ac54e3e
llama example working, bmm triton kernel
123epsilon Jan 26, 2025
74b12b7
add build and llama3 demo instructions
123epsilon Jan 27, 2025
e7b9919
add range factories
123epsilon Jan 28, 2025
1eae71d
adjust find_package calls for zoom in cmake
123epsilon Jan 28, 2025
2ca34c8
add sudo to build whl
123epsilon Feb 4, 2025
5d099e9
chmod build script
123epsilon Feb 4, 2025
51f6432
CI checkout recursive
123epsilon Feb 19, 2025
6c373c5
clang-19 compat in intrusive_ptr
123epsilon Feb 19, 2025
12f62e2
add venv to audit build step
123epsilon Feb 19, 2025
a20c493
add more kernels for autograd examples
123epsilon Mar 18, 2025
358886c
add distributed support
123epsilon Mar 18, 2025
0a1a39b
add hipblas and hipblaslt with build flags
123epsilon Mar 31, 2025
fcc1e4a
add back binary, bitwise, polynomial, and conv kernels
Apr 14, 2025
7a16442
add back distance, distributions, embedding, fused, gridsampler, herm…
Apr 14, 2025
d31d5df
add back index, loss, transposed and dilated conv, normalization, and…
Apr 14, 2025
019c170
get to 100% device tests; add back Convolution, randperm, reflection/…
Apr 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
126 changes: 126 additions & 0 deletions .github/workflows/build_zoom_backend.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
name: "Build PyTorch"

on:
workflow_dispatch:
inputs:
force_debug_with_tmate:
type: boolean
description: 'Run the build with tmate session'
required: false
default: false
debug_with_tmate:
type: boolean
description: 'Run the build with a tmate session ONLY in case of failure'
required: false
default: false
pull_request:
push:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
build:

strategy:
fail-fast: false
matrix:
include:
- name: "ubuntu-22.04"
runs-on: "mi300"
# container: "rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0"
# runs-on: "nod-ai-shared-cpubuilder-manylinux-x86_64"

runs-on: ${{ matrix.runs-on }}

name: ${{ matrix.name }}

env:
CACHE_DIR: ${{ github.workspace }}/.container-cache
# either the PR number or `branch-N` where N always increments
CACHE_KEY: linux-build-test-cpp-asserts-manylinux-v2-${{ format('{0}-{1}', github.ref_name, github.run_number) }}

defaults:
run:
shell: bash

permissions:
id-token: write
contents: write

container:
image: ${{ matrix.container }}

steps:
- name: "Check out repository"
uses: actions/[email protected]
with:
submodules: recursive

- name: Enable cache
uses: actions/cache/restore@v3
with:
path: ${{ env.CACHE_DIR }}
key: ${{ env.CACHE_KEY }}
restore-keys: linux-build-test-cpp-

- name: "Build PyTorch"
id: build
run: |

export CCACHE_DIR="${{ env.CACHE_DIR }}"
export CMAKE_C_COMPILER_LAUNCHER=ccache
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export CCACHE_SLOPPINESS=include_file_ctime,include_file_mtime,time_macros

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
chmod +x ./build.sh
./build.sh

- name: "Audit"
id: audit
run: |

sudo apt install patchelf
python -m venv venv
source venv/bin/activate
pip install auditwheel
auditwheel repair -w dist --plat manylinux_2_39_x86_64 dist/torch*

- name: Save cache
uses: actions/cache/save@v3
if: ${{ !cancelled() }}
with:
path: ${{ env.CACHE_DIR }}
key: ${{ env.CACHE_KEY }}

- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: ${{ matrix.name }}_artifact
path: dist
if-no-files-found: warn

- name: Release current commit
uses: ncipollo/[email protected]
with:
artifacts: "dist/torch*.whl"
token: "${{ secrets.GITHUB_TOKEN }}"
tag: "latest"
name: "latest"
removeArtifacts: false
allowUpdates: true
replacesArtifacts: true
makeLatest: true

- name: "Setup tmate session"
if: ${{ (failure() && inputs.debug_with_tmate) || inputs.force_debug_with_tmate }}
uses: mxschmitt/[email protected]
with:
limit-access-to-actor: true
install-dependencies: ${{ startsWith(matrix.runs-on, 'macos') || startsWith(matrix.runs-on, 'windows') }}
11 changes: 9 additions & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ load("@pytorch//tools/config:defs.bzl", "if_cuda")
load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops")
load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets")
load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources")
load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources")
load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources", "aten_ufunc_generated_zoom_sources")
load("//:tools/bazel.bzl", "rules")

define_targets(rules = rules)
Expand Down Expand Up @@ -104,6 +104,12 @@ generated_cuda_cpp = [
"aten/src/ATen/RegisterSparseCsrCUDA.cpp",
]

generated_zoom_cpp = [
"aten/src/ATen/ZoomFunctions.h",
"aten/src/ATen/ZoomFunctions_inl.h",
"aten/src/ATen/RegisterPrivateUse1.cpp",
]

generate_aten(
name = "generated_aten_cpp",
srcs = aten_generation_srcs,
Expand All @@ -112,7 +118,8 @@ generate_aten(
generated_cuda_cpp +
aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") +
aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") +
aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [
aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") +
aten_ufunc_generated_zoom_sources("aten/src/ATen/{}") + [
"aten/src/ATen/Declarations.yaml",
]
),
Expand Down
136 changes: 136 additions & 0 deletions BuildingZoom.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Setup Python Env

To start out, we just need to follow the normal procedure to build PyTorch from source. For convenience I've included these steps here:

```bash
conda create -n nod-pytorch python==3.10
conda activate nod-pytorch
conda install cmake ninja
pip install -r requirements.txt
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
python setup.py develop
```

# CMake Build

Using the `USE_ZOOM` flag with CMake will enable building with HIP for ROCm without requiring any of the "HIPify" scripts in order to build. This will include HIP libraries and populate `torch.version.hip` appropriately. This flag is NOT yet entered into the `setup.py` script, so for now it needs to be added manually via `cmake` or `ccmake`.

You'll need to set the `ROCM_PATH` and `HIP_ROOT_DIR` environment variables appropriately, by default on linux these should be `/opt/rocm/` and `/opt/rocm/hip` respectively.

If you're running on Linux you can just use `build.sh` to build:
```bash
cd pytorch/
source build.sh
```

Alternatively, if you want to manually setup your CMake build you can use the following commands:

```bash
cd build/
export PYTORCH_ROCM_ARCH=gfx90a
export ROCM_PATH=/opt/rocm
export HIP_ROOT_DIR=/opt/rocm/hip
cmake -DUSE_ZOOM=ON --build . --target install
```

# Running PyTorch with Zoom

Programs using the zoom backend must be prefaced with this stub until we register a proper dispatch key in pytorch

```python
import torch
import torch.zoom
torch.utils.rename_privateuse1_backend('zoom')
torch.utils.generate_methods_for_privateuse1_backend(unsupported_dtype=None)
```

# Installing Triton

Since main Triton currently treats ROCm as if its masquerading as `torch.cuda`, we need a custom installation:

```bash
git clone https://github.com/123epsilon/triton.git
cd triton/
git checkout zoom
pip install pybind11
pip install python/
```

# Running LLama3 with Triton using LigerKernels and HuggingFace

```bash
pip install liger-kernel
```

```python
# Run Llama 3
import torch
from transformers import AutoTokenizer
from liger_kernel.transformers import AutoLigerKernelForCausalLM
from time import perf_counter as pf
torch.utils.rename_privateuse1_backend('zoom')

# Set up the model and tokenizer
model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoLigerKernelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="zoom"
)

# Function to generate text
def generate_text(prompt, max_length=30):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=max_length)
return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example usage
prompt = "Hey, how are you doing today?"
s = pf()
response = generate_text(prompt)
e = pf()
print(f"Prompt: {prompt}")
print(f"Response: {response}")

print(f"{e-s} seconds")
```

```python
# Or run the instruct-tuned variant
import torch
import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama
torch.utils.rename_privateuse1_backend('zoom')

apply_liger_kernel_to_llama()
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="zoom",
)

messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]

terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = pipeline(
messages,
max_new_tokens=30,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
print(outputs[0]["generated_text"][-1])

```
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ option(USE_CPP_CODE_COVERAGE "Compile C/C++ with code coverage flags" OFF)
option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
option(USE_TSAN "Use Thread Sanitizer" OFF)
option(USE_ZOOM "Use ZOOM HIP Backend" OFF)
option(USE_CUDA "Use CUDA" ON)
cmake_dependent_option(
USE_XPU "Use XPU. Only available on Linux." ON
Expand Down Expand Up @@ -231,12 +232,14 @@ option(USE_MAGMA "Use MAGMA" ON)
option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF)
option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF)
option(USE_NATIVE_ARCH "Use -march=native" OFF)
option(ENABLE_ZOOM_BLAS "Use HIPBlas Kernels in the ZOOM backend" ON)
option(DISABLE_HIPBLASLT "Disable HIPBlasLt Kernels in the ZOOM backend" OFF)
cmake_dependent_option(
USE_MPS "Use MPS for macOS build" ON
"MPS_FOUND" OFF)
cmake_dependent_option(
USE_NCCL "Use NCCL" ON
"USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
"USE_CUDA OR USE_ROCM OR USE_ZOOM;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_RCCL "Use RCCL" ON
USE_NCCL OFF)
cmake_dependent_option(
Expand Down
19 changes: 19 additions & 0 deletions aten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ set(ATen_CUDA_SRCS_W_SORT_BY_KEY)
set(ATen_CUDA_TEST_SRCS)
set(ATen_CUDA_INCLUDE)
set(ATen_NVRTC_STUB_SRCS)
set(ATen_HIPRTC_STUB_SRCS)
set(ATen_HIP_SRCS)
set(ATen_ZOOM_SRCS)
set(ATen_HIP_SRCS_W_SORT_BY_KEY)
set(ATen_HIP_TEST_SRCS)
set(ATen_HIP_INCLUDE)
set(ATen_ZOOM_INCLUDE)
set(ATen_MPS_SRCS)
set(ATen_MPS_TEST_SRCS)
set(ATen_XPU_SRCS)
Expand All @@ -44,6 +47,7 @@ set(ATen_CPU_DEPENDENCY_LIBS)
set(ATen_XPU_DEPENDENCY_LIBS)
set(ATen_CUDA_DEPENDENCY_LIBS)
set(ATen_HIP_DEPENDENCY_LIBS)
set(ATen_ZOOM_DEPENDENCY_LIBS)
set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS)
set(ATen_PUBLIC_HIP_DEPENDENCY_LIBS)
set(ATEN_INSTALL_BIN_SUBDIR "bin" CACHE PATH "ATen install binary subdirectory")
Expand All @@ -70,6 +74,17 @@ if(USE_ROCM)
endif()
endif()

if(USE_ZOOM)
include(LoadHIP)
if(NOT PYTORCH_FOUND_HIP)
message(WARNING "Could not load HIP, setting USE_ZOOM = OFF")
set(USE_ZOOM OFF)
else()
message(STATUS "Loaded HIP, Zoom Enabled")
endif()
endif()


# Both CUDA and ROCM are enabled and found. Report an error.
if(USE_CUDA AND USE_ROCM)
message(FATAL_ERROR "Both CUDA and ROCm are enabled and found. PyTorch can only be built with either of them. Please turn one off by using either USE_CUDA=OFF or USE_ROCM=OFF.")
Expand Down Expand Up @@ -116,12 +131,14 @@ set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE)
set(ATen_ZOOM_SRCS ${ATen_ZOOM_SRCS} PARENT_SCOPE)
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} PARENT_SCOPE)
set(ATen_MPS_TEST_SRCS ${ATen_MPS_TEST_SRCS} PARENT_SCOPE)
set(ATen_HIP_SRCS_W_SORT_BY_KEY ${ATen_HIP_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
set(ATen_XPU_TEST_SRCS ${ATen_XPU_TEST_SRCS} PARENT_SCOPE)
set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE)
set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE)
set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
Expand All @@ -132,12 +149,14 @@ set(ATen_VEC_TEST_SRCS ${ATen_VEC_TEST_SRCS} PARENT_SCOPE)
set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE)
set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE)
set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
set(ATen_ZOOM_INCLUDE ${ATen_ZOOM_INCLUDE} PARENT_SCOPE)
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_ZOOM_DEPENDENCY_LIBS ${ATen_ZOOM_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
Expand Down
Loading
Loading