-
Notifications
You must be signed in to change notification settings - Fork 35
Add trt decoder #307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add trt decoder #307
Changes from all commits
9e97e26
79a7e19
d88452f
e7ec736
7cbbeb1
5287b09
88b3cc1
1bfbb3d
4c040dc
e01de62
ce0f24c
e641f49
f3d7a95
6533797
83e957b
cdb1754
b6cfa6f
aea8d56
036b331
1deb4f5
5db5c88
2fb89c7
8fa4ef8
4f133f9
fb16b36
c9e563f
392f5de
42c2b32
5ad505b
2d08b88
4defcfd
62cdbac
d4e79a9
d8489f7
6ba9191
eea3198
33359f5
5e38f6a
be6a52f
f152a6c
55ae990
07ed875
108a8b9
d090aca
2e727c0
626ee9a
65fdcee
ac18d2b
b645807
8e0ab06
60e703d
e9825b3
c2c61db
2b39242
663ba48
9626a12
13e5c69
6ea6ec9
32e8e64
3642fc8
30da7ce
5389216
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| *.onnx filter=lfs diff=lfs merge=lfs -text |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,7 +64,21 @@ jobs: | |
|
|
||
| - name: Install build requirements | ||
| run: | | ||
| apt install -y --no-install-recommends gfortran libblas-dev | ||
| apt install -y --no-install-recommends gfortran libblas-dev wget | ||
|
|
||
| - name: Install TensorRT (amd64) | ||
| if: matrix.platform == 'amd64' | ||
| run: | | ||
| apt-cache search tensorrt | awk '{print "Package: "$1"\nPin: version *+cuda${{ matrix.cuda_version }}\nPin-Priority: 1001\n"}' | tee /etc/apt/preferences.d/tensorrt-cuda${{ matrix.cuda_version }}.pref > /dev/null | ||
| apt update | ||
| apt install -y tensorrt-dev | ||
|
|
||
| - name: Install TensorRT (arm64) | ||
| if: matrix.platform == 'arm64' | ||
| run: | | ||
| apt-cache search tensorrt | awk '{print "Package: "$1"\nPin: version *+cuda13.0\nPin-Priority: 1001\n"}' | tee /etc/apt/preferences.d/tensorrt-cuda13.0.pref > /dev/null | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is installing CUDA 13 regardless of what Point of reference: https://github.com/NVIDIA/cudaqx/actions/runs/18989982159/job/54240883357#step:12:41 shows the CUDA 13 version being installed in AR CUDA 12.6. (I found this because our GitLab pipeline is broken for ARM right now, and I am still investigating.)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| apt update | ||
| apt install -y tensorrt-dev | ||
|
|
||
| - name: Build | ||
| id: build | ||
|
|
@@ -92,7 +106,7 @@ jobs: | |
| LD_LIBRARY_PATH: ${{ env.MPI_PATH }}/lib:${{ env.LD_LIBRARY_PATH }} | ||
| shell: bash | ||
| run: | | ||
| pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} torch lightning ml_collections mpi4py transformers quimb opt_einsum torch nvidia-cublas-cu${{ steps.config.outputs.cuda_major }} cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09 | ||
| pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} torch lightning ml_collections mpi4py transformers quimb opt_einsum torch nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09 | ||
| # The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py. | ||
| if [ "$(uname -m)" == "x86_64" ]; then | ||
| # Stim is not currently available on manylinux ARM wheels, so only | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| /******************************************************************************* | ||
| * Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. * | ||
| * All rights reserved. * | ||
| * * | ||
| * This source code and the accompanying materials are made available under * | ||
| * the terms of the Apache License 2.0 which accompanies this distribution. * | ||
| ******************************************************************************/ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "cudaq/qec/decoder.h" | ||
| #include <memory> | ||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "NvInfer.h" | ||
wsttiger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #include "NvOnnxParser.h" | ||
|
|
||
| namespace cudaq::qec::trt_decoder_internal { | ||
|
|
||
| /// @brief Validates TRT decoder parameters | ||
| /// @param params The parameter map to validate | ||
| /// @throws std::runtime_error if parameters are invalid | ||
| void validate_trt_decoder_parameters(const cudaqx::heterogeneous_map ¶ms); | ||
|
|
||
| /// @brief Loads a binary file into memory | ||
| /// @param filename Path to the file to load | ||
| /// @return Vector containing the file contents | ||
| /// @throws std::runtime_error if file cannot be opened | ||
| std::vector<char> load_file(const std::string &filename); | ||
|
|
||
| /// @brief Builds a TensorRT engine from an ONNX model | ||
| /// @param onnx_model_path Path to the ONNX model file | ||
| /// @param params Configuration parameters | ||
| /// @param logger TensorRT logger instance | ||
| /// @return Unique pointer to the built TensorRT engine | ||
| /// @throws std::runtime_error if engine building fails | ||
| std::unique_ptr<nvinfer1::ICudaEngine> | ||
| build_engine_from_onnx(const std::string &onnx_model_path, | ||
| const cudaqx::heterogeneous_map ¶ms, | ||
| nvinfer1::ILogger &logger); | ||
|
|
||
| /// @brief Saves a TensorRT engine to a file | ||
| /// @param engine The engine to save | ||
| /// @param file_path Path where to save the engine | ||
| /// @throws std::runtime_error if saving fails | ||
| void save_engine_to_file(nvinfer1::ICudaEngine *engine, | ||
| const std::string &file_path); | ||
|
|
||
| /// @brief Parses and configures precision settings for TensorRT | ||
| /// @param precision The precision string (fp16, bf16, int8, fp8, noTF32, best) | ||
| /// @param config TensorRT builder config instance | ||
| void parse_precision(const std::string &precision, | ||
| nvinfer1::IBuilderConfig *config); | ||
|
|
||
| } // namespace cudaq::qec::trt_decoder_internal | ||
Uh oh!
There was an error while loading. Please reload this page.