From 7a341d99abbb58547b0002475e1a1bc32ec43d73 Mon Sep 17 00:00:00 2001 From: gushiqiao Date: Mon, 14 Apr 2025 02:48:04 +0000 Subject: [PATCH 1/5] add t5 trt infer and convert, refactored vae trt infer --- .gitignore | 3 + .../convert_t5xxl_to_trt_engine.py | 52 +++++++ examples/vae_trt/convert_vae_trt_engine.py | 4 +- .../backend_infer/trt/trt_infer_base.py | 130 ++++++++++++++++++ .../models/text_encoders/trt/t5/model.py | 51 +++++++ .../text_encoders/trt/t5/trt_t5_infer.py | 73 ++++++++++ .../models/video_encoders/hf/wan/vae.py | 14 +- .../autoencoder_kl_causal_3d/trt_vae_infer.py | 89 +----------- 8 files changed, 328 insertions(+), 88 deletions(-) create mode 100644 examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py create mode 100644 lightx2v/common/backend_infer/trt/trt_infer_base.py create mode 100644 lightx2v/text2v/models/text_encoders/trt/t5/model.py create mode 100644 lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py diff --git a/.gitignore b/.gitignore index ad9e8e0..d5596fb 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ # just4dev devscripts/ +out* +mycode +work_dir diff --git a/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py b/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py new file mode 100644 index 0000000..13f89af --- /dev/null +++ b/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py @@ -0,0 +1,52 @@ +from pathlib import Path +import os +import argparse + +import torch +from loguru import logger + +from lightx2v.text2v.models.text_encoders.hf.t5.model import T5EncoderModel +from lightx2v.text2v.models.text_encoders.trt.t5.trt_t5_infer import T5TrtModelInfer + + +def parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model_path", help="", type=str, default="models/Wan2.1-T2V-1.3B") + args.add_argument("--dtype", default=torch.float16) + args.add_argument("--device", default="cuda", type=str) + return args.parse_args() + + +def convert_trt_engine(args): + t5_checkpoint_path = os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth") + t5_tokenizer_path = os.path.join(args.model_path, "google/umt5-xxl") + assert Path(t5_checkpoint_path).exists(), f"{t5_checkpoint_path} not exists." + model = T5EncoderModel( + text_len=512, + dtype=args.dtype, + device=args.device, + checkpoint_path=t5_checkpoint_path, + tokenizer_path=t5_tokenizer_path, + shard_fn=None + ) + texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + ids, mask = model.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(args.device) + mask = mask.to(args.device) + onnx_path = T5TrtModelInfer.export_to_onnx(model.model, model_dir=args.model_path, ids=ids, mask=mask) + del model + torch.cuda.empty_cache() + engine_path = onnx_path.replace(".onnx", ".engine") + T5TrtModelInfer.convert_to_trt_engine(onnx_path, engine_path) + logger.info(f"ONNX: {onnx_path}") + logger.info(f"TRT Engine: {engine_path}") + return + + +def main(): + args = parse_args() + convert_trt_engine(args) + + +if __name__ == "__main__": + main() diff --git a/examples/vae_trt/convert_vae_trt_engine.py b/examples/vae_trt/convert_vae_trt_engine.py index 0882449..50ecf0f 100644 --- a/examples/vae_trt/convert_vae_trt_engine.py +++ b/examples/vae_trt/convert_vae_trt_engine.py @@ -17,7 +17,7 @@ def parse_args(): return args.parse_args() -def convert_vae_trt_engine(args): +def convert_trt_engine(args): vae_path = os.path.join(args.model_path, "hunyuan-video-t2v-720p/vae") assert Path(vae_path).exists(), f"{vae_path} not exists." config = AutoencoderKLCausal3D.load_config(vae_path) @@ -38,7 +38,7 @@ def convert_vae_trt_engine(args): def main(): args = parse_args() - convert_vae_trt_engine(args) + convert_trt_engine(args) if __name__ == "__main__": diff --git a/lightx2v/common/backend_infer/trt/trt_infer_base.py b/lightx2v/common/backend_infer/trt/trt_infer_base.py new file mode 100644 index 0000000..be280a9 --- /dev/null +++ b/lightx2v/common/backend_infer/trt/trt_infer_base.py @@ -0,0 +1,130 @@ +from pathlib import Path + +import numpy as np +import torch +import tensorrt as trt +from cuda import cudart +import torch.nn as nn +from loguru import logger + +from lightx2v.common.backend_infer.trt import common + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + +np_torch_dtype_map = { + "float16": torch.float16, + "float32": torch.float32 +} + + +class TrtModelInferBase(nn.Module): + """ + Implements inference for the TensorRT engine. + """ + + def __init__(self, engine_path, **kwargs): + """ + :param engine_path: The path to the serialized engine to load from disk. + """ + # Load TRT engine + if not Path(engine_path).exists(): + raise FileNotFoundError(f"Tensorrt engine `{str(engine_path)}` not exists.") + self.logger = trt.Logger(trt.Logger.ERROR) + with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: + assert runtime + self.engine = runtime.deserialize_cuda_engine(f.read()) + assert self.engine + self.context = self.engine.create_execution_context() + assert self.context + logger.info(f"Loaded tensorrt engine from `{engine_path}`") + self.inp_list = [] + self.out_list = [] + self.get_io_properties() + + def alloc(self, shape_dict): + """ + Setup I/O bindings + """ + self.inputs = [] + self.outputs = [] + self.allocations = [] + for i in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(i) + is_input = False + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + is_input = True + dtype = self.engine.get_tensor_dtype(name) + shape = shape_dict[name] + if is_input: + self.context.set_input_shape(name, shape) + self.batch_size = shape[0] + if dtype == trt.DataType.BF16: + dtype = trt.DataType.HALF + size = np.dtype(trt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = common.cuda_call(cudart.cudaMalloc(size)) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(trt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + self.allocations.append(allocation) + if is_input: + self.inputs.append(binding) + else: + self.outputs.append(binding) + + assert self.batch_size > 0 + assert len(self.inputs) > 0 + assert len(self.outputs) > 0 + assert len(self.allocations) > 0 + + def input_spec(self): + """ + Get the specs for the input tensor of the network. Useful to prepare memory allocations. + :return: Two items, the shape of the input tensor and its (numpy) datatype. + """ + return self.inputs[0]["shape"], self.inputs[0]["dtype"] + + def get_io_properties(self): + for bind in self.engine: + mode = self.engine.get_tensor_mode(bind) + if mode.name == "INPUT": + self.inp_list.append( + { + "name": bind, + "shape": self.engine.get_tensor_shape(bind), + "dtype": self.engine.get_tensor_dtype(bind).name + } + ) + else: + self.out_list.append( + { + "name": bind, + "shape": self.engine.get_tensor_shape(bind), + "dtype": self.engine.get_tensor_dtype(bind).name + } + ) + return + + def output_spec(self): + """ + Get the specs for the output tensor of the network. Useful to prepare memory allocations. + :return: Two items, the shape of the output tensor and its (numpy) datatype. + """ + return self.outputs[0]["shape"], self.outputs[0]["dtype"] + + def __call__(self, batch, *args, **kwargs): + pass + + @staticmethod + def export_to_onnx(model: torch.nn.Module, model_dir): + pass + + @staticmethod + def convert_to_trt_engine(onnx_path, engine_path): + pass \ No newline at end of file diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/model.py b/lightx2v/text2v/models/text_encoders/trt/t5/model.py new file mode 100644 index 0000000..f41dd87 --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/t5/model.py @@ -0,0 +1,51 @@ +import logging + +import torch + +from ...hf.t5.tokenizer import HuggingfaceTokenizer +from .trt_t5_infer import T5TrtModelInfer + + + +class T5EncoderModel: + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + engine_path=None, + checkpoint_path=None, + tokenizer_path=None, + **kwargs + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + self.model = T5TrtModelInfer(engine_path=engine_path) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def to_cpu(self): + self.model = self.model.to("cpu") + + def to_cuda(self): + self.model = self.model.to("cuda") + + def infer(self, texts, args): + if args.cpu_offload: + self.to_cuda() + + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.cuda() + mask = mask.cuda() + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + + if args.cpu_offload: + self.to_cpu() + + return [u[:v] for u, v in zip(context, seq_lens)] \ No newline at end of file diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py new file mode 100644 index 0000000..a89ecfb --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py @@ -0,0 +1,73 @@ +import os +from pathlib import Path +from subprocess import Popen + +import torch +import tensorrt as trt +from loguru import logger +import numpy as np +from torch.nn.modules import Module + +from lightx2v.common.backend_infer.trt import common +from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase, np_torch_dtype_map + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + +class T5TrtModelInfer(TrtModelInferBase): + def __init__(self, engine_path, **kwargs): + super().__init__(engine_path, **kwargs) + import onnxruntime as ort + + def __call__(self, ids, mask, *args, **kwargs): + device = ids.device + ids = ids.cpu().numpy() + mask = mask.cpu().numpy() + shp_dict = {i['name']: i['shape'] for i in self.inp_list} + shp_dict.update({i['name']: i['shape'] for i in self.out_list}) + self.alloc(shp_dict) + + out_list = [] + for o in self.outputs: + out_list.append(np.zeros(o["shape"], o['dtype'])) + for inp, data in zip(self.inputs, [ids, mask]): + common.memcpy_host_to_device(inp["allocation"], np.ascontiguousarray(data)) + self.context.execute_v2(self.allocations) + outs = [] + for i, out in enumerate(out_list): + common.memcpy_device_to_host(out, self.outputs[i]["allocation"]) + out = torch.from_numpy(out).to(device) + out = out.type(torch.bfloat16) + outs.append(out) + return outs[0] + + @staticmethod + def export_to_onnx(model: Module, model_dir, *args, **kwargs): + ids = kwargs.get("ids") + mask = kwargs.get("mask") + onnx_dir = Path(model_dir) / "onnx/t5" + onnx_dir.mkdir(parents=True, exist_ok=True) + onnx_path = str(onnx_dir/"t5.onnx") + torch.onnx.export( + model, + (ids, mask), + onnx_path, + opset_version=14 + ) + return onnx_path + + @staticmethod + def convert_to_trt_engine(onnx_path, engine_path, *args, **kwargs): + logger.info("Start to convert ONNX to tensorrt engine.") + cmd = ( + "trtexec " + f"--onnx={onnx_path} " + f"--saveEngine={engine_path} " + "--bf16 " + ) + p = Popen(cmd, shell=True) + p.wait() + if not Path(engine_path).exists(): + raise RuntimeError(f"Convert onnx({onnx_path}) to tensorrt engine failed.") + logger.info("Finish tensorrt converting.") + return engine_path \ No newline at end of file diff --git a/lightx2v/text2v/models/video_encoders/hf/wan/vae.py b/lightx2v/text2v/models/video_encoders/hf/wan/vae.py index d1a62f5..ab6ba8b 100755 --- a/lightx2v/text2v/models/video_encoders/hf/wan/vae.py +++ b/lightx2v/text2v/models/video_encoders/hf/wan/vae.py @@ -82,12 +82,12 @@ def __init__(self, dim, mode): # layers if mode == "upsample2d": self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + Upsample(scale_factor=(2.0, 2.0), mode="nearest"), nn.Conv2d(dim, dim // 2, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + Upsample(scale_factor=(2.0, 2.0), mode="nearest"), nn.Conv2d(dim, dim // 2, 3, padding=1), ) self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -253,7 +253,8 @@ def forward(self, x): k, v, ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + # x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + x = x.view(x.shape[0], x.shape[2], x.shape[3]).permute(0, 2, 1).reshape(b * t, c, h, w) # output x = self.proj(x) @@ -649,7 +650,8 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs): return model -class WanVAE: +class WanVAE(nn.Module): + def __init__( self, z_dim=16, @@ -658,6 +660,7 @@ def __init__( device="cuda", parallel=False, ): + super().__init__() self.dtype = dtype self.device = device self.parallel = parallel @@ -811,3 +814,6 @@ def decode(self, zs, generator, args): self.to_cpu() return images + + def forward(self, zs): + return self.decode(zs, None, None) \ No newline at end of file diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py index 03664c2..cb43258 100644 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py @@ -2,97 +2,26 @@ from pathlib import Path from subprocess import Popen -import numpy as np import torch +import numpy as np import tensorrt as trt -from cuda import cudart -import torch.nn as nn from loguru import logger -from lightx2v.common.backend_infer.trt import common +from lightx2v import common +from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase TRT_LOGGER = trt.Logger(trt.Logger.INFO) -class HyVaeTrtModelInfer(nn.Module): +class HyVaeTrtModelInfer(TrtModelInferBase): """ - Implements inference for the TensorRT engine. + Implements hunyuan vae inference for the TensorRT engine. """ def __init__(self, engine_path): - """ - :param engine_path: The path to the serialized engine to load from disk. - """ - # Load TRT engine - if not Path(engine_path).exists(): - # dir_name = str(Path(engine_path).parents) - # onnx_path = self.export_to_onnx(decoder, dir_name) - # self.convert_to_trt_engine(onnx_path, engine_path) - raise FileNotFoundError(f"VAE tensorrt engine `{str(engine_path)}` not exists.") - self.logger = trt.Logger(trt.Logger.ERROR) - with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: - assert runtime - self.engine = runtime.deserialize_cuda_engine(f.read()) - assert self.engine - self.context = self.engine.create_execution_context() - assert self.context - logger.info(f"Loaded VAE tensorrt engine from `{engine_path}`") - - def alloc(self, shape_dict): - """ - Setup I/O bindings - """ - self.inputs = [] - self.outputs = [] - self.allocations = [] - for i in range(self.engine.num_io_tensors): - name = self.engine.get_tensor_name(i) - is_input = False - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: - is_input = True - dtype = self.engine.get_tensor_dtype(name) - # shape = self.engine.get_tensor_shape(name) - shape = shape_dict[name] - if is_input: - self.context.set_input_shape(name, shape) - self.batch_size = shape[0] - size = np.dtype(trt.nptype(dtype)).itemsize - for s in shape: - size *= s - allocation = common.cuda_call(cudart.cudaMalloc(size)) - binding = { - "index": i, - "name": name, - "dtype": np.dtype(trt.nptype(dtype)), - "shape": list(shape), - "allocation": allocation, - } - self.allocations.append(allocation) - if is_input: - self.inputs.append(binding) - else: - self.outputs.append(binding) - - assert self.batch_size > 0 - assert len(self.inputs) > 0 - assert len(self.outputs) > 0 - assert len(self.allocations) > 0 - - def input_spec(self): - """ - Get the specs for the input tensor of the network. Useful to prepare memory allocations. - :return: Two items, the shape of the input tensor and its (numpy) datatype. - """ - return self.inputs[0]["shape"], self.inputs[0]["dtype"] - - def output_spec(self): - """ - Get the specs for the output tensor of the network. Useful to prepare memory allocations. - :return: Two items, the shape of the output tensor and its (numpy) datatype. - """ - return self.outputs[0]["shape"], self.outputs[0]["dtype"] + super().__init__(engine_path) - def __call__(self, batch, top=1): + def __call__(self, batch, *args, **kwargs): """ Execute inference """ @@ -132,11 +61,7 @@ def export_to_onnx(decoder: torch.nn.Module, model_dir): opset_version=14, dynamic_axes={"inp": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}, "out": {1: "c1", 2: "c2", 3: "c3", 4: "c4"}}, ) - # onnx_ori = onnx.load(out_path) os.system(f"onnxsim {out_path} {out_path}") - # onnx_opt, check = simplify(onnx_ori) - # assert check, f"Simplified ONNX model({out_path}) could not be validated." - # onnx.save(onnx_opt, out_path) logger.info("Finish VAE onnx exporting.") return out_path From 22d12ab29bae984497feb8ec3624bee0898406ff Mon Sep 17 00:00:00 2001 From: Wq-dd <1904007277@qq.com> Date: Mon, 14 Apr 2025 05:04:51 +0000 Subject: [PATCH 2/5] precommit --- .../convert_t5xxl_to_trt_engine.py | 13 +++------- .../backend_infer/trt/trt_infer_base.py | 23 ++++-------------- .../models/text_encoders/trt/t5/model.py | 14 ++--------- .../text_encoders/trt/t5/trt_t5_infer.py | 24 ++++++------------- 4 files changed, 16 insertions(+), 58 deletions(-) diff --git a/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py b/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py index 13f89af..69e2d46 100644 --- a/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py +++ b/examples/text_encoder_trt/convert_t5xxl_to_trt_engine.py @@ -19,19 +19,12 @@ def parse_args(): def convert_trt_engine(args): t5_checkpoint_path = os.path.join(args.model_path, "models_t5_umt5-xxl-enc-bf16.pth") - t5_tokenizer_path = os.path.join(args.model_path, "google/umt5-xxl") + t5_tokenizer_path = os.path.join(args.model_path, "google/umt5-xxl") assert Path(t5_checkpoint_path).exists(), f"{t5_checkpoint_path} not exists." - model = T5EncoderModel( - text_len=512, - dtype=args.dtype, - device=args.device, - checkpoint_path=t5_checkpoint_path, - tokenizer_path=t5_tokenizer_path, - shard_fn=None - ) + model = T5EncoderModel(text_len=512, dtype=args.dtype, device=args.device, checkpoint_path=t5_checkpoint_path, tokenizer_path=t5_tokenizer_path, shard_fn=None) texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." ids, mask = model.tokenizer(texts, return_mask=True, add_special_tokens=True) - ids = ids.to(args.device) + ids = ids.to(args.device) mask = mask.to(args.device) onnx_path = T5TrtModelInfer.export_to_onnx(model.model, model_dir=args.model_path, ids=ids, mask=mask) del model diff --git a/lightx2v/common/backend_infer/trt/trt_infer_base.py b/lightx2v/common/backend_infer/trt/trt_infer_base.py index be280a9..ef5453c 100644 --- a/lightx2v/common/backend_infer/trt/trt_infer_base.py +++ b/lightx2v/common/backend_infer/trt/trt_infer_base.py @@ -12,10 +12,7 @@ TRT_LOGGER = trt.Logger(trt.Logger.INFO) -np_torch_dtype_map = { - "float16": torch.float16, - "float32": torch.float32 -} +np_torch_dtype_map = {"float16": torch.float16, "float32": torch.float32} class TrtModelInferBase(nn.Module): @@ -94,21 +91,9 @@ def get_io_properties(self): for bind in self.engine: mode = self.engine.get_tensor_mode(bind) if mode.name == "INPUT": - self.inp_list.append( - { - "name": bind, - "shape": self.engine.get_tensor_shape(bind), - "dtype": self.engine.get_tensor_dtype(bind).name - } - ) + self.inp_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": self.engine.get_tensor_dtype(bind).name}) else: - self.out_list.append( - { - "name": bind, - "shape": self.engine.get_tensor_shape(bind), - "dtype": self.engine.get_tensor_dtype(bind).name - } - ) + self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": self.engine.get_tensor_dtype(bind).name}) return def output_spec(self): @@ -127,4 +112,4 @@ def export_to_onnx(model: torch.nn.Module, model_dir): @staticmethod def convert_to_trt_engine(onnx_path, engine_path): - pass \ No newline at end of file + pass diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/model.py b/lightx2v/text2v/models/text_encoders/trt/t5/model.py index f41dd87..5d021cb 100644 --- a/lightx2v/text2v/models/text_encoders/trt/t5/model.py +++ b/lightx2v/text2v/models/text_encoders/trt/t5/model.py @@ -6,18 +6,8 @@ from .trt_t5_infer import T5TrtModelInfer - class T5EncoderModel: - def __init__( - self, - text_len, - dtype=torch.bfloat16, - device=torch.cuda.current_device(), - engine_path=None, - checkpoint_path=None, - tokenizer_path=None, - **kwargs - ): + def __init__(self, text_len, dtype=torch.bfloat16, device=torch.cuda.current_device(), engine_path=None, checkpoint_path=None, tokenizer_path=None, **kwargs): self.text_len = text_len self.dtype = dtype self.device = device @@ -48,4 +38,4 @@ def infer(self, texts, args): if args.cpu_offload: self.to_cpu() - return [u[:v] for u, v in zip(context, seq_lens)] \ No newline at end of file + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py index a89ecfb..0918c75 100644 --- a/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py +++ b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py @@ -23,13 +23,13 @@ def __call__(self, ids, mask, *args, **kwargs): device = ids.device ids = ids.cpu().numpy() mask = mask.cpu().numpy() - shp_dict = {i['name']: i['shape'] for i in self.inp_list} - shp_dict.update({i['name']: i['shape'] for i in self.out_list}) + shp_dict = {i["name"]: i["shape"] for i in self.inp_list} + shp_dict.update({i["name"]: i["shape"] for i in self.out_list}) self.alloc(shp_dict) out_list = [] for o in self.outputs: - out_list.append(np.zeros(o["shape"], o['dtype'])) + out_list.append(np.zeros(o["shape"], o["dtype"])) for inp, data in zip(self.inputs, [ids, mask]): common.memcpy_host_to_device(inp["allocation"], np.ascontiguousarray(data)) self.context.execute_v2(self.allocations) @@ -47,27 +47,17 @@ def export_to_onnx(model: Module, model_dir, *args, **kwargs): mask = kwargs.get("mask") onnx_dir = Path(model_dir) / "onnx/t5" onnx_dir.mkdir(parents=True, exist_ok=True) - onnx_path = str(onnx_dir/"t5.onnx") - torch.onnx.export( - model, - (ids, mask), - onnx_path, - opset_version=14 - ) + onnx_path = str(onnx_dir / "t5.onnx") + torch.onnx.export(model, (ids, mask), onnx_path, opset_version=14) return onnx_path @staticmethod def convert_to_trt_engine(onnx_path, engine_path, *args, **kwargs): logger.info("Start to convert ONNX to tensorrt engine.") - cmd = ( - "trtexec " - f"--onnx={onnx_path} " - f"--saveEngine={engine_path} " - "--bf16 " - ) + cmd = f"trtexec --onnx={onnx_path} --saveEngine={engine_path} --bf16 " p = Popen(cmd, shell=True) p.wait() if not Path(engine_path).exists(): raise RuntimeError(f"Convert onnx({onnx_path}) to tensorrt engine failed.") logger.info("Finish tensorrt converting.") - return engine_path \ No newline at end of file + return engine_path From dec0e5e3bd6dd9ba8eb62bb2cde96cef70be8373 Mon Sep 17 00:00:00 2001 From: Wq-dd <1904007277@qq.com> Date: Mon, 14 Apr 2025 05:27:38 +0000 Subject: [PATCH 3/5] refactor trt infer base class --- .../common/backend_infer/trt/trt_infer_base.py | 14 -------------- .../trt/autoencoder_kl_causal_3d/trt_vae_infer.py | 5 ++++- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/lightx2v/common/backend_infer/trt/trt_infer_base.py b/lightx2v/common/backend_infer/trt/trt_infer_base.py index ef5453c..df349b1 100644 --- a/lightx2v/common/backend_infer/trt/trt_infer_base.py +++ b/lightx2v/common/backend_infer/trt/trt_infer_base.py @@ -80,13 +80,6 @@ def alloc(self, shape_dict): assert len(self.outputs) > 0 assert len(self.allocations) > 0 - def input_spec(self): - """ - Get the specs for the input tensor of the network. Useful to prepare memory allocations. - :return: Two items, the shape of the input tensor and its (numpy) datatype. - """ - return self.inputs[0]["shape"], self.inputs[0]["dtype"] - def get_io_properties(self): for bind in self.engine: mode = self.engine.get_tensor_mode(bind) @@ -96,13 +89,6 @@ def get_io_properties(self): self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": self.engine.get_tensor_dtype(bind).name}) return - def output_spec(self): - """ - Get the specs for the output tensor of the network. Useful to prepare memory allocations. - :return: Two items, the shape of the output tensor and its (numpy) datatype. - """ - return self.outputs[0]["shape"], self.outputs[0]["dtype"] - def __call__(self, batch, *args, **kwargs): pass diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py index cb43258..4eb26cd 100644 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py @@ -21,6 +21,9 @@ class HyVaeTrtModelInfer(TrtModelInferBase): def __init__(self, engine_path): super().__init__(engine_path) + def output_spec(self): + return self.outputs[0]["shape"], self.outputs[0]["dtype"] + def __call__(self, batch, *args, **kwargs): """ Execute inference @@ -37,7 +40,7 @@ def get_output_shape(shp): shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)} self.alloc(shp_dict) - output = np.zeros(*self.output_spec()) + output = np.zeros(*self.out_list[0]["shape"], self.out_list[0]["dtype"]) # Process I/O and execute the network common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch)) From 08a751ca63618f3d43e38f37dec0989e5bf4ed55 Mon Sep 17 00:00:00 2001 From: Wq-dd <1904007277@qq.com> Date: Mon, 14 Apr 2025 05:28:39 +0000 Subject: [PATCH 4/5] refactor trt infer base class --- .../trt/autoencoder_kl_causal_3d/trt_vae_infer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py index 4eb26cd..c7ad8aa 100644 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py @@ -21,9 +21,6 @@ class HyVaeTrtModelInfer(TrtModelInferBase): def __init__(self, engine_path): super().__init__(engine_path) - def output_spec(self): - return self.outputs[0]["shape"], self.outputs[0]["dtype"] - def __call__(self, batch, *args, **kwargs): """ Execute inference From f7654bca9bb5833ec09626cf45c90f12ce9a04ce Mon Sep 17 00:00:00 2001 From: Wq-dd <1904007277@qq.com> Date: Mon, 14 Apr 2025 08:02:27 +0000 Subject: [PATCH 5/5] add clip trt infer --- .../convert_CLIP_L_to_trt_engine.py | 51 +++++++++++++ .../backend_infer/trt/trt_infer_base.py | 5 +- .../models/text_encoders/trt/clip/__init__.py | 0 .../models/text_encoders/trt/clip/model.py | 61 ++++++++++++++++ .../text_encoders/trt/clip/trt_clip_infer.py | 73 +++++++++++++++++++ .../models/text_encoders/trt/t5/model.py | 4 +- .../text_encoders/trt/t5/trt_t5_infer.py | 1 - .../trt/autoencoder_kl_causal_3d/model.py | 4 +- .../autoencoder_kl_causal_3d/trt_vae_infer.py | 7 +- 9 files changed, 195 insertions(+), 11 deletions(-) create mode 100644 examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py create mode 100644 lightx2v/text2v/models/text_encoders/trt/clip/__init__.py create mode 100644 lightx2v/text2v/models/text_encoders/trt/clip/model.py create mode 100644 lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py diff --git a/examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py b/examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py new file mode 100644 index 0000000..0b3a662 --- /dev/null +++ b/examples/text_encoder_trt/convert_CLIP_L_to_trt_engine.py @@ -0,0 +1,51 @@ +import os +import argparse + +import torch +from loguru import logger + +from lightx2v.text2v.models.text_encoders.hf.clip.model import TextEncoderHFClipModel +from lightx2v.text2v.models.text_encoders.trt.clip.trt_clip_infer import CLIPTrtModelInfer + + +def parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--model_path", help="", type=str, default="/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/t2v") + args.add_argument("--dtype", default=torch.float32) + args.add_argument("--device", default="cuda", type=str) + return args.parse_args() + + +def convert_trt_engine(args): + init_device = torch.device(args.device) + text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device) + texts = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + tokens = text_encoder_2.tokenizer( + texts, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + truncation=True, + max_length=text_encoder_2.max_length, + padding="max_length", + return_tensors="pt", + ).to(init_device) + input_ids=tokens["input_ids"].to(init_device) + attention_mask=tokens["attention_mask"].to(init_device) + onnx_path = CLIPTrtModelInfer.export_to_onnx(text_encoder_2.model, model_dir=args.model_path, input_ids=input_ids, attention_mask=attention_mask) + del text_encoder_2 + torch.cuda.empty_cache() + engine_path = onnx_path.replace(".onnx", ".engine") + CLIPTrtModelInfer.convert_to_trt_engine(onnx_path, engine_path) + logger.info(f"ONNX: {onnx_path}") + logger.info(f"TRT Engine: {engine_path}") + return + + +def main(): + args = parse_args() + convert_trt_engine(args) + + +if __name__ == "__main__": + main() diff --git a/lightx2v/common/backend_infer/trt/trt_infer_base.py b/lightx2v/common/backend_infer/trt/trt_infer_base.py index df349b1..f1b6e44 100644 --- a/lightx2v/common/backend_infer/trt/trt_infer_base.py +++ b/lightx2v/common/backend_infer/trt/trt_infer_base.py @@ -83,10 +83,11 @@ def alloc(self, shape_dict): def get_io_properties(self): for bind in self.engine: mode = self.engine.get_tensor_mode(bind) + dtype = trt.nptype(self.engine.get_tensor_dtype(bind)) if mode.name == "INPUT": - self.inp_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": self.engine.get_tensor_dtype(bind).name}) + self.inp_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype}) else: - self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": self.engine.get_tensor_dtype(bind).name}) + self.out_list.append({"name": bind, "shape": self.engine.get_tensor_shape(bind), "dtype": dtype}) return def __call__(self, batch, *args, **kwargs): diff --git a/lightx2v/text2v/models/text_encoders/trt/clip/__init__.py b/lightx2v/text2v/models/text_encoders/trt/clip/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lightx2v/text2v/models/text_encoders/trt/clip/model.py b/lightx2v/text2v/models/text_encoders/trt/clip/model.py new file mode 100644 index 0000000..5f4856f --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/clip/model.py @@ -0,0 +1,61 @@ +import os + +import torch +from transformers import AutoTokenizer + +from .trt_clip_infer import CLIPTrtModelInfer + + +class TextEncoderHFClipModel: + def __init__(self, model_path, device, **kwargs): + self.device = device + self.model_path = model_path + self.engine_path = os.path.join(model_path, "onnx/clip_l/clip_l.engine") + self.init() + self.load() + + def init(self): + self.max_length = 77 + + def load(self): + self.model = CLIPTrtModelInfer(engine_path=self.engine_path) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right") + + def to_cpu(self): + self.model = self.model.to("cpu") + + def to_cuda(self): + self.model = self.model.to("cuda") + + @torch.no_grad() + def infer(self, text, args): + if args.cpu_offload: + self.to_cuda() + tokens = self.tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ).to("cuda") + + outputs = self.model( + ids=tokens["input_ids"], + mask=tokens["attention_mask"], + ) + + last_hidden_state = outputs["pooler_output"] + if args.cpu_offload: + self.to_cpu() + return last_hidden_state, tokens["attention_mask"] + + +if __name__ == "__main__": + model_path = "" + model = TextEncoderHFClipModel(model_path, torch.device("cuda")) + text = "A cat walks on the grass, realistic style." + outputs = model.infer(text) + print(outputs) diff --git a/lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py b/lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py new file mode 100644 index 0000000..6c48306 --- /dev/null +++ b/lightx2v/text2v/models/text_encoders/trt/clip/trt_clip_infer.py @@ -0,0 +1,73 @@ +from pathlib import Path +from subprocess import Popen + +import torch +import tensorrt as trt +from loguru import logger +import numpy as np +from torch.nn.modules import Module + +from lightx2v.common.backend_infer.trt import common +from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase, np_torch_dtype_map + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + + +class CLIPTrtModelInfer(TrtModelInferBase): + def __init__(self, engine_path, **kwargs): + super().__init__(engine_path, **kwargs) + + def __call__(self, ids, mask, *args, **kwargs): + device = ids.device + ids = ids.cpu().numpy() + mask = mask.cpu().numpy() + shp_dict = {i["name"]: i["shape"] for i in self.inp_list} + shp_dict.update({i["name"]: i["shape"] for i in self.out_list}) + self.alloc(shp_dict) + + out_list = [] + for o in self.outputs: + out_list.append(np.zeros(o["shape"], o["dtype"])) + for inp, data in zip(self.inputs, [ids, mask]): + common.memcpy_host_to_device(inp["allocation"], np.ascontiguousarray(data)) + self.context.execute_v2(self.allocations) + outs = [] + for i, out in enumerate(out_list): + common.memcpy_device_to_host(out, self.outputs[i]["allocation"]) + out = torch.from_numpy(out).to(device) + out = out.type(torch.bfloat16) + outs.append(out) + return {"pooler_output": outs[1]} + + @staticmethod + def export_to_onnx(model: Module, model_dir, *args, **kwargs): + ids = kwargs.get("input_ids") + mask = kwargs.get("attention_mask") + onnx_dir = Path(model_dir) / "text_encoder_2/onnx/clip_l" + onnx_dir.mkdir(parents=True, exist_ok=True) + onnx_path = str(onnx_dir / "clip_l.onnx") + + class ClipWrapper(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, input_ids, attention_mask, return_dict=False, output_hidden_states=False): + out = self.model(input_ids, attention_mask, return_dict=return_dict, output_hidden_states=output_hidden_states) + return out + + model_wrapped = ClipWrapper() + model_wrapped.model = model + torch.onnx.export(model_wrapped, (ids, mask), onnx_path, opset_version=14) + return onnx_path + + @staticmethod + def convert_to_trt_engine(onnx_path, engine_path, *args, **kwargs): + logger.info("Start to convert ONNX to tensorrt engine.") + cmd = f"trtexec --onnx={onnx_path} --saveEngine={engine_path} --bf16 " + p = Popen(cmd, shell=True) + p.wait() + if not Path(engine_path).exists(): + raise RuntimeError(f"Convert onnx({onnx_path}) to tensorrt engine failed.") + logger.info("Finish tensorrt converting.") + return engine_path diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/model.py b/lightx2v/text2v/models/text_encoders/trt/t5/model.py index 5d021cb..6b3b6e4 100644 --- a/lightx2v/text2v/models/text_encoders/trt/t5/model.py +++ b/lightx2v/text2v/models/text_encoders/trt/t5/model.py @@ -1,5 +1,3 @@ -import logging - import torch from ...hf.t5.tokenizer import HuggingfaceTokenizer @@ -25,7 +23,7 @@ def to_cpu(self): def to_cuda(self): self.model = self.model.to("cuda") - def infer(self, texts, args): + def infer(self, texts, args, **kwargs): if args.cpu_offload: self.to_cuda() diff --git a/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py index 0918c75..496d69a 100644 --- a/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py +++ b/lightx2v/text2v/models/text_encoders/trt/t5/trt_t5_infer.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from subprocess import Popen diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py index 774168f..2ca8045 100755 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/model.py @@ -6,7 +6,7 @@ class VideoEncoderKLCausal3DModel: - def __init__(self, model_path, dtype, device): + def __init__(self, model_path, dtype, device, **kwargs): self.model_path = model_path self.dtype = dtype self.device = device @@ -24,7 +24,7 @@ def load(self): trt_decoder = trt_vae_infer.HyVaeTrtModelInfer(engine_path=os.path.join(self.vae_path, "vae_decoder.engine")) self.model.decoder = trt_decoder - def decode(self, latents, generator): + def decode(self, latents, generator, **kwargs): latents = latents / self.model.config.scaling_factor latents = latents.to(dtype=self.dtype, device=self.device) self.model.enable_tiling() diff --git a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py index c7ad8aa..7c79560 100644 --- a/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py +++ b/lightx2v/text2v/models/video_encoders/trt/autoencoder_kl_causal_3d/trt_vae_infer.py @@ -7,7 +7,7 @@ import tensorrt as trt from loguru import logger -from lightx2v import common +from lightx2v.common.backend_infer.trt import common from lightx2v.common.backend_infer.trt.trt_infer_base import TrtModelInferBase TRT_LOGGER = trt.Logger(trt.Logger.INFO) @@ -35,9 +35,10 @@ def get_output_shape(shp): out = (b, 3, 4 * (t - 1) + 1, h * 8, w * 8) return out - shp_dict = {"inp": batch.shape, "out": get_output_shape(batch.shape)} + vae_out_shape = get_output_shape(batch.shape) + shp_dict = {"inp": batch.shape, "out": vae_out_shape} self.alloc(shp_dict) - output = np.zeros(*self.out_list[0]["shape"], self.out_list[0]["dtype"]) + output = np.zeros(vae_out_shape, self.out_list[0]["dtype"]) # Process I/O and execute the network common.memcpy_host_to_device(self.inputs[0]["allocation"], np.ascontiguousarray(batch))