Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions libkineto/src/CuptiActivity.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,38 @@ inline std::string getGraphNodeMetadata(const T& activity) {
#endif
}

// Convert limitingFactors bitmask to human-readable string
// Based on cudaOccLimitingFactor enum from cuda_occupancy.h
// This can be found in the CUDA toolkit typically /usr/local/cuda/targets/x86_64-linux/include/cuda_occupancy.h
inline std::string limitingFactorsToString(unsigned int factors) {
if (factors == 0) {
return "none";
}
constexpr std::pair<unsigned int, const char*> kFactors[] = {
{OCC_LIMIT_WARPS, "WARPS"},
{OCC_LIMIT_REGISTERS, "REGS"},
{OCC_LIMIT_SHARED_MEMORY, "SMEM"},
{OCC_LIMIT_BLOCKS, "BLOCKS"},
{OCC_LIMIT_BARRIERS, "BARRIERS"},
};
std::string result;
for (const auto& [mask, name] : kFactors) {
if (factors & mask) {
if (!result.empty()) {
result += "|";
}
result += name;
}
}
return result;
}

template <>
inline const std::string GpuActivity<CUpti_ActivityKernelType>::metadataJson() const {
const CUpti_ActivityKernelType& kernel = raw();
float blocksPerSmVal = blocksPerSm(kernel);
float warpsPerSmVal = warpsPerSm(kernel);
OccupancyMetrics occMetrics = computeOccupancyMetrics(kernel);

// clang-format off

Expand All @@ -456,7 +483,18 @@ inline const std::string GpuActivity<CUpti_ActivityKernelType>::metadataJson() c
"warps per SM": {},
"grid": [{}, {}, {}],
"block": [{}, {}, {}],
"est. achieved occupancy %": {}{})JSON",
"est. achieved occupancy %": {},
"occupancy": {{
"activeBlocksPerMultiprocessor": {},
"limitingFactors": "{}",
"blockLimitRegs": {},
"blockLimitSharedMem": {},
"blockLimitWarps": {},
"blockLimitBlocks": {},
"blockLimitBarriers": {},
"allocatedRegistersPerBlock": {},
"allocatedSharedMemPerBlock": {}
}}{})JSON",
kernel.queued, kernel.deviceId, kernel.contextId,
kernel.streamId, kernel.correlationId,
kernel.registersPerThread,
Expand All @@ -465,7 +503,16 @@ inline const std::string GpuActivity<CUpti_ActivityKernelType>::metadataJson() c
std::isinf(warpsPerSmVal) ? "\"inf\"" : std::to_string(warpsPerSmVal),
kernel.gridX, kernel.gridY, kernel.gridZ,
kernel.blockX, kernel.blockY, kernel.blockZ,
(int) (0.5 + (kernelOccupancy(kernel) * 100.0)),
static_cast<int>(std::lround(occMetrics.occupancy * 100.0)),
occMetrics.result.activeBlocksPerMultiprocessor,
limitingFactorsToString(occMetrics.result.limitingFactors),
occMetrics.result.blockLimitRegs,
occMetrics.result.blockLimitSharedMem,
occMetrics.result.blockLimitWarps,
occMetrics.result.blockLimitBlocks,
occMetrics.result.blockLimitBarriers,
occMetrics.result.allocatedRegistersPerBlock,
occMetrics.result.allocatedSharedMemPerBlock,
getGraphNodeMetadata(kernel)
);
// clang-format on
Expand Down
105 changes: 47 additions & 58 deletions libkineto/src/DeviceProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,12 @@ int smCount([[maybe_unused]] uint32_t deviceId) {

#ifdef HAS_CUPTI
float blocksPerSm(const CUpti_ActivityKernelType& kernel) {
int sm_count = smCount(kernel.deviceId);
if (sm_count == 0) {
return std::numeric_limits<float>::infinity();
}
return (kernel.gridX * kernel.gridY * kernel.gridZ) /
static_cast<float>(smCount(kernel.deviceId));
static_cast<float>(sm_count);
}

float warpsPerSm(const CUpti_ActivityKernelType& kernel) {
Expand All @@ -154,67 +158,52 @@ float warpsPerSm(const CUpti_ActivityKernelType& kernel) {
threads_per_warp;
}

float kernelOccupancy(const CUpti_ActivityKernelType& kernel) {
float blocks_per_sm = -1.0;
OccupancyMetrics computeOccupancyMetrics(
const CUpti_ActivityKernelType& kernel) {
OccupancyMetrics metrics;
const std::vector<cudaDeviceProp>& props = deviceProps();
if (kernel.deviceId >= props.size()) {
LOG(ERROR) << "Invalid deviceId " << kernel.deviceId
<< " exceeds available devices (" << props.size()
<< "), skipping occupancy calculation";
return metrics;
}

float blocksPerSm = -1.0;
int sm_count = smCount(kernel.deviceId);
if (sm_count) {
blocks_per_sm =
(kernel.gridX * kernel.gridY * kernel.gridZ) / (float)sm_count;
if (sm_count != 0) {
blocksPerSm = (kernel.gridX * kernel.gridY * kernel.gridZ) /
static_cast<float>(sm_count);
}
return kernelOccupancy(
kernel.deviceId,
kernel.registersPerThread,
kernel.staticSharedMemory,
kernel.dynamicSharedMemory,
kernel.blockX,
kernel.blockY,
kernel.blockZ,
blocks_per_sm);
}

float kernelOccupancy(
uint32_t deviceId,
uint16_t registersPerThread,
int32_t staticSharedMemory,
int32_t dynamicSharedMemory,
int32_t blockX,
int32_t blockY,
int32_t blockZ,
float blocksPerSm) {
// Calculate occupancy
float occupancy = -1.0;
const std::vector<cudaDeviceProp>& props = deviceProps();
if (deviceId < props.size()) {
cudaOccFuncAttributes occFuncAttr;
occFuncAttr.maxThreadsPerBlock = INT_MAX;
occFuncAttr.numRegs = registersPerThread;
occFuncAttr.sharedSizeBytes = staticSharedMemory;
occFuncAttr.partitionedGCConfig = PARTITIONED_GC_OFF;
occFuncAttr.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT;
occFuncAttr.maxDynamicSharedSizeBytes = 0;
const cudaOccDeviceState occDeviceState = {};
int blockSize = blockX * blockY * blockZ;
size_t dynamicSmemSize = dynamicSharedMemory;
cudaOccResult occ_result;
cudaOccDeviceProp prop(props[deviceId]);
cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor(
&occ_result,
&prop,
&occFuncAttr,
&occDeviceState,
blockSize,
dynamicSmemSize);
if (status == CUDA_OCC_SUCCESS) {
blocksPerSm = std::min<float>(
occ_result.activeBlocksPerMultiprocessor, blocksPerSm);
occupancy = blocksPerSm * blockSize /
(float)props[deviceId].maxThreadsPerMultiProcessor;
} else {
LOG_EVERY_N(ERROR, 1000)
<< "Failed to calculate occupancy, status = " << status;
}
cudaOccFuncAttributes occFuncAttr;
occFuncAttr.maxThreadsPerBlock = INT_MAX;
occFuncAttr.numRegs = kernel.registersPerThread;
occFuncAttr.sharedSizeBytes = kernel.staticSharedMemory;
occFuncAttr.partitionedGCConfig = PARTITIONED_GC_OFF;
occFuncAttr.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT;
occFuncAttr.maxDynamicSharedSizeBytes = 0;
const cudaOccDeviceState occDeviceState = {};
int blockSize = kernel.blockX * kernel.blockY * kernel.blockZ;
size_t dynamicSmemSize = kernel.dynamicSharedMemory;
cudaOccDeviceProp prop(props[kernel.deviceId]);
cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor(
&metrics.result,
&prop,
&occFuncAttr,
&occDeviceState,
blockSize,
dynamicSmemSize);
if (status == CUDA_OCC_SUCCESS) {
float effectiveBlocksPerSm = std::min<float>(
metrics.result.activeBlocksPerMultiprocessor, blocksPerSm);
metrics.occupancy = effectiveBlocksPerSm * blockSize /
static_cast<float>(props[kernel.deviceId].maxThreadsPerMultiProcessor);
} else {
LOG_EVERY_N(ERROR, 1000)
<< "Failed to calculate occupancy, status = " << status;
}
return occupancy;
return metrics;
}
#endif // HAS_CUPTI

Expand Down
20 changes: 10 additions & 10 deletions libkineto/src/DeviceProperties.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>

#ifdef HAS_CUPTI
#include <cuda_occupancy.h>
#include <cupti.h>
#endif

Expand Down Expand Up @@ -41,16 +42,15 @@ using CUpti_ActivityMemsetType = CUpti_ActivityMemset;
float blocksPerSm(const CUpti_ActivityKernelType& kernel);
float warpsPerSm(const CUpti_ActivityKernelType& kernel);

// Return estimated achieved occupancy for a kernel
float kernelOccupancy(const CUpti_ActivityKernelType& kernel);
float kernelOccupancy(uint32_t deviceId,
uint16_t registersPerThread,
int32_t staticSharedMemory,
int32_t dynamicSharedMemory,
int32_t blockX,
int32_t blockY,
int32_t blockZ,
float blocks_per_sm);
// Occupancy results from CUDA occupancy calculator
// Returns cudaOccResult from cuda_occupancy.h plus a computed occupancy metric
struct OccupancyMetrics {
float occupancy = -1.0f; // Computed effective occupancy in number of threads
cudaOccResult result = {}; // Raw results from cudaOccMaxActiveBlocksPerMultiprocessor
};

// Return detailed occupancy metrics including limiting factors
OccupancyMetrics computeOccupancyMetrics(const CUpti_ActivityKernelType& kernel);
#endif

} // namespace KINETO_NAMESPACE
53 changes: 53 additions & 0 deletions libkineto/test/DevicePropertiesTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include "src/DeviceProperties.h"

using namespace KINETO_NAMESPACE;

class OccupancyMetricsTest : public ::testing::Test {};

#ifdef HAS_CUPTI

// Verify all cudaOccResult fields are mapped to OccupancyMetrics
TEST_F(OccupancyMetricsTest, AllFieldsPopulated) {
if (smCount(0) == 0) {
GTEST_SKIP() << "No GPU available";
}

CUpti_ActivityKernelType kernel = {};
kernel.deviceId = 0;
kernel.registersPerThread = 32;
kernel.staticSharedMemory = 0;
kernel.dynamicSharedMemory = 0;
kernel.blockX = 256;
kernel.blockY = 1;
kernel.blockZ = 1;
kernel.gridX = 100;
kernel.gridY = 1;
kernel.gridZ = 1;

OccupancyMetrics metrics = computeOccupancyMetrics(kernel);

// All fields from cudaOccResult should be populated (non-default)
EXPECT_NE(metrics.occupancy, -1.0f);
EXPECT_NE(metrics.result.activeBlocksPerMultiprocessor, 0);
// limitingFactors can legitimately be 0 if nothing is limiting
EXPECT_NE(metrics.result.blockLimitRegs, 0);
EXPECT_NE(metrics.result.blockLimitSharedMem, 0);
EXPECT_NE(metrics.result.blockLimitWarps, 0);
EXPECT_NE(metrics.result.blockLimitBlocks, 0);
// blockLimitBarriers can be 0 if no barriers used
EXPECT_NE(metrics.result.allocatedRegistersPerBlock, 0);
// allocatedSharedMemPerBlock can be 0 if no shared mem used
EXPECT_EQ(metrics.result.partitionedGCConfig, PARTITIONED_GC_OFF);
}

#endif // HAS_CUPTI
Loading