Skip to content

Add support for UAV Counters #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/Support/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct Buffer {
std::unique_ptr<char[]> Data;
size_t Size;
OutputProperties OutputProps;
uint32_t Counter;

uint32_t size() const { return Size; }

Expand Down Expand Up @@ -108,6 +109,7 @@ struct Resource {
std::string Name;
DirectXBinding DXBinding;
Buffer *BufferPtr = nullptr;
bool HasCounter;

bool isRaw() const {
switch (Kind) {
Expand Down
44 changes: 33 additions & 11 deletions lib/API/DX/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ static DXGI_FORMAT getRawDXFormat(Resource &R) {
return DXGI_FORMAT_UNKNOWN;
}

static uint32_t getUAVBufferSize(Resource &R) {
Copy link

@alsepkow alsepkow May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint32_t

nit: Use size_t instead? Same below.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value is a pseudo size value for Resource R and the type of R.size() is uint32_t and imo it makes the most sense to keep them aligned

return R.HasCounter
? llvm::alignTo(R.size(), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT) +
sizeof(uint32_t)
: R.size();
}

static uint32_t getUAVBufferCounterOffset(Resource &R) {
return R.HasCounter
? llvm::alignTo(R.size(), D3D12_UAV_COUNTER_PLACEMENT_ALIGNMENT)
: 0;
}

namespace {

enum DXResourceKind { UAV, SRV, CBV };
Expand Down Expand Up @@ -450,9 +463,10 @@ class DXDevice : public offloadtest::Device {
}

llvm::Expected<ResourceSet> createUAV(Resource &R, InvocationState &IS) {
llvm::outs() << "Creating UAV: { Size = " << R.size() << ", Register = u"
const uint32_t BufferSize = getUAVBufferSize(R);
llvm::outs() << "Creating UAV: { Size = " << BufferSize << ", Register = u"
<< R.DXBinding.Register << ", Space = " << R.DXBinding.Space
<< " }\n";
<< ", HasCounter = " << R.HasCounter << " }\n";
ComPtr<ID3D12Resource> Buffer;
ComPtr<ID3D12Resource> UploadBuffer;
ComPtr<ID3D12Resource> ReadBackBuffer;
Expand All @@ -462,7 +476,7 @@ class DXDevice : public offloadtest::Device {
const D3D12_RESOURCE_DESC ResDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
R.size(),
BufferSize,
1,
1,
1,
Expand All @@ -481,7 +495,7 @@ class DXDevice : public offloadtest::Device {
const D3D12_HEAP_PROPERTIES UploadHeapProp =
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
const D3D12_RESOURCE_DESC UploadResDesc =
CD3DX12_RESOURCE_DESC::Buffer(R.size());
CD3DX12_RESOURCE_DESC::Buffer(BufferSize);

if (auto Err =
HR::toError(Device->CreateCommittedResource(
Expand All @@ -496,7 +510,7 @@ class DXDevice : public offloadtest::Device {
const D3D12_RESOURCE_DESC ReadBackResDesc = {
D3D12_RESOURCE_DIMENSION_BUFFER,
0,
R.size(),
BufferSize,
1,
1,
1,
Expand Down Expand Up @@ -530,24 +544,27 @@ class DXDevice : public offloadtest::Device {
ComPtr<ID3D12Resource> Buffer) {
const uint32_t EltSize = R.getElementSize();
const uint32_t NumElts = R.size() / EltSize;
ID3D12Resource *CounterBuffer = R.HasCounter ? Buffer.Get() : nullptr;
const uint32_t CounterOffset = getUAVBufferCounterOffset(R);
DXGI_FORMAT EltFormat =
R.isRaw() ? getRawDXFormat(R)
: getDXFormat(R.BufferPtr->Format, R.BufferPtr->Channels);
const D3D12_UNORDERED_ACCESS_VIEW_DESC UAVDesc = {
EltFormat,
D3D12_UAV_DIMENSION_BUFFER,
{D3D12_BUFFER_UAV{0, NumElts, R.isStructuredBuffer() ? EltSize : 0, 0,
R.isByteAddressBuffer()
? D3D12_BUFFER_UAV_FLAG_RAW
: D3D12_BUFFER_UAV_FLAG_NONE}}};
{D3D12_BUFFER_UAV{
0, NumElts, R.isStructuredBuffer() ? EltSize : 0, CounterOffset,
R.isByteAddressBuffer() ? D3D12_BUFFER_UAV_FLAG_RAW
: D3D12_BUFFER_UAV_FLAG_NONE}}};

llvm::outs() << "UAV: HeapIdx = " << HeapIdx << " EltSize = " << EltSize
<< " NumElts = " << NumElts << "\n";
<< " NumElts = " << NumElts << " HasCounter = " << R.HasCounter
<< "\n";
D3D12_CPU_DESCRIPTOR_HANDLE UAVHandle =
IS.DescHeap->GetCPUDescriptorHandleForHeapStart();
UAVHandle.ptr += HeapIdx * Device->GetDescriptorHandleIncrementSize(
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
Device->CreateUnorderedAccessView(Buffer.Get(), nullptr, &UAVDesc,
Device->CreateUnorderedAccessView(Buffer.Get(), CounterBuffer, &UAVDesc,
UAVHandle);
}

Expand Down Expand Up @@ -895,6 +912,11 @@ class DXDevice : public offloadtest::Device {
"Failed to map result."))
return Err;
memcpy(R.first->BufferPtr->Data.get(), DataPtr, R.first->size());
if (R.first->HasCounter)
memcpy(&R.first->BufferPtr->Counter,
static_cast<char *>(DataPtr) +
getUAVBufferCounterOffset(*R.first),
sizeof(uint32_t));
R.second.Readback->Unmap(0, nullptr);
return llvm::Error::success();
};
Expand Down
2 changes: 2 additions & 0 deletions lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void MappingTraits<offloadtest::Buffer>::mapping(IO &I,
I.mapRequired("Format", B.Format);
I.mapOptional("Channels", B.Channels, 1);
I.mapOptional("Stride", B.Stride, 0);
I.mapOptional("Counter", B.Counter, 0);
if (!I.outputting() && B.Stride != 0 && B.Channels != 1)
I.setError("Cannot set a structure stride and more than one channel.");
switch (B.Format) {
Expand Down Expand Up @@ -123,6 +124,7 @@ void MappingTraits<offloadtest::Resource>::mapping(IO &I,
offloadtest::Resource &R) {
I.mapRequired("Name", R.Name);
I.mapRequired("Kind", R.Kind);
I.mapOptional("HasCounter", R.HasCounter, 0);
I.mapRequired("DirectXBinding", R.DXBinding);
}

Expand Down
49 changes: 49 additions & 0 deletions test/Feature/StructuredBuffer/dec_counter.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#--- source.hlsl
RWStructuredBuffer<uint> Out : register(u0);

[numthreads(1,1,1)]
void main(uint GI : SV_GroupIndex) {
Out.DecrementCounter();
Out.DecrementCounter();
Out.DecrementCounter();
Out[GI] = Out.DecrementCounter();
}

//--- pipeline.yaml
---
Shaders:
- Stage: Compute
Entry: main
DispatchSize: [1, 1, 1]
Buffers:
- Name: Out
Format: Hex32
Stride: 4
ZeroInitSize: 4
DescriptorSets:
- Resources:
- Name: Out
Kind: RWStructuredBuffer
HasCounter: true
DirectXBinding:
Register: 0
Space: 0
...
#--- end

# UNSUPPORTED: Vulkan
# UNSUPPORTED: Metal
# UNSUPPORTED: Clang

# RUN: split-file %s %t
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl
# RUN: %offloader %t/pipeline.yaml %t.o | FileCheck %s

# CHECK: Creating UAV: { Size = 4100, Register = u0, Space = 0, HasCounter = 1 }
# CHECK: UAV: HeapIdx = 0 EltSize = 4 NumElts = 1 HasCounter = 1

# CHECK: Name: Out
# CHECK: Counter: 4294967292
# CHECK: Data: [
# CHECK: 0xFFFFFFFC
# CHECK: ]
49 changes: 49 additions & 0 deletions test/Feature/StructuredBuffer/inc_counter.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#--- source.hlsl
RWStructuredBuffer<int> Out : register(u0);

[numthreads(1,1,1)]
void main(uint GI : SV_GroupIndex) {
Out.IncrementCounter();
Out.IncrementCounter();
Out.IncrementCounter();
Out[GI] = Out.IncrementCounter();
}

//--- pipeline.yaml
---
Shaders:
- Stage: Compute
Entry: main
DispatchSize: [1, 1, 1]
Buffers:
- Name: Out
Format: Hex32
Stride: 4
ZeroInitSize: 4
DescriptorSets:
- Resources:
- Name: Out
Kind: RWStructuredBuffer
HasCounter: true
DirectXBinding:
Register: 0
Space: 0
...
#--- end

# UNSUPPORTED: Vulkan
# UNSUPPORTED: Metal
# UNSUPPORTED: Clang

# RUN: split-file %s %t
# RUN: %dxc_target -T cs_6_0 -Fo %t.o %t/source.hlsl
# RUN: %offloader %t/pipeline.yaml %t.o | FileCheck %s

# CHECK: Creating UAV: { Size = 4100, Register = u0, Space = 0, HasCounter = 1 }
# CHECK: UAV: HeapIdx = 0 EltSize = 4 NumElts = 1 HasCounter = 1

# CHECK: Name: Out
# CHECK: Counter: 4
# CHECK: Data: [
# CHECK: 0x3
# CHECK: ]
Loading