-
Notifications
You must be signed in to change notification settings - Fork 124
[ptx] Support for CUDA JIT compiler flags #713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
… flags can be passed to the CUDA JIT compiler, implement the JNI of cuModuleLoad so TornadoVM prebuilt can read .cubin files, currently both compiler flags and the path of the .cubin are hardcoded, should pass them from API
… flags can be passed to the CUDA JIT compiler, implement the JNI of cuModuleLoad so TornadoVM prebuilt can read .cubin files, currently both compiler flags and the path of the .cubin are hardcoded, should pass them from API
|
@yrq0208 what it needs to accept compiler flags as in opencl backend? |
…ornadoVM into PTX_cuModuleLoadDataEx
…to process the CUDA JIT flags passed from TornadoOption.java to the relevant CUDA function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements support for passing CUDA JIT compiler flags through the cuModuleLoadDataEx() method, enabling explicit control over PTX compilation optimization levels and other performance-related flags. Previously, the implementation only used cuModuleLoadData() which didn't allow passing compiler flags.
Key Changes:
- Added new JNI method
cuModuleLoadDataExto handle CUDA JIT compiler flags - Updated default PTX compiler flags to include
CU_JIT_OPTIMIZATION_LEVEL 4for optimal performance - Modified PTX module loading pipeline to pass
TaskDataContextmetadata through the compilation chain
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java | Updated default PTX compiler flags and added documentation on flag format |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXTornadoCompiler.java | Updated installSource call to pass metadata parameter |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/tests/TestPTXJITCompiler.java | Updated installCode call to pass task metadata |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java | Modified compilation methods to pass task metadata for compiler flags |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXModule.java | Added new constructor parameter for compiler flags and native method declaration |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXDeviceContext.java | Updated installCode methods to accept and forward task metadata |
| tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/PTXCodeCache.java | Modified to extract compiler flags from task metadata and pass to PTXModule |
| tornado-drivers/ptx-jni/src/main/cpp/source/PTXModule.h | Added JNI header declaration for cuModuleLoadDataEx |
| tornado-drivers/ptx-jni/src/main/cpp/source/PTXModule.cpp | Implemented cuModuleLoadDataEx with CUDA JIT flag parsing and application |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java
Outdated
Show resolved
Hide resolved
tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java
Outdated
Show resolved
Hide resolved
| char ptx[ptx_length + 1]; | ||
| #endif | ||
| env->GetByteArrayRegion(source, 0, ptx_length, reinterpret_cast<jbyte *>(ptx)); | ||
| ptx[ptx_length] = 0; // Make sure string terminates with a 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary? is it a cuda-funtion requirement?
as far as I see, it is not used again after line 95. Are you trying to reset it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this is necessary, otherwise if removed cuModuleLoadDataEx() would throw error 218 which means invalid PTX code input.
Description
This patch enables the explicit passing of CUDA JIT compiler flags by using the cuModuleLoadDataEx() method and TornadoOptions.java
Problem description
The current Java_uk_ac_manchester_tornado_drivers_ptx_PTXModule_cuModuleLoadData implementation does not allow the explicit passing of CUDA JIT compiler flags.
Backend/s tested
Mark the backends affected by this PR.
OS tested
Mark the OS where this PR is tested.
Did you check on FPGAs?
This patch is not applicable to FPGAs
How to test the new patch?
The CUDA JIT flags are passed via the TornadoVM CLI in the form of a string. Please refer to the document for a list of currently supported CUDA JIT flags. Feel free to try other flags. By default, the TestCompilerFlagsAPI unit test for PTX using optimization level 0, by passing opt level 4 via the CLI, I can observe a speedup of around 3.8x in terms of TASK_KERNEL_TIME on my machine, which suggests the passing is successful and the opt level is overwritten from 0 to 4
tornado-test -V --jvm="-Ds0.t0.device=0:0 -Dtornado.ptx.compiler.flags=CU_JIT_OPTIMIZATION_LEVEL\ 4\ CU_JIT_CACHE_MODE\ 0" --enableProfiler console uk.ac.manchester.tornado.unittests.compiler.TestCompilerFlagsAPI#testPTX
Here is another example command with all the flags (set your own CU_JIT_TARGET accordingly! Older versions of CUDA might not support GPU with 12.0 computer capability.):
tornado-test -V --jvm="-Ds0.t0.device=0:0 -Dtornado.ptx.compiler.flags=CU_JIT_OPTIMIZATION_LEVEL\ 4\ CU_JIT_CACHE_MODE\ 0\ CU_JIT_MAX_REGISTERS\ 255\ CU_JIT_TARGET\ 120\ CU_JIT_GENERATE_DEBUG_INFO\ 0\ CU_JIT_LOG_VERBOSE\ 0\ CU_JIT_GENERATE_LINE_INFO\ 0" --enableProfiler console uk.ac.manchester.tornado.unittests.compiler.TestCompilerFlagsAPI#testPTX