Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
import hat.codebuilders.CodeBuilder;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
import hat.dialect.HATVectorLoadOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import jdk.incubator.code.Op;
Expand Down Expand Up @@ -250,4 +251,10 @@ public CudaHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext
}
return self();
}

@Override
public CudaHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
identifier("make_" + hatVectorOfOp.buildType()).oparen();
return self();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
import hat.codebuilders.CodeBuilder;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
import hat.dialect.HATVectorLoadOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;

import java.util.List;

public class OpenCLHATKernelBuilder extends C99HATKernelBuilder<OpenCLHATKernelBuilder> {

@Override
Expand Down Expand Up @@ -83,7 +86,7 @@ public OpenCLHATKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon
Value dest = hatVectorStoreView.operands().get(0);
Value index = hatVectorStoreView.operands().get(2);

identifier("vstore" + hatVectorStoreView.storeN())
identifier("vstore" + hatVectorStoreView.vectorN())
.oparen()
.varName(hatVectorStoreView)
.comma()
Expand Down Expand Up @@ -131,7 +134,7 @@ public OpenCLHATKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont
Value source = hatVectorLoadOp.operands().get(0);
Value index = hatVectorLoadOp.operands().get(1);

identifier("vload" + hatVectorLoadOp.loadN())
identifier("vload" + hatVectorLoadOp.vectorN())
.oparen()
.intConstZero()
.comma()
Expand Down Expand Up @@ -180,13 +183,11 @@ public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon

@Override
public OpenCLHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
oparen().typeName("half").cparen();
// typeName("convert_half").oparen();
oparen().halfType().cparen();
Value initValue = hatF16ConvOp.operands().getFirst();
if (initValue instanceof Op.Result r) {
recurse(buildContext, r.op());
}
//cparen();
return self();
}

Expand All @@ -204,4 +205,9 @@ public OpenCLHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte
return self();
}

@Override
public OpenCLHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen();
return self();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ bool OpenCLBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memoryS

