Skip to content

Wan vae trt #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@

# just4dev
devscripts/
out*
mycode
work_dir
51 changes: 51 additions & 0 deletions examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import argparse

import torch
from loguru import logger

from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.text2v.models.text_encoders.trt.clip.trt_clip_infer import CLIPTrtModelInfer


def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--model_path", help="", type=str, default="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v")
args.add_argument("--dtype", default=torch.float32)
args.add_argument("--device", default="cuda", type=str)
return args.parse_args()


def convert_trt_engine(args):
init_device = torch.device(args.device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
tokens = text_encoder_2.tokenizer(
texts,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=text_encoder_2.max_length,
padding="max_length",
return_tensors="pt",
).to(init_device)
input_ids=tokens["input_ids"].to(init_device)
attention_mask=tokens["attention_mask"].to(init_device)
onnx_path = CLIPTrtModelInfer.export_to_onnx(text_encoder_2.model, model_dir=args.model_path, input_ids=input_ids, attention_mask=attention_mask)
del text_encoder_2
torch.cuda.empty_cache()
engine_path = onnx_path.replace(".onnx", ".engine")
CLIPTrtModelInfer.convert_to_trt_engine(onnx_path, engine_path)
logger.info(f"ONNX: {onnx_path}")
logger.info(f"TRT Engine: {engine_path}")
return


def main():
args = parse_args()
convert_trt_engine(args)


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pathlib import Path
import os
import argparse

import torch
from loguru import logger

from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel
from lightx2v.text2v.models.text_encoders.trt.t5.trt_t5_infer import T5TrtModelInfer


def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--model_path", help="", type=str, default="models/Wan2.1-T2V-1.3B")
args.add_argument("--dtype", default=torch.float16)
args.add_argument("--device", default="cuda", type=str)
return args.parse_args()


def convert_trt_engine(args):
t5_checkpoint_path = os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth")
t5_tokenizer_path = os.path.join(args.model_path, "google/umt5-xxl")
assert Path(t5_checkpoint_path).exists(), f"{t5_checkpoint_path} not exists."
model = T5EncoderModel(text_len=512, dtype=args.dtype, device=args.device, checkpoint_path=t5_checkpoint_path, tokenizer_path=t5_tokenizer_path, shard_fn=None)
texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
ids, mask = model.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(args.device)
mask = mask.to(args.device)
onnx_path = T5TrtModelInfer.export_to_onnx(model.model, model_dir=args.model_path, ids=ids, mask=mask)
del model
torch.cuda.empty_cache()
engine_path = onnx_path.replace(".onnx", ".engine")
T5TrtModelInfer.convert_to_trt_engine(onnx_path, engine_path)
logger.info(f"ONNX: {onnx_path}")
logger.info(f"TRT Engine: {engine_path}")
return


def main():
args = parse_args()
convert_trt_engine(args)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/vae_trt/convert_vae_trt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def parse_args():
return args.parse_args()


def convert_vae_trt_engine(args):
def convert_trt_engine(args):
vae_path = os.path.join(args.model_path, "hunyuan-video-t2v-720p/vae")
assert Path(vae_path).exists(), f"{vae_path} not exists."
config = AutoencoderKLCausal3D.load_config(vae_path)
Expand All @@ -38,7 +38,7 @@ def convert_vae_trt_engine(args):

def main():
args = parse_args()
convert_vae_trt_engine(args)
convert_trt_engine(args)


if __name__ == "__main__":
Expand Down
102 changes: 102 additions & 0 deletions lightx2v/common/backend_infer/trt/trt_infer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from pathlib import Path

import numpy as np
import torch
import tensorrt as trt
from cuda import cudart
import torch.nn as nn
from loguru import logger

from lightx2v.common.backend_infer.trt import common

TRT_LOGGER = trt.Logger(trt.Logger.INFO)


np_torch_dtype_map = {"float16": torch.float16, "float32": torch.float32}


class TrtModelInferBase(nn.Module):
"""
Implements inference for the TensorRT engine.
"""

def __init__(self, engine_path, **kwargs):
"""
:param engine_path: The path to the serialized engine to load from disk.
"""
# Load TRT engine
if not Path(engine_path).exists():
raise FileNotFoundError(f"Tensorrt engine `{str(engine_path)}` not exists.")
self.logger = trt.Logger(trt.Logger.ERROR)
with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
assert runtime
self.engine = runtime.deserialize_cuda_engine(f.read())
assert self.engine
self.context = self.engine.create_execution_context()
assert self.context
logger.info(f"Loaded tensorrt engine from `{engine_path}`")
self.inp_list = []
self.out_list = []
self.get_io_properties()

def alloc(self, shape_dict):
"""
Setup I/O bindings
"""
self.inputs = []
self.outputs = []
self.allocations = []
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
is_input = False
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
is_input = True
dtype = self.engine.get_tensor_dtype(name)
shape = shape_dict[name]
if is_input:
self.context.set_input_shape(name, shape)
self.batch_size = shape[0]
if dtype == trt.DataType.BF16:
dtype = trt.DataType.HALF
size = np.dtype(trt.nptype(dtype)).itemsize
for s in shape:
size *= s
allocation = common.cuda_call(cudart.cudaMalloc(size))
binding = {
"index": i,
"name": name,
"dtype": np.dtype(trt.nptype(dtype)),
"shape": list(shape),
"allocation": allocation,
}
self.allocations.append(allocation)
if is_input:
self.inputs.append(binding)
else:
self.outputs.append(binding)

assert self.batch_size > 0
assert len(self.inputs) > 0
assert len(self.outputs) > 0
assert len(self.allocations) > 0

def get_io_properties(self):
for bind in self.engine:
mode = self.engine.get_tensor_mode(bind)
dtype = trt.nptype(self.engine.get_tensor_dtype(bind))
if mode.name == "INPUT":
self.inp_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype})
else:
self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype})
return

def __call__(self, batch, *args, **kwargs):
pass

@staticmethod
def export_to_onnx(model: torch.nn.Module, model_dir):
pass

@staticmethod
def convert_to_trt_engine(onnx_path, engine_path):
pass
Empty file.
61 changes: 61 additions & 0 deletions lightx2v/text2v/models/text_encoders/trt/clip/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import torch
from transformers import AutoTokenizer

from .trt_clip_infer import CLIPTrtModelInfer


class TextEncoderHFClipModel:
def __init__(self, model_path, device, **kwargs):
self.device = device
self.model_path = model_path
self.engine_path = os.path.join(model_path, "onnx/clip_l/clip_l.engine")
self.init()
self.load()

def init(self):
self.max_length = 77

def load(self):
self.model = CLIPTrtModelInfer(engine_path=self.engine_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")

def to_cpu(self):
self.model = self.model.to("cpu")

def to_cuda(self):
self.model = self.model.to("cuda")

@torch.no_grad()
def infer(self, text, args):
if args.cpu_offload:
self.to_cuda()
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")

outputs = self.model(
ids=tokens["input_ids"],
mask=tokens["attention_mask"],
)

last_hidden_state = outputs["pooler_output"]
if args.cpu_offload:
self.to_cpu()
return last_hidden_state, tokens["attention_mask"]


if __name__ == "__main__":
model_path = ""
model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
print(outputs)
73 changes: 73 additions & 0 deletions lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pathlib import Path
from subprocess import Popen

import torch
import tensorrt as trt
from loguru import logger
import numpy as np
from torch.nn.modules import Module

from lightx2v.common.backend_infer.trt import common
from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase, np_torch_dtype_map

TRT_LOGGER = trt.Logger(trt.Logger.INFO)



class CLIPTrtModelInfer(TrtModelInferBase):
def __init__(self, engine_path, **kwargs):
super().__init__(engine_path, **kwargs)

def __call__(self, ids, mask, *args, **kwargs):
device = ids.device
ids = ids.cpu().numpy()
mask = mask.cpu().numpy()
shp_dict = {i["name"]: i["shape"] for i in self.inp_list}
shp_dict.update({i["name"]: i["shape"] for i in self.out_list})
self.alloc(shp_dict)

out_list = []
for o in self.outputs:
out_list.append(np.zeros(o["shape"], o["dtype"]))
for inp, data in zip(self.inputs, [ids, mask]):
common.memcpy_host_to_device(inp["allocation"], np.ascontiguousarray(data))
self.context.execute_v2(self.allocations)
outs = []
for i, out in enumerate(out_list):
common.memcpy_device_to_host(out, self.outputs[i]["allocation"])
out = torch.from_numpy(out).to(device)
out = out.type(torch.bfloat16)
outs.append(out)
return {"pooler_output": outs[1]}

@staticmethod
def export_to_onnx(model: Module, model_dir, *args, **kwargs):
ids = kwargs.get("input_ids")
mask = kwargs.get("attention_mask")
onnx_dir = Path(model_dir) / "text_encoder_2/onnx/clip_l"
onnx_dir.mkdir(parents=True, exist_ok=True)
onnx_path = str(onnx_dir / "clip_l.onnx")

class ClipWrapper(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, input_ids, attention_mask, return_dict=False, output_hidden_states=False):
out = self.model(input_ids, attention_mask, return_dict=return_dict, output_hidden_states=output_hidden_states)
return out

model_wrapped = ClipWrapper()
model_wrapped.model = model
torch.onnx.export(model_wrapped, (ids, mask), onnx_path, opset_version=14)
return onnx_path

@staticmethod
def convert_to_trt_engine(onnx_path, engine_path, *args, **kwargs):
logger.info("Start to convert ONNX to tensorrt engine.")
cmd = f"trtexec --onnx={onnx_path} --saveEngine={engine_path} --bf16 "
p = Popen(cmd, shell=True)
p.wait()
if not Path(engine_path).exists():
raise RuntimeError(f"Convert onnx({onnx_path}) to tensorrt engine failed.")
logger.info("Finish tensorrt converting.")
return engine_path
Loading