Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import hat.codebuilders.CodeBuilder;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
import hat.dialect.HATVectorLoadOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import jdk.incubator.code.Op;
Expand Down Expand Up @@ -250,4 +251,10 @@ public CudaHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext
}
return self();
}

@Override
public CudaHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
identifier("make_" + hatVectorOfOp.buildType()).oparen();
return self();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
import hat.codebuilders.CodeBuilder;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
import hat.dialect.HATVectorLoadOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;

import java.util.List;

public class OpenCLHATKernelBuilder extends C99HATKernelBuilder<OpenCLHATKernelBuilder> {

@Override
Expand Down Expand Up @@ -83,7 +86,7 @@ public OpenCLHATKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon
Value dest = hatVectorStoreView.operands().get(0);
Value index = hatVectorStoreView.operands().get(2);

identifier("vstore" + hatVectorStoreView.storeN())
identifier("vstore" + hatVectorStoreView.vectorN())
.oparen()
.varName(hatVectorStoreView)
.comma()
Expand Down Expand Up @@ -131,7 +134,7 @@ public OpenCLHATKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont
Value source = hatVectorLoadOp.operands().get(0);
Value index = hatVectorLoadOp.operands().get(1);

identifier("vload" + hatVectorLoadOp.loadN())
identifier("vload" + hatVectorLoadOp.vectorN())
.oparen()
.intConstZero()
.comma()
Expand Down Expand Up @@ -180,13 +183,11 @@ public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon

@Override
public OpenCLHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
oparen().typeName("half").cparen();
// typeName("convert_half").oparen();
oparen().halfType().cparen();
Value initValue = hatF16ConvOp.operands().getFirst();
if (initValue instanceof Op.Result r) {
recurse(buildContext, r.op());
}
//cparen();
return self();
}

Expand All @@ -204,4 +205,9 @@ public OpenCLHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte
return self();
}

@Override
public OpenCLHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen();
return self();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ bool OpenCLBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memoryS

OpenCLBackend::OpenCLBackend(int configBits)
: Backend(new Config(configBits), new OpenCLQueue(this)) {

if (config->info) {
std::cerr << "[INFO] Config Bits = " << std::hex << configBits << std::dec << std::endl;
}

cl_int status;
cl_uint platformc = 0;
OPENCL_CHECK(clGetPlatformIDs(0, nullptr, &platformc), "clGetPlatformIDs");
Expand Down Expand Up @@ -303,7 +308,6 @@ const char *OpenCLBackend::errorMsg(cl_int status) {
}

extern "C" long getBackend(int configBits) {
std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl;
return reinterpret_cast<long>(new OpenCLBackend(configBits));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
import hat.codebuilders.CodeBuilder;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
import hat.dialect.HATVectorLoadOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;

import java.util.List;

public class OpenCLHatKernelBuilder extends C99HATKernelBuilder<OpenCLHatKernelBuilder> {

@Override
Expand Down Expand Up @@ -83,7 +86,7 @@ public OpenCLHatKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon
Value dest = hatVectorStoreView.operands().get(0);
Value index = hatVectorStoreView.operands().get(2);

identifier("vstore" + hatVectorStoreView.storeN())
identifier("vstore" + hatVectorStoreView.vectorN())
.oparen()
.varName(hatVectorStoreView)
.comma()
Expand Down Expand Up @@ -131,7 +134,7 @@ public OpenCLHatKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont
Value source = hatVectorLoadOp.operands().get(0);
Value index = hatVectorLoadOp.operands().get(1);

identifier("vload" + hatVectorLoadOp.loadN())
identifier("vload" + hatVectorLoadOp.vectorN())
.oparen()
.intConstZero()
.comma()
Expand Down Expand Up @@ -181,12 +184,10 @@ public OpenCLHatKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon
@Override
public OpenCLHatKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
oparen().typeName("half").cparen();
// typeName("convert_half").oparen();
Value initValue = hatF16ConvOp.operands().getFirst();
if (initValue instanceof Op.Result r) {
recurse(buildContext, r.op());
}
//cparen();
return self();
}

Expand All @@ -204,4 +205,10 @@ public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte
return self();
}

@Override
public OpenCLHatKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen();
return self();
}

}
3 changes: 2 additions & 1 deletion hat/core/src/main/java/hat/buffer/Buffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ default void setState(int newState ){
default String getStateString(){
return BufferState.of(this).getStateString();
}
// default boolean isDeviceDirty(){

// default boolean isDeviceDirty(){
// return BufferState.of(this).isDeviceDirty();
// }
// default boolean isHostChecked(){
Expand Down
17 changes: 11 additions & 6 deletions hat/core/src/main/java/hat/buffer/F32ArrayPadded.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,18 @@ default float[] arrayView() {
return arr;
}

// This is an intrinsic for HAT to create views. It does not execute code
// on the host side, at least for now.
default Float4 float4View(int index) {
return null;
default Float4.MutableImpl float4View(int index) {
// MemorySegment memorySegment = Buffer.getMemorySegment(this);
// float f1 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 0);
// float f2 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 1);
// float f3 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 2);
// float f4 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 3);
// return Float4.makeMutable(Float4.of(f1, f2, f3, f4));
return null;
}

// This is an intrinsic for HAT to store views. It does not execute code
default void storeFloat4View(Float4 v, int index) {}
default void storeFloat4View(Float4 v, int index) {
// MemorySegment.copy(Buffer.getMemorySegment(this), JAVA_FLOAT, ARRAY_OFFSET, v.toArray(), index, 4);
}

}
64 changes: 46 additions & 18 deletions hat/core/src/main/java/hat/buffer/Float4.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,87 @@
*/
package hat.buffer;

import hat.Accelerator;
import hat.ifacemapper.Schema;
import jdk.incubator.code.CodeReflection;

import java.util.function.BiFunction;
import java.util.stream.IntStream;

public interface Float4 extends HatVector {

float x();
float y();
float z();
float w();
void x(float x);
void y(float y);
void z(float z);
void w(float w);

Schema<Float4> schema = Schema.of(Float4.class,
float4->float4.fields("x","y","z","w"));
record MutableImpl(float x, float y, float z, float w) implements Float4 {
public void x(float x) {}
public void y(float y) {}
public void z(float z) {}
public void w(float w) {}
}

record ImmutableImpl(float x, float y, float z, float w) implements Float4 {
}

/**
* Make a Mutable implementation (for the device side - e.g., the GPU) from an immutable implementation.
*
* @param float4
* @return {@link Float4.MutableImpl}
*/
static Float4.MutableImpl makeMutable(Float4 float4) {
return new MutableImpl(float4.x(), float4.y(), float4.z(), float4.w());
}

static Float4 of(float x, float y, float z, float w) {
return new ImmutableImpl(x, y, z, w);
}

static Float4 create(Accelerator accelerator) {
return schema.allocate(accelerator, 1);
// Not implemented for the GPU yet
default Float4 lanewise(Float4 other, BiFunction<Float, Float, Float> f) {
float[] backA = this.toArray();
float[] backB = other.toArray();
float[] backC = new float[backA.length];
IntStream.range(0, backA.length).forEach(j -> {
backC[j] = f.apply(backA[j], backB[j]);
});
return of(backC[0], backC[1], backC[2], backC[3]);
}

static Float4 add(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, Float::sum);
}

static Float4 sub(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, (a, b) -> a - b);
}

static Float4 mul(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, (a, b) -> a * b);
}