OpenCLBackend::OpenCLBackend(int configBits)
: Backend(new Config(configBits), new OpenCLQueue(this)) {

if (config->info) {
std::cerr << "[INFO] Config Bits = " << std::hex << configBits << std::dec << std::endl;
}

cl_int status;
cl_uint platformc = 0;
OPENCL_CHECK(clGetPlatformIDs(0, nullptr, &platformc), "clGetPlatformIDs");
Expand Down Expand Up @@ -303,7 +308,6 @@ const char *OpenCLBackend::errorMsg(cl_int status) {
}

extern "C" long getBackend(int configBits) {
std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl;
return reinterpret_cast<long>(new OpenCLBackend(configBits));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
import hat.codebuilders.CodeBuilder;
import hat.codebuilders.ScopedCodeBuilderContext;
import hat.dialect.HATF16ConvOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorBinaryOp;
import hat.dialect.HATVectorLoadOp;
import hat.dialect.HATVectorOfOp;
import hat.dialect.HATVectorSelectLoadOp;
import hat.dialect.HATVectorSelectStoreOp;
import hat.dialect.HATVectorStoreView;
import hat.dialect.HATVectorVarOp;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;

import java.util.List;

public class OpenCLHatKernelBuilder extends C99HATKernelBuilder<OpenCLHatKernelBuilder> {

@Override
Expand Down Expand Up @@ -83,7 +86,7 @@ public OpenCLHatKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon
Value dest = hatVectorStoreView.operands().get(0);
Value index = hatVectorStoreView.operands().get(2);

identifier("vstore" + hatVectorStoreView.storeN())
identifier("vstore" + hatVectorStoreView.vectorN())
.oparen()
.varName(hatVectorStoreView)
.comma()
Expand Down Expand Up @@ -131,7 +134,7 @@ public OpenCLHatKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont
Value source = hatVectorLoadOp.operands().get(0);
Value index = hatVectorLoadOp.operands().get(1);

identifier("vload" + hatVectorLoadOp.loadN())
identifier("vload" + hatVectorLoadOp.vectorN())
.oparen()
.intConstZero()
.comma()
Expand Down Expand Up @@ -181,12 +184,10 @@ public OpenCLHatKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon
@Override
public OpenCLHatKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) {
oparen().typeName("half").cparen();
// typeName("convert_half").oparen();
Value initValue = hatF16ConvOp.operands().getFirst();
if (initValue instanceof Op.Result r) {
recurse(buildContext, r.op());
}
//cparen();
return self();
}

Expand All @@ -204,4 +205,10 @@ public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte
return self();
}

@Override
public OpenCLHatKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) {
oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen();
return self();
}

}
37 changes: 37 additions & 0 deletions hat/core/src/main/java/hat/annotations/HATVectorType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package hat.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE_USE})
public @interface HATVectorType {
String primitiveType() default "float";
int lanes() default 0;
}
3 changes: 2 additions & 1 deletion hat/core/src/main/java/hat/buffer/Buffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ default void setState(int newState ){
default String getStateString(){
return BufferState.of(this).getStateString();
}
// default boolean isDeviceDirty(){

// default boolean isDeviceDirty(){
// return BufferState.of(this).isDeviceDirty();
// }
// default boolean isHostChecked(){
Expand Down
17 changes: 11 additions & 6 deletions hat/core/src/main/java/hat/buffer/F32ArrayPadded.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,18 @@ default float[] arrayView() {
return arr;
}

// This is an intrinsic for HAT to create views. It does not execute code
// on the host side, at least for now.
default Float4 float4View(int index) {
return null;
default Float4.MutableImpl float4View(int index) {
// MemorySegment memorySegment = Buffer.getMemorySegment(this);
// float f1 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 0);
// float f2 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 1);
// float f3 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 2);
// float f4 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 3);
// return Float4.makeMutable(Float4.of(f1, f2, f3, f4));
return null;
}

// This is an intrinsic for HAT to store views. It does not execute code
default void storeFloat4View(Float4 v, int index) {}
default void storeFloat4View(Float4 v, int index) {
// MemorySegment.copy(Buffer.getMemorySegment(this), JAVA_FLOAT, ARRAY_OFFSET, v.toArray(), index, 4);
}

}
85 changes: 66 additions & 19 deletions hat/core/src/main/java/hat/buffer/Float4.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,106 @@
*/
package hat.buffer;

import hat.Accelerator;
import hat.ifacemapper.Schema;
import hat.annotations.HATVectorType;
import hat.types._V4;
import jdk.incubator.code.CodeReflection;
import jdk.incubator.code.TypeElement;
import jdk.incubator.code.dialect.java.JavaType;
import jdk.incubator.code.dialect.java.PrimitiveType;

public interface Float4 extends HatVector {
import java.util.function.BiFunction;
import java.util.stream.IntStream;

@HATVectorType(primitiveType = "float", lanes = 4)
public interface Float4 extends HATVector {

float x();
float y();
float z();
float w();
void x(float x);
void y(float y);
void z(float z);
void w(float w);

Schema<Float4> schema = Schema.of(Float4.class,
float4->float4.fields("x","y","z","w"));
// @CodeReflection
// @Override
// default PrimitiveType type() {
// return JavaType.FLOAT;
// }
//
// @CodeReflection
// @Override
// default int width() {
// return 4;
// }

@HATVectorType(primitiveType = "float", lanes = 4)
record MutableImpl(float x, float y, float z, float w) implements Float4 {
public void x(float x) {}
public void y(float y) {}
public void z(float z) {}
public void w(float w) {}
}

@HATVectorType(primitiveType = "float", lanes = 4)
record ImmutableImpl(float x, float y, float z, float w) implements Float4 {
}

/**
* Make a Mutable implementation (for the device side - e.g., the GPU) from an immutable implementation.
*
* @param float4
* @return {@link Float4.MutableImpl}
*/
static Float4.MutableImpl makeMutable(Float4 float4) {
return new MutableImpl(float4.x(), float4.y(), float4.z(), float4.w());
}

static Float4 of(float x, float y, float z, float w) {
return new ImmutableImpl(x, y, z, w);
}

static Float4 create(Accelerator accelerator) {
return schema.allocate(accelerator, 1);
// Not implemented for the GPU yet
default Float4 lanewise(Float4 other, BiFunction<Float, Float, Float> f) {
float[] backA = this.toArray();
float[] backB = other.toArray();
float[] backC = new float[backA.length];
IntStream.range(0, backA.length).forEach(j -> {
backC[j] = f.apply(backA[j], backB[j]);
});
return of(backC[0], backC[1], backC[2], backC[3]);
}

static Float4 add(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, Float::sum);
}

static Float4 sub(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, (a, b) -> a - b);
}

static Float4 mul(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, (a, b) -> a * b);
}

static Float4 div(Float4 vA, Float4 vB) {
return null;
return vA.lanewise(vB, (a, b) -> a / b);
}

default Float4 add(Float4 vb) {
return null;
return Float4.add(this, vb);
}

default Float4 sub(Float4 vb) {
return null;
return Float4.sub(this, vb);
}

default Float4 mul(Float4 vb) {
return null;
return Float4.mul(this, vb);
}

default Float4 div(Float4 vb) {
return null;
return Float4.div(this, vb);
}

// Not implemented for the GPU yet
default float[] toArray() {
return new float[] { x(), y(), z(), w() };
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
*/
package hat.buffer;

public interface HatVector extends Buffer {
public interface HATVector {

}
Loading