Skip to content

Commit 198bc35

Browse files
[InsertGPUAllocs] Use gpu.memcpy for opencl instead of memref.copy
1 parent 5c647e7 commit 198bc35

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

lib/Transforms/InsertGPUAllocs.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,10 @@ class InsertGPUAllocsPass final
360360
auto newAlloc = builder.create<mlir::memref::AllocOp>(
361361
loc, alloc.getType(), alloc.getDynamicSizes(),
362362
alloc.getSymbolOperands());
363-
builder.create<mlir::memref::CopyOp>(loc, allocResult,
364-
newAlloc.getResult());
363+
builder.create<mlir::gpu::MemcpyOp>(
364+
loc, /*asyncToken*/ static_cast<mlir::Type>(nullptr),
365+
/*asyncDependencies*/ std::nullopt, newAlloc.getResult(),
366+
allocResult);
365367
use.set(newAlloc.getResult());
366368
}
367369
}
@@ -401,8 +403,9 @@ class InsertGPUAllocsPass final
401403
/*symbolOperands*/ std::nullopt, hostShared);
402404
auto allocResult = gpuAlloc.getResult(0);
403405
if (access.hostWrite && access.deviceRead) {
404-
auto copy =
405-
builder.create<mlir::memref::CopyOp>(loc, op, allocResult);
406+
auto copy = builder.create<mlir::gpu::MemcpyOp>(
407+
loc, /*asyncToken*/ static_cast<mlir::Type>(nullptr),
408+
/*asyncDependencies*/ std::nullopt, allocResult, op);
406409
filter.insert(copy);
407410
}
408411

@@ -421,7 +424,9 @@ class InsertGPUAllocsPass final
421424
op.replaceAllUsesExcept(allocResult, filter);
422425
builder.setInsertionPoint(term);
423426
if (access.hostRead && access.deviceWrite) {
424-
builder.create<mlir::memref::CopyOp>(loc, allocResult, op);
427+
builder.create<mlir::gpu::MemcpyOp>(
428+
loc, /*asyncToken*/ static_cast<mlir::Type>(nullptr),
429+
/*asyncDependencies*/ std::nullopt, op, allocResult);
425430
}
426431
builder.create<mlir::gpu::DeallocOp>(loc, std::nullopt, allocResult);
427432
}

test/Transforms/InsertGpuAllocs/add-gpu-alloc.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ func.func @addt(%arg0: memref<2x5xf32>, %arg1: memref<2x5xf32>) -> memref<2x5xf3
77
%c1 = arith.constant 1 : index
88
%c5 = arith.constant 5 : index
99
// OPENCL: %[[MEMREF0:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
10-
// OPENCL: memref.copy %arg1, %[[MEMREF0]] : memref<2x5xf32> to memref<2x5xf32>
10+
// OPENCL: gpu.memcpy %[[MEMREF0]], %arg1 : memref<2x5xf32>, memref<2x5xf32>
1111
// OPENCL: %[[MEMREF1:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
12-
// OPENCL: memref.copy %arg0, %[[MEMREF1]] : memref<2x5xf32> to memref<2x5xf32>
12+
// OPENCL: gpu.memcpy %[[MEMREF1]], %arg0 : memref<2x5xf32>, memref<2x5xf32>
1313
// VULKAN: %[[MEMREF0:.*]] = memref.alloc() : memref<2x5xf32>
1414
// VULKAN: memref.copy %arg1, %[[MEMREF0]] : memref<2x5xf32> to memref<2x5xf32>
1515
// VULKAN: %[[MEMREF1:.*]] = memref.alloc() : memref<2x5xf32>

test/Transforms/InsertGpuAllocs/memref-get-global.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ func.func @addt(%arg0: memref<2x5xf32>, %arg1: memref<2x5xf32>) -> memref<2x5xf3
1717

1818
// OPENCL: [[VAR0:%.*]] = memref.get_global @__constant_2x5xf32 : memref<2x5xf32>
1919
// OPENCL: %[[MEMREF0:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
20-
// OPENCL: memref.copy [[VAR0]], %[[MEMREF0]] : memref<2x5xf32> to memref<2x5xf32>
20+
// OPENCL: gpu.memcpy %[[MEMREF0]], [[VAR0]] : memref<2x5xf32>, memref<2x5xf32>
2121
// OPENCL: [[VAR1:%.*]] = memref.get_global @__constant_2x5xf32_0 : memref<2x5xf32>
2222
// OPENCL: %[[MEMREF1:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
23-
// OPENCL: memref.copy [[VAR1]], %[[MEMREF1]] : memref<2x5xf32> to memref<2x5xf32>
23+
// OPENCL: gpu.memcpy %[[MEMREF1]], [[VAR1]] : memref<2x5xf32>, memref<2x5xf32>
2424
// OPENCL: %[[MEMREF2:.*]] = gpu.alloc host_shared () : memref<2x5xf32>
2525
// VULKAN: [[VAR0:%.*]] = memref.get_global @__constant_2x5xf32 : memref<2x5xf32>
2626
// VULKAN: %[[MEMREF0:.*]] = memref.alloc() : memref<2x5xf32>

test/Transforms/InsertGpuAllocs/memref-returned-from-call.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func.func @main() {
1212
// OPENCL: func.func @main()
1313
%0 = func.call @alloc_buffer() : () -> memref<8xf32>
1414
// OPENCL: %[[MEMREF:.*]] = gpu.alloc host_shared () : memref<8xf32>
15-
// OPENCL: memref.copy %0, %[[MEMREF]] : memref<8xf32> to memref<8xf32>
15+
// OPENCL: gpu.memcpy %[[MEMREF]], %0 : memref<8xf32>, memref<8xf32>
1616
%1 = memref.alloc() : memref<8xf32>
1717
%2 = memref.alloc() : memref<8xf32>
1818
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c8, %arg7 = %c1, %arg8 = %c1) threads(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) {

0 commit comments

Comments
 (0)