Skip to content

Add Intel® Gaudi® HPU Support #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
26 changes: 26 additions & 0 deletions Dockerfile.hpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Use the official Gaudi Docker image with PyTorch
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest

# Set environment variables for Habana
ENV HABANA_VISIBLE_DEVICES=all
ENV OMPI_MCA_btl_vader_single_copy_mechanism=none
ENV PT_HPU_LAZY_ACC_PAR_MODE=0
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=1

# Set timezone to UTC and install essential packages
ENV DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC
RUN apt-get update && apt-get install -y \
tzdata \
python3-pip \
&& rm -rf /var/lib/apt/lists/*

COPY . /workspace/clip
WORKDIR /workspace/clip

# Copy HPU requirements
COPY requirements_hpu.txt /workspace/requirements_hpu.txt

# Install Python packages
RUN pip install --upgrade pip \
&& pip install -r requirements_hpu.txt \
&& pip install -e .
41 changes: 38 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
from clip.utils import get_device_initial

device = get_device_initial() # "HPU" if using Intel® Gaudi® HPU, "cuda" if using CUDA GPU, "cpu" otherwise
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
Expand Down Expand Up @@ -94,8 +96,10 @@ import clip
import torch
from torchvision.datasets import CIFAR100

from clip.utils import get_device_initial

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device_initial()
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
Expand Down Expand Up @@ -153,8 +157,10 @@ from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

from clip.utils import get_device_initial

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device_initial()
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
Expand Down Expand Up @@ -193,6 +199,35 @@ print(f"Accuracy = {accuracy:.3f}")
Note that the `C` value should be determined via a hyperparameter sweep using a validation split.


## Intel® Gaudi® HPU Usage

### Build the Docker Image
To use Intel® Gaudi® HPU for running this notebook, start by building a Docker image with the appropriate environment setup.

```bash
docker build -t clip_hpu:latest -f Dockerfile.hpu .
```

In the `Dockerfile.hpu`, we use the `vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest` base image. Ensure that the version matches your setup.
See the [PyTorch Docker Images for the Intel® Gaudi® Accelerator](https://developer.habana.ai/catalog/pytorch-container/) for more information.

### Run the Container

```bash
docker run -it --runtime=habana clip_hpu:latest
```

### Python Usage with Intel® Gaudi® HPU

You do not need to change the code to leverage Intel® Gaudi® HPU. The `get_device_initial()` function will automatically detect the correct device and return the appropriate device name. So no changes are required.

### Run the Tests

```bash
pytest
```
This will run the tests and verify that the model is working correctly.

## See Also

* [OpenCLIP](https://github.com/mlfoundations/open_clip): includes larger and independently trained CLIP models up to ViT-G/14
Expand Down
68 changes: 55 additions & 13 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from .utils import get_device_initial

try:
from torchvision.transforms import InterpolationMode

BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
Expand Down Expand Up @@ -51,13 +53,24 @@ def _download(url: str, root: str):
raise RuntimeError(f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
if (
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
== expected_sha256
):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
Expand Down Expand Up @@ -91,7 +104,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys())


def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
def load(
name: str,
device: Union[str, torch.device] = get_device_initial(),
jit: bool = False,
download_root: str = None,
):
"""Load a CLIP model

Parameters
Expand All @@ -100,7 +118,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

device : Union[str, torch.device]
The device to put the loaded model
The device to put the loaded model, by default it uses the device returned by `clip.get_device_initial()`

jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Expand All @@ -123,10 +141,12 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

with open(model_path, 'rb') as opened_file:
with open(model_path, "rb") as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.jit.load(
opened_file, map_location=device if jit else "cpu"
).eval()
state_dict = None
except RuntimeError:
# loading saved state dict
Expand All @@ -136,13 +156,25 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
state_dict = torch.load(opened_file, map_location="cpu")

if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
model = build_model(state_dict or model.state_dict())

if str(device) == "hpu":
from habana_frameworks.torch.utils.library_loader import load_habana_module

load_habana_module()
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
else:
model = model.to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)

# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device("cpu" if device == "hpu" else device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

def _node_get(node: torch._C.Node, key: str):
Expand Down Expand Up @@ -171,9 +203,11 @@ def patch_device(module):
patch_device(model.encode_image)
patch_device(model.encode_text)

# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
# patch dtype to float32 on CPU, HPU
if str(device) in ["cpu", "hpu"]:
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[]
)
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()

Expand All @@ -199,10 +233,18 @@ def patch_float(module):

model.float()

if str(device) == "hpu":
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
def tokenize(
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)

Expand Down
30 changes: 30 additions & 0 deletions clip/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import importlib.util

import torch


def get_device_initial(preferred_device=None):
"""
Determine the appropriate device to use (cuda, hpu, or cpu).
Args:
preferred_device (str): User-preferred device ('cuda', 'hpu', or 'cpu').

Returns:
str: Device string ('cuda', 'hpu', or 'cpu').
"""
# Check for HPU support
if importlib.util.find_spec("habana_frameworks") is not None:
from habana_frameworks.torch.utils.library_loader import load_habana_module

load_habana_module()
if torch.hpu.is_available():
if preferred_device == "hpu" or preferred_device is None:
return "hpu"

# Check for CUDA (GPU support)
if torch.cuda.is_available():
if preferred_device == "cuda" or preferred_device is None:
return "cuda"

# Default to CPU
return "cpu"
3 changes: 3 additions & 0 deletions requirements_hpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
optimum-habana==1.14.1
pytest
22 changes: 21 additions & 1 deletion tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pytest
import torch
from PIL import Image
import habana_frameworks.torch

import clip


@pytest.mark.parametrize('model_name', clip.available_models())
@pytest.mark.parametrize("model_name", clip.available_models())
def test_consistency(model_name):
device = "cpu"
jit_model, transform = clip.load(model_name, device=device, jit=True)
Expand All @@ -23,3 +24,22 @@ def test_consistency(model_name):
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()

assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)


@pytest.mark.parametrize("model_name", clip.available_models())
def test_hpu_support(model_name):
devices = ["hpu", "cpu"]
all_probs = []
for device in devices:
print(f"=== Testing {model_name} on {device} ===")
model, transform = clip.load(model_name, device=device, jit=False)

image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
logits_per_image, _ = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
all_probs.append(probs)

assert np.allclose(all_probs[0], all_probs[1], atol=0.01, rtol=0.1)