static Float4 div(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, (a, b) -> a / b);
}

default Float4 add(Float4 vb) {
return null;
return Float4.add(this, vb);
}

default Float4 sub(Float4 vb) {
return null;
return Float4.sub(this, vb);
}

default Float4 mul(Float4 vb) {
return null;
return Float4.mul(this, vb);
}

default Float4 div(Float4 vb) {
return null;
return Float4.div(this, vb);
}

// Not implemented for the GPU yet
@CodeReflection
default float[] toArray() {
return new float[] { x(), y(), z(), w() };
}
Expand Down
8 changes: 8 additions & 0 deletions hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATF16VarLoadOp;
import hat.dialect.HATF16VarOp;
import hat.dialect.HATVectorMakeOfOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
Expand Down Expand Up @@ -141,6 +143,10 @@ public interface BabylonOpBuilder<T extends HATCodeBuilderWithContext<?>> {

T hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp);

T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp);

T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp);

default T recurse(ScopedCodeBuilderContext buildContext, Op op) {
switch (op) {
case CoreOp.VarAccessOp.VarLoadOp $ -> varLoadOp(buildContext, $);
Expand Down Expand Up @@ -183,10 +189,12 @@ default T recurse(ScopedCodeBuilderContext buildContext, Op op) {
case HATVectorSelectLoadOp $ -> hatSelectLoadOp(buildContext, $);
case HATVectorSelectStoreOp $ -> hatSelectStoreOp(buildContext, $);
case HATVectorVarLoadOp $ -> hatVectorVarLoadOp(buildContext, $);
case HATVectorOfOp $ -> hatVectorOfOps(buildContext, $);
case HATF16VarOp $ -> hatF16VarOp(buildContext, $);
case HATF16BinaryOp $ -> hatF16BinaryOp(buildContext, $);
case HATF16VarLoadOp $ -> hatF16VarLoadOp(buildContext, $);
case HATF16ConvOp $ -> hatF16ConvOp(buildContext, $);
case HATVectorMakeOfOp $ -> hatVectorMakeOf(buildContext, $);
default -> throw new IllegalStateException("handle nesting of op " + op);
}
return (T) this;
Expand Down
31 changes: 31 additions & 0 deletions hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import hat.dialect.HATGlobalThreadIdOp;
import hat.dialect.HATLocalSizeOp;
import hat.dialect.HATLocalThreadIdOp;
import hat.dialect.HATVectorMakeOfOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorVarLoadOp;
import hat.ifacemapper.MappableIface;
import hat.optools.FuncOpParams;
Expand Down Expand Up @@ -274,6 +276,35 @@ public T hatF16VarLoadOp(ScopedCodeBuilderContext buildContext, HATF16VarLoadOp
return self();
}

@Override
public T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp) {
identifier(hatVectorMakeOfOp.varName());
return self();
}

public abstract T genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp);

@Override
public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) {
genVectorIdentifier(buildContext, hatVectorOp);

List<Value> inputOperands = hatVectorOp.operands();
int i;
for (i = 0; i < (inputOperands.size() - 1); i++) {
var operand = inputOperands.get(i);
if ((operand instanceof Op.Result r)) {
recurse(buildContext, r.op());
}
comma().space();
}
// Last parameter
var operand = inputOperands.get(i);
if ((operand instanceof Op.Result r)) {
recurse(buildContext, r.op());
}
cparen();
return self();
}

public T kernelDeclaration(CoreOp.FuncOp funcOp) {
return kernelPrefix().voidType().space().funcName(funcOp);
Expand Down
Loading