From 02310ffce3a8acdd69358cafa5982884d0230dff Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Fri, 24 Oct 2025 17:10:59 +0200 Subject: [PATCH 01/14] [hat] Float4.of supported for OpenCL --- .../backend/ffi/OpenCLHATKernelBuilder.java | 1 + hat/core/src/main/java/hat/buffer/Buffer.java | 3 +- hat/core/src/main/java/hat/buffer/Float4.java | 49 +++++++---- .../hat/codebuilders/BabylonOpBuilder.java | 4 + .../hat/codebuilders/C99HATKernelBuilder.java | 23 ++++++ .../java/hat/dialect/HATVectorBinaryOp.java | 2 +- .../java/hat/dialect/HATVectorLoadOp.java | 2 +- .../main/java/hat/dialect/HATVectorOfOp.java | 76 +++++++++++++++++ ...{HATVectorViewOp.java => HATVectorOp.java} | 6 +- .../hat/dialect/HATVectorSelectLoadOp.java | 2 +- .../hat/dialect/HATVectorSelectStoreOp.java | 2 +- .../java/hat/dialect/HATVectorStoreView.java | 2 +- .../java/hat/dialect/HATVectorVarLoadOp.java | 2 +- .../main/java/hat/dialect/HATVectorVarOp.java | 2 +- .../java/hat/phases/HATDialectifyTier.java | 8 +- .../phases/HATDialectifyVectorOpPhase.java | 82 ++++++++++++++++--- .../HATDialectifyVectorSelectPhase.java | 8 +- .../phases/HATDialectifyVectorStorePhase.java | 10 +-- .../main/java/hat/test/TestVectorTypes.java | 49 +++++++++-- .../main/java/hat/test/engine/HatAsserts.java | 14 ++++ .../hat/tools/text/JavaHATCodeBuilder.java | 8 ++ 21 files changed, 301 insertions(+), 54 deletions(-) create mode 100644 hat/core/src/main/java/hat/dialect/HATVectorOfOp.java rename hat/core/src/main/java/hat/dialect/{HATVectorViewOp.java => HATVectorOp.java} (91%) diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 0e2ad5d52c9..7c975ef13fc 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -28,6 +28,7 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; +import hat.dialect.HATVectorOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; diff --git a/hat/core/src/main/java/hat/buffer/Buffer.java b/hat/core/src/main/java/hat/buffer/Buffer.java index e018b67395d..07efec4edbd 100644 --- a/hat/core/src/main/java/hat/buffer/Buffer.java +++ b/hat/core/src/main/java/hat/buffer/Buffer.java @@ -50,7 +50,8 @@ default void setState(int newState ){ default String getStateString(){ return BufferState.of(this).getStateString(); } - // default boolean isDeviceDirty(){ + + // default boolean isDeviceDirty(){ // return BufferState.of(this).isDeviceDirty(); // } // default boolean isHostChecked(){ diff --git a/hat/core/src/main/java/hat/buffer/Float4.java b/hat/core/src/main/java/hat/buffer/Float4.java index d7b0408976d..8cd32dd35c8 100644 --- a/hat/core/src/main/java/hat/buffer/Float4.java +++ b/hat/core/src/main/java/hat/buffer/Float4.java @@ -24,8 +24,7 @@ */ package hat.buffer; -import hat.Accelerator; -import hat.ifacemapper.Schema; +import java.util.function.BiFunction; public interface Float4 extends HatVector { @@ -38,43 +37,65 @@ public interface Float4 extends HatVector { void z(float z); void w(float w); - Schema schema = Schema.of(Float4.class, - float4->float4.fields("x","y","z","w")); + record Float4Impl(float x, float y, float z, float w) implements Float4 { + @Override + public void x(float x) {} - static Float4 create(Accelerator accelerator) { - return schema.allocate(accelerator, 1); + @Override + public void y(float y) {} + + @Override + public void z(float z) {} + + @Override + public void w(float w) {} + } + + static Float4 of(float x, float y, float z, float w) { + return new Float4Impl(x, y, z, w); + } + + default Float4 lanewise(Float4 other, BiFunction f) { + float[] backA = this.toArray(); + float[] backB = other.toArray(); + float[] backC = new float[backA.length]; + for (int j = 0; j < backA.length; j++) { + var r = f.apply(backA[j], backB[j]); + backC[j] = r; + } + return of(backC[0], backC[1], backC[2], backC[3]); } static Float4 add(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, Float::sum); } static Float4 sub(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, (a, b) -> a - b); } static Float4 mul(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, (a, b) -> a * b); } static Float4 div(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, (a, b) -> a / b); } default Float4 add(Float4 vb) { - return null; + return Float4.add(this, vb); } default Float4 sub(Float4 vb) { - return null; + return Float4.sub(this, vb); } default Float4 mul(Float4 vb) { - return null; + return Float4.mul(this, vb); } default Float4 div(Float4 vb) { - return null; + return Float4.div(this, vb); } default float[] toArray() { diff --git a/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java b/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java index 20548818494..393375f00f3 100644 --- a/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java @@ -30,6 +30,7 @@ import hat.dialect.HATF16ConvOp; import hat.dialect.HATF16VarLoadOp; import hat.dialect.HATF16VarOp; +import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; @@ -141,6 +142,8 @@ public interface BabylonOpBuilder> { T hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp); + T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp); + default T recurse(ScopedCodeBuilderContext buildContext, Op op) { switch (op) { case CoreOp.VarAccessOp.VarLoadOp $ -> varLoadOp(buildContext, $); @@ -183,6 +186,7 @@ default T recurse(ScopedCodeBuilderContext buildContext, Op op) { case HATVectorSelectLoadOp $ -> hatSelectLoadOp(buildContext, $); case HATVectorSelectStoreOp $ -> hatSelectStoreOp(buildContext, $); case HATVectorVarLoadOp $ -> hatVectorVarLoadOp(buildContext, $); + case HATVectorOfOp $ -> hatVectorOfOps(buildContext, $); case HATF16VarOp $ -> hatF16VarOp(buildContext, $); case HATF16BinaryOp $ -> hatF16BinaryOp(buildContext, $); case HATF16VarLoadOp $ -> hatF16VarLoadOp(buildContext, $); diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java index 0587c6a949d..0a82d1dae99 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java @@ -33,6 +33,7 @@ import hat.dialect.HATGlobalThreadIdOp; import hat.dialect.HATLocalSizeOp; import hat.dialect.HATLocalThreadIdOp; +import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorVarLoadOp; import hat.ifacemapper.MappableIface; import hat.optools.FuncOpParams; @@ -268,6 +269,28 @@ public T hatF16BinaryOp(ScopedCodeBuilderContext buildContext, HATF16BinaryOp ha return self(); } + @Override + public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + oparen().identifier(hatVectorOp.buildType()).cparen().oparen(); + + List inputOperands = hatVectorOp.operands(); + int i; + for (i = 0; i < (inputOperands.size() - 1); i++) { + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + comma().space(); + } + // Last parameter + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + cparen(); + return self(); + } + @Override public T hatF16VarLoadOp(ScopedCodeBuilderContext buildContext, HATF16VarLoadOp hatF16VarLoadOp) { identifier(hatF16VarLoadOp.varName()); diff --git a/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java b/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java index 35364962b62..3caddfa7d22 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java @@ -30,7 +30,7 @@ import java.util.List; -public abstract class HATVectorBinaryOp extends HATVectorViewOp { +public abstract class HATVectorBinaryOp extends HATVectorOp { public enum OpType { ADD("+"), SUB("-"), diff --git a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java index 402944c96e5..89ca2ec0ee5 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -public class HATVectorLoadOp extends HATVectorViewOp { +public class HATVectorLoadOp extends HATVectorOp { private final TypeElement typeElement; private final TypeElement vectorType; diff --git a/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java b/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java new file mode 100644 index 00000000000..1331ebf2909 --- /dev/null +++ b/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package hat.dialect; + +import jdk.incubator.code.CopyContext; +import jdk.incubator.code.Op; +import jdk.incubator.code.OpTransformer; +import jdk.incubator.code.TypeElement; +import jdk.incubator.code.Value; + +import java.util.List; +import java.util.Map; + +public class HATVectorOfOp extends HATVectorOp { + + private final TypeElement typeElement; + private final int loadN; + + public HATVectorOfOp(TypeElement typeElement, int loadN, List operands) { + super("", operands); + this.typeElement = typeElement; + this.loadN = loadN; + } + + public HATVectorOfOp(HATVectorOfOp op, CopyContext copyContext) { + super(op, copyContext); + this.typeElement = op.typeElement; + this.loadN = op.loadN; + } + + @Override + public Op transform(CopyContext copyContext, OpTransformer opTransformer) { + return new HATVectorOfOp(this, copyContext); + } + + @Override + public TypeElement resultType() { + return typeElement; + } + + @Override + public Map externalize() { + return Map.of("hat.dialect.vectorOf." + varName(), typeElement); + } + + public String buildType() { + // floatN + if (typeElement.toString().startsWith("hat.buffer.Float")) { + return "float" + loadN; + } + throw new RuntimeException("Unexpected vector type " + typeElement); + } + +} diff --git a/hat/core/src/main/java/hat/dialect/HATVectorViewOp.java b/hat/core/src/main/java/hat/dialect/HATVectorOp.java similarity index 91% rename from hat/core/src/main/java/hat/dialect/HATVectorViewOp.java rename to hat/core/src/main/java/hat/dialect/HATVectorOp.java index 0cef7659239..bc270440bbe 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorViewOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorOp.java @@ -29,16 +29,16 @@ import java.util.List; -public abstract class HATVectorViewOp extends HATOp { +public abstract class HATVectorOp extends HATOp { private String varName; - public HATVectorViewOp(String varName, List operands) { + public HATVectorOp(String varName, List operands) { super(operands); this.varName = varName; } - protected HATVectorViewOp(HATVectorViewOp that, CopyContext cc) { + protected HATVectorOp(HATVectorOp that, CopyContext cc) { super(that, cc); this.varName = that.varName; } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java index 1f74303e2fd..6d60400a91b 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -public class HATVectorSelectLoadOp extends HATVectorViewOp { +public class HATVectorSelectLoadOp extends HATVectorOp { private final TypeElement elementType; private final int lane; diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java index e527f2ca1dd..6b5ff6d093b 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java @@ -34,7 +34,7 @@ import java.util.List; import java.util.Map; -public class HATVectorSelectStoreOp extends HATVectorViewOp { +public class HATVectorSelectStoreOp extends HATVectorOp { private final TypeElement elementType; private final int lane; diff --git a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java index a936422ca88..e69219260d1 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -public final class HATVectorStoreView extends HATVectorViewOp { +public final class HATVectorStoreView extends HATVectorOp { private final TypeElement elementType; private final int storeN; diff --git a/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java index c97086b863b..427d205ae18 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -public class HATVectorVarLoadOp extends HATVectorViewOp { +public class HATVectorVarLoadOp extends HATVectorOp { private final TypeElement typeElement; diff --git a/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java b/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java index fa9fd4a0a8c..d49c82d4742 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java @@ -34,7 +34,7 @@ import java.util.List; import java.util.Map; -public class HATVectorVarOp extends HATVectorViewOp { +public class HATVectorVarOp extends HATVectorOp { private final VarType typeElement; private final int loadN; diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyTier.java b/hat/core/src/main/java/hat/phases/HATDialectifyTier.java index bc9432cf0a5..fdde6854070 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyTier.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyTier.java @@ -38,6 +38,7 @@ public class HATDialectifyTier implements Function private List hatPhases = new ArrayList<>(); public HATDialectifyTier(Accelerator accelerator) { + // barriers hatPhases.add(new HATDialectifyBarrierPhase(accelerator)); // Memory hatPhases.add(new HATDialectifyMemoryPhase.SharedPhase(accelerator)); @@ -50,16 +51,19 @@ public HATDialectifyTier(Accelerator accelerator) { hatPhases.add(new HATDialectifyThreadsPhase.LocalSizePhase(accelerator)); hatPhases.add(new HATDialectifyThreadsPhase.BlockPhase(accelerator)); - // views + // views for vector types hatPhases.add(new HATDialectifyVectorOpPhase.Float4LoadPhase(accelerator)); + hatPhases.add(new HATDialectifyVectorOpPhase.Float4OfPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.AddPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.SubPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.MulPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.DivPhase(accelerator)); - hatPhases.add(new HATDialectifyVectorStorePhase.Float4StorePhase(accelerator)); + // Vector Select individual lines hatPhases.add(new HATDialectifyVectorSelectPhase(accelerator)); + + // F16 type hatPhases.add(new HATDialectifyFP16Phase(accelerator)); } diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java index 5e3d35468ba..7a1caf8d7fe 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java @@ -31,10 +31,11 @@ import hat.dialect.HATVectorDivOp; import hat.dialect.HATVectorLoadOp; import hat.dialect.HATVectorMulOp; +import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorSubOp; import hat.dialect.HATVectorVarLoadOp; import hat.dialect.HATVectorVarOp; -import hat.dialect.HATVectorViewOp; +import hat.dialect.HATVectorOp; import hat.dialect.HATVectorBinaryOp; import hat.optools.OpTk; import jdk.incubator.code.CodeElement; @@ -53,7 +54,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public abstract class HATDialectifyVectorOpPhase implements HATDialect{ +public abstract class HATDialectifyVectorOpPhase implements HATDialect { protected final Accelerator accelerator; @Override public Accelerator accelerator(){ @@ -80,6 +81,7 @@ private HATVectorBinaryOp.OpType getBinaryOpType(JavaOp.InvokeOp invokeOp) { public enum OpView { FLOAT4_LOAD("float4View"), + OF("of"), ADD("add"), SUB("sub"), MUL("mul"), @@ -107,8 +109,8 @@ private String findNameVector(Value v) { return findNameVector(varLoadOp); } else { // Leaf of tree - - if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp hatVectorViewOp) { - return hatVectorViewOp.varName(); + if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) { + return hatVectorOp.varName(); } return null; } @@ -175,7 +177,7 @@ private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { if (r.op() instanceof CoreOp.VarOp varOp) { List inputOperandsVarOp = invokeOp.operands(); List outputOperandsVarOp = context.getValues(inputOperandsVarOp); - HATVectorViewOp memoryViewOp = new HATVectorLoadOp(varOp.varName(), varOp.resultType(), invokeOp.resultType(), 4, isShared, outputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorLoadOp(varOp.varName(), varOp.resultType(), invokeOp.resultType(), 4, isShared, outputOperandsVarOp); Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(invokeOp.result(), hatLocalResult); @@ -186,7 +188,7 @@ private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { //context.mapValue(varOp.result(), context.getValue(varOp.operands().getFirst())); List inputOperandsVarOp = varOp.operands(); List outputOperandsVarOp = context.getValues(inputOperandsVarOp); - HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(varOp.result(), hatLocalResult); @@ -200,6 +202,51 @@ private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { return funcOp; } + private CoreOp.FuncOp dialectifyVectorOf(CoreOp.FuncOp funcOp) { + var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorOf" ); + before(here,funcOp); + Stream> float4NodesInvolved = funcOp.elements() + .mapMulti((codeElement, consumer) -> { + if (codeElement instanceof JavaOp.InvokeOp invokeOp) { + if (isVectorOperation(invokeOp)) { + consumer.accept(invokeOp); + } + Set uses = invokeOp.result().uses(); + for (Op.Result result : uses) { + if (result.op() instanceof CoreOp.VarOp varOp) { + consumer.accept(varOp); + } + } + } + }); + + Set> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); + + funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> { + CopyContext context = blockBuilder.context(); + if (!nodesInvolved.contains(op)) { + blockBuilder.op(op); + } else if (op instanceof JavaOp.InvokeOp invokeOp) { + List inputOperandsVarOp = invokeOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + HATVectorOfOp memoryViewOp = new HATVectorOfOp(invokeOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); + memoryViewOp.setLocation(invokeOp.location()); + context.mapValue(invokeOp.result(), hatLocalResult); + } else if (op instanceof CoreOp.VarOp varOp) { + List inputOperandsVarOp = varOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); + memoryViewOp.setLocation(varOp.location()); + context.mapValue(varOp.result(), hatLocalResult); + } + return blockBuilder; + }); + after(here, funcOp); + return funcOp; + } + private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorBinaryOps"); before(here, funcOp); @@ -239,7 +286,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { for (Op.Result r : collect) { if (r.op() instanceof CoreOp.VarOp varOp) { HATVectorBinaryOp.OpType binaryOpType = binaryOperation.get(invokeOp); - HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOpType, varOp.varName(), invokeOp.resultType(), outputOperands); + HATVectorOp memoryViewOp = buildVectorBinaryOp(binaryOpType, varOp.varName(), invokeOp.resultType(), outputOperands); Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(invokeOp.result(), hatVectorOpResult); @@ -249,7 +296,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { } else if (op instanceof CoreOp.VarOp varOp) { List inputOperandsVarOp = varOp.operands(); List outputOperandsVarOp = context.getValues(inputOperandsVarOp); - HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); Op.Result hatVectorResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(varOp.result(), hatVectorResult); @@ -260,7 +307,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { return funcOp; } - private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp funcOp) { + private CoreOp.FuncOp dialectifyVectorBinaryWithConcatenationOps(CoreOp.FuncOp funcOp) { var here = OpTk.CallSite.of(this.getClass(), "dialectifyBinaryWithConcatenation"); before(here, funcOp); Map binaryOperation = new HashMap<>(); @@ -299,7 +346,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp f } else if (op instanceof JavaOp.InvokeOp invokeOp) { List inputOperands = invokeOp.operands(); List outputOperands = context.getValues(inputOperands); - HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOperation.get(invokeOp), "null", invokeOp.resultType(), outputOperands); + HATVectorOp memoryViewOp = buildVectorBinaryOp(binaryOperation.get(invokeOp), "null", invokeOp.resultType(), outputOperands); Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(invokeOp.location()); context.mapValue(invokeOp.result(), hatVectorOpResult); @@ -307,7 +354,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp f List inputOperandsVarLoad = varLoadOp.operands(); List outputOperandsVarLoad = context.getValues(inputOperandsVarLoad); String varLoadName = findNameVector(varLoadOp); - HATVectorViewOp memoryViewOp = new HATVectorVarLoadOp(varLoadName, varLoadOp.resultType(), outputOperandsVarLoad); + HATVectorOp memoryViewOp = new HATVectorVarLoadOp(varLoadName, varLoadOp.resultType(), outputOperandsVarLoad); Op.Result hatVectorResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varLoadOp.location()); context.mapValue(varLoadOp.result(), hatVectorResult); @@ -322,10 +369,12 @@ private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp f public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { if (Objects.requireNonNull(vectorOperation) == OpView.FLOAT4_LOAD) { funcOp = dialectifyVectorLoad(funcOp); + } else if (Objects.requireNonNull(vectorOperation) == OpView.OF) { + funcOp = dialectifyVectorOf(funcOp); } else { // Find binary operations funcOp = dialectifyVectorBinaryOps(funcOp); - funcOp = dialectifyVectorBinaryWithContatenationOps(funcOp); + funcOp = dialectifyVectorBinaryWithConcatenationOps(funcOp); } return funcOp; } @@ -344,13 +393,20 @@ public DivPhase(Accelerator accelerator) { } } - public static class Float4LoadPhase extends HATDialectifyVectorOpPhase{ + public static class Float4LoadPhase extends HATDialectifyVectorOpPhase { public Float4LoadPhase(Accelerator accelerator) { super(accelerator, OpView.FLOAT4_LOAD); } } + public static class Float4OfPhase extends HATDialectifyVectorOpPhase { + + public Float4OfPhase(Accelerator accelerator) { + super(accelerator, OpView.OF); + } + } + public static class MulPhase extends HATDialectifyVectorOpPhase{ public MulPhase(Accelerator accelerator) { diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java index ac476a86e4b..0d50b4c3d7c 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java @@ -27,7 +27,7 @@ import hat.Accelerator; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; -import hat.dialect.HATVectorViewOp; +import hat.dialect.HATVectorOp; import hat.optools.OpTk; import jdk.incubator.code.CodeElement; import jdk.incubator.code.CopyContext; @@ -85,7 +85,7 @@ private String findNameVector(Value v) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { return findNameVector(varLoadOp); } else { - if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp vectorViewOp) { + if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp vectorViewOp) { return vectorViewOp.varName(); } return null; @@ -141,7 +141,7 @@ private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { List outputOperandsInvokeOp = context.getValues(inputInvokeOp); int lane = getLane(invokeOp.invokeDescriptor().name()); - HATVectorViewOp vSelectOp; + HATVectorOp vSelectOp; String name = findNameVector(varLoadOp); if (invokeOp.resultType() != JavaType.VOID) { vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp); @@ -198,7 +198,7 @@ private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { List outputOperandsInvokeOp = context.getValues(inputInvokeOp); int lane = getLane(invokeOp.invokeDescriptor().name()); - HATVectorViewOp vSelectOp; + HATVectorOp vSelectOp; String name = findNameVector(varLoadOp); if (invokeOp.resultType() == JavaType.VOID) { // The operand 1 in the store is the address (lane) diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java index 4ef09df7db2..f6df79913fc 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java @@ -28,7 +28,7 @@ import hat.dialect.HATLocalVarOp; import hat.dialect.HATPrivateVarOp; import hat.dialect.HATVectorStoreView; -import hat.dialect.HATVectorViewOp; +import hat.dialect.HATVectorOp; import hat.optools.OpTk; import jdk.incubator.code.CodeElement; import jdk.incubator.code.CopyContext; @@ -85,8 +85,8 @@ private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) { private String findNameVector(Value v) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { return findNameVector(varLoadOp); - } else if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp hatVectorViewOp) { - return hatVectorViewOp.varName(); + } else if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) { + return hatVectorOp.varName(); }else{ return null; } @@ -134,8 +134,8 @@ public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { String name = findNameVector(v); boolean isSharedOrPrivate = findIsSharedOrPrivateSpace(invokeOp.operands().get(0)); - HATVectorViewOp storeView = switch (vectorOperation) { - case FLOAT4_STORE -> new HATVectorStoreView(name, invokeOp.resultType(), 4, HATVectorViewOp.VectorType.FLOAT4, isSharedOrPrivate, outputOperandsVarOp); + HATVectorOp storeView = switch (vectorOperation) { + case FLOAT4_STORE -> new HATVectorStoreView(name, invokeOp.resultType(), 4, HATVectorOp.VectorType.FLOAT4, isSharedOrPrivate, outputOperandsVarOp); }; Op.Result hatLocalResult = blockBuilder.op(storeView); storeView.setLocation(invokeOp.location()); diff --git a/hat/tests/src/main/java/hat/test/TestVectorTypes.java b/hat/tests/src/main/java/hat/test/TestVectorTypes.java index 0f7a2196e4d..e2a68db9c4c 100644 --- a/hat/tests/src/main/java/hat/test/TestVectorTypes.java +++ b/hat/tests/src/main/java/hat/test/TestVectorTypes.java @@ -72,11 +72,8 @@ public static void vectorOps03(@RO KernelContext kernelContext, @RO F32ArrayPadd float scaleY = vA.y() * 20.0f; float scaleZ = vA.z() * 30.0f; float scaleW = vA.w() * 40.0f; - vA.x(scaleX); - vA.y(scaleY); - vA.z(scaleZ); - vA.w(scaleW); - b.storeFloat4View(vA, index * 4); + Float4 vResult = Float4.of(scaleX, scaleY, scaleZ, scaleW); + b.storeFloat4View(vResult, index * 4); } } @@ -569,4 +566,46 @@ public void testVectorTypes12() { HatAsserts.assertEquals(arrayA.array(i), arrayB.array(i), 0.001f); } } + + @HatTest + public void testVectorTypes13() { + // Test the CPU implementation of Float4 + Float4 vA = Float4.of(1, 2, 3, 4); + Float4 vB = Float4.of(4, 3, 2, 1); + Float4 vC = Float4.add(vA, vB); + Float4 expectedSum = Float4.of(vA.x() + vB.x(), + vA.y() + vB.y(), + vA.z() + vB.z(), + vA.w() + vB.w() + ); + HatAsserts.assertEquals(expectedSum, vC, 0.001f); + + Float4 vD = Float4.sub(vA, vB); + Float4 expectedSub = Float4.of( + vA.x() - vB.x(), + vA.y() - vB.y(), + vA.z() - vB.z(), + vA.w() - vB.w() + ); + HatAsserts.assertEquals(expectedSub, vD, 0.001f); + + Float4 vE = Float4.mul(vA, vB); + Float4 expectedMul = Float4.of( + vA.x() * vB.x(), + vA.y() * vB.y(), + vA.z() * vB.z(), + vA.w() * vB.w() + ); + HatAsserts.assertEquals(expectedMul, vE, 0.001f); + + Float4 vF = Float4.div(vA, vB); + Float4 expectedDiv = Float4.of( + vA.x() / vB.x(), + vA.y() / vB.y(), + vA.z() / vB.z(), + vA.w() / vB.w() + ); + HatAsserts.assertEquals(expectedDiv, vF, 0.001f); + } } + diff --git a/hat/tests/src/main/java/hat/test/engine/HatAsserts.java b/hat/tests/src/main/java/hat/test/engine/HatAsserts.java index 853c9f2388b..cee0b7ccdb9 100644 --- a/hat/tests/src/main/java/hat/test/engine/HatAsserts.java +++ b/hat/tests/src/main/java/hat/test/engine/HatAsserts.java @@ -24,6 +24,8 @@ */ package hat.test.engine; +import hat.buffer.Float4; + public class HatAsserts { public static void assertEquals(int expected, int actual) { @@ -50,6 +52,18 @@ public static void assertEquals(double expected, double actual, double delta) { } } + public static void assertEquals(Float4 expected, Float4 actual, float delta) { + float[] arrayExpected = expected.toArray(); + float[] arrayActual = actual.toArray(); + for (int i = 0; i < 4; i++) { + var expectedValue = arrayExpected[i]; + var actualValue = arrayActual[i]; + if (Math.abs(expectedValue - actualValue) > delta) { + throw new HatAssertionError("Expected: " + expectedValue + " != actual: " + actualValue); + } + } + } + public static void assertTrue(boolean isCorrect) { if (!isCorrect) { throw new HatAssertionError("Expected: " + isCorrect); diff --git a/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java b/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java index 991814b17a6..3c7c03b4071 100644 --- a/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java +++ b/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java @@ -28,6 +28,8 @@ import hat.codebuilders.HATCodeBuilderWithContext; import hat.dialect.HATBlockThreadIdOp; import hat.dialect.HATF16ConvOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATF16BinaryOp; @@ -203,6 +205,12 @@ public T hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16 return self(); } + @Override + public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + blockComment("Vector Of Ops Not Implemented"); + return self(); + } + public T createJava(ScopedCodeBuilderContext buildContext) { buildContext.funcScope(buildContext.funcOp, () -> { typeName(buildContext.funcOp.resultType().toString()).space().funcName(buildContext.funcOp); From c822d54a1a0f616c7835105fe93db7e3c8a30d3b Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Fri, 24 Oct 2025 17:22:53 +0200 Subject: [PATCH 02/14] [hat] fix dialectify Float.of phase --- .../java/hat/backend/ffi/OpenCLHATKernelBuilder.java | 1 - .../java/hat/phases/HATDialectifyVectorOpPhase.java | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 7c975ef13fc..0e2ad5d52c9 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -28,7 +28,6 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java index 7a1caf8d7fe..f1d9924249a 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java @@ -210,11 +210,11 @@ private CoreOp.FuncOp dialectifyVectorOf(CoreOp.FuncOp funcOp) { if (codeElement instanceof JavaOp.InvokeOp invokeOp) { if (isVectorOperation(invokeOp)) { consumer.accept(invokeOp); - } - Set uses = invokeOp.result().uses(); - for (Op.Result result : uses) { - if (result.op() instanceof CoreOp.VarOp varOp) { - consumer.accept(varOp); + Set uses = invokeOp.result().uses(); + for (Op.Result result : uses) { + if (result.op() instanceof CoreOp.VarOp varOp) { + consumer.accept(varOp); + } } } } From 087083324c4dbffedfa32602925cc53ef4f8208b Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Fri, 24 Oct 2025 17:35:23 +0200 Subject: [PATCH 03/14] [hat] CUDA backend Float4.of supported --- .../hat/backend/ffi/CudaHATKernelBuilder.java | 30 +++++++++++++++++-- .../backend/ffi/OpenCLHATKernelBuilder.java | 30 +++++++++++++++++-- .../jextracted/OpenCLHatKernelBuilder.java | 29 ++++++++++++++++-- .../hat/codebuilders/C99HATKernelBuilder.java | 22 -------------- 4 files changed, 83 insertions(+), 28 deletions(-) diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java index 8e665abe807..b1b21d03c4e 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java @@ -28,15 +28,18 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorSelectLoadOp; -import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorSelectLoadOp; +import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorStoreView; import hat.dialect.HATVectorVarOp; import jdk.incubator.code.Op; import jdk.incubator.code.Value; +import java.util.List; + public class CudaHATKernelBuilder extends C99HATKernelBuilder { private CudaHATKernelBuilder threadDimId(int id) { @@ -250,4 +253,27 @@ public CudaHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext } return self(); } + + @Override + public CudaHATKernelBuilder hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + identifier("make_" + hatVectorOp.buildType()).oparen(); + + List inputOperands = hatVectorOp.operands(); + int i; + for (i = 0; i < (inputOperands.size() - 1); i++) { + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + comma().space(); + } + // Last parameter + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + cparen(); + return self(); + } + } diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 0e2ad5d52c9..964dc157cde 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -28,15 +28,18 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorSelectLoadOp; -import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorSelectLoadOp; +import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorStoreView; import hat.dialect.HATVectorVarOp; import jdk.incubator.code.Op; import jdk.incubator.code.Value; +import java.util.List; + public class OpenCLHATKernelBuilder extends C99HATKernelBuilder { @Override @@ -204,4 +207,27 @@ public OpenCLHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte return self(); } + @Override + public OpenCLHATKernelBuilder hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + oparen().identifier(hatVectorOp.buildType()).cparen().oparen(); + + List inputOperands = hatVectorOp.operands(); + int i; + for (i = 0; i < (inputOperands.size() - 1); i++) { + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + comma().space(); + } + // Last parameter + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + cparen(); + return self(); + } + + } diff --git a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java index a33fc95417a..d17b313ee99 100644 --- a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java +++ b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java @@ -28,15 +28,18 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorSelectLoadOp; -import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorSelectLoadOp; +import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorStoreView; import hat.dialect.HATVectorVarOp; import jdk.incubator.code.Op; import jdk.incubator.code.Value; +import java.util.List; + public class OpenCLHatKernelBuilder extends C99HATKernelBuilder { @Override @@ -204,4 +207,26 @@ public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte return self(); } + @Override + public OpenCLHatKernelBuilder hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + oparen().identifier(hatVectorOp.buildType()).cparen().oparen(); + + List inputOperands = hatVectorOp.operands(); + int i; + for (i = 0; i < (inputOperands.size() - 1); i++) { + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + comma().space(); + } + // Last parameter + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + cparen(); + return self(); + } + } diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java index 0a82d1dae99..6b1c7f0c141 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java @@ -269,28 +269,6 @@ public T hatF16BinaryOp(ScopedCodeBuilderContext buildContext, HATF16BinaryOp ha return self(); } - @Override - public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { - oparen().identifier(hatVectorOp.buildType()).cparen().oparen(); - - List inputOperands = hatVectorOp.operands(); - int i; - for (i = 0; i < (inputOperands.size() - 1); i++) { - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - comma().space(); - } - // Last parameter - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - cparen(); - return self(); - } - @Override public T hatF16VarLoadOp(ScopedCodeBuilderContext buildContext, HATF16VarLoadOp hatF16VarLoadOp) { identifier(hatF16VarLoadOp.varName()); From ba1005b7c48cb26314d25854857737cbd7fc3bb0 Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 09:24:00 +0100 Subject: [PATCH 04/14] [hat] Float4 views docs --- hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java | 2 -- hat/tests/src/main/java/hat/test/TestVectorTypes.java | 9 +++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java b/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java index cae5ae112e2..637a4d9f137 100644 --- a/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java +++ b/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java @@ -25,8 +25,6 @@ package hat.dialect; import jdk.incubator.code.CopyContext; -import jdk.incubator.code.Op; -import jdk.incubator.code.OpTransformer; import jdk.incubator.code.TypeElement; import jdk.incubator.code.Value; diff --git a/hat/tests/src/main/java/hat/test/TestVectorTypes.java b/hat/tests/src/main/java/hat/test/TestVectorTypes.java index e2a68db9c4c..2bb59be137f 100644 --- a/hat/tests/src/main/java/hat/test/TestVectorTypes.java +++ b/hat/tests/src/main/java/hat/test/TestVectorTypes.java @@ -67,12 +67,21 @@ public static void vectorOps02(@RO KernelContext kernelContext, @RO F32ArrayPadd public static void vectorOps03(@RO KernelContext kernelContext, @RO F32ArrayPadded a, @RW F32ArrayPadded b) { if (kernelContext.gix < kernelContext.gsx) { int index = kernelContext.gix; + + // Obtain a view of the input data as a float4 and + // store that view in private memory Float4 vA = a.float4View(index * 4); + + // operate with the float4 float scaleX = vA.x() * 10.0f; float scaleY = vA.y() * 20.0f; float scaleZ = vA.z() * 30.0f; float scaleW = vA.w() * 40.0f; + + // Create a float4 within the device code Float4 vResult = Float4.of(scaleX, scaleY, scaleZ, scaleW); + + // store the float4 from private memory to global memory b.storeFloat4View(vResult, index * 4); } } From 924f142ce0e1265090becc0355be9e9d707851de Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 09:41:20 +0100 Subject: [PATCH 05/14] [hat][wip] float4 implementation with mutable/immutable variants --- .../main/java/hat/buffer/F32ArrayPadded.java | 2 +- hat/core/src/main/java/hat/buffer/Float4.java | 29 ++++++++++--------- .../main/java/hat/test/TestVectorTypes.java | 4 +-- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java b/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java index 87f7367420f..c793efac1fc 100644 --- a/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java +++ b/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java @@ -67,7 +67,7 @@ default float[] arrayView() { // This is an intrinsic for HAT to create views. It does not execute code // on the host side, at least for now. - default Float4 float4View(int index) { + default Float4.MutableImpl float4View(int index) { return null; } diff --git a/hat/core/src/main/java/hat/buffer/Float4.java b/hat/core/src/main/java/hat/buffer/Float4.java index 8cd32dd35c8..25546bca476 100644 --- a/hat/core/src/main/java/hat/buffer/Float4.java +++ b/hat/core/src/main/java/hat/buffer/Float4.java @@ -32,27 +32,29 @@ public interface Float4 extends HatVector { float y(); float z(); float w(); - void x(float x); - void y(float y); - void z(float z); - void w(float w); - record Float4Impl(float x, float y, float z, float w) implements Float4 { - @Override + record MutableImpl(float x, float y, float z, float w) implements Float4 { public void x(float x) {} - - @Override public void y(float y) {} - - @Override public void z(float z) {} - - @Override public void w(float w) {} } + record ImmutableImpl(float x, float y, float z, float w) implements Float4 { + } + + /** + * Make a Mutable implementation (for the device side - e.g., the GPU) from an immutable implementation. + * + * @param float4Immutable + * @return {@link Float4.MutableImpl} + */ + static Float4.MutableImpl makeMutable(ImmutableImpl float4Immutable) { + return new MutableImpl(float4Immutable.x(), float4Immutable.y(), float4Immutable.z(), float4Immutable.w()); + } + static Float4 of(float x, float y, float z, float w) { - return new Float4Impl(x, y, z, w); + return new ImmutableImpl(x, y, z, w); } default Float4 lanewise(Float4 other, BiFunction f) { @@ -98,6 +100,7 @@ default Float4 div(Float4 vb) { return Float4.div(this, vb); } + // Not implemented for the GPU yet. default float[] toArray() { return new float[] { x(), y(), z(), w() }; } diff --git a/hat/tests/src/main/java/hat/test/TestVectorTypes.java b/hat/tests/src/main/java/hat/test/TestVectorTypes.java index 2bb59be137f..41e2ca27427 100644 --- a/hat/tests/src/main/java/hat/test/TestVectorTypes.java +++ b/hat/tests/src/main/java/hat/test/TestVectorTypes.java @@ -56,7 +56,7 @@ public static void vectorOps01(@RO KernelContext kernelContext, @RO F32ArrayPadd public static void vectorOps02(@RO KernelContext kernelContext, @RO F32ArrayPadded a, @RW F32ArrayPadded b) { if (kernelContext.gix < kernelContext.gsx) { int index = kernelContext.gix; - Float4 vA = a.float4View(index * 4); + Float4.MutableImpl vA = a.float4View(index * 4); float scaleX = vA.x() * 10.0f; vA.x(scaleX); b.storeFloat4View(vA, index * 4); @@ -90,7 +90,7 @@ public static void vectorOps03(@RO KernelContext kernelContext, @RO F32ArrayPadd public static void vectorOps04(@RO KernelContext kernelContext, @RO F32ArrayPadded a, @RW F32ArrayPadded b) { if (kernelContext.gix < kernelContext.gsx) { int index = kernelContext.gix; - Float4 vA = a.float4View(index * 4); + Float4.MutableImpl vA = a.float4View(index * 4); vA.x(vA.x() * 10.0f); vA.y(vA.y() * 20.0f); vA.z(vA.z() * 30.0f); From 63970115bcb75515db4c4ab63151a3cf02b0065e Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 10:36:38 +0100 Subject: [PATCH 06/14] [hat] makeFloat4 from immutable supported on GPU --- .../main/java/hat/buffer/F32ArrayPadded.java | 15 ++-- hat/core/src/main/java/hat/buffer/Float4.java | 20 +++-- .../hat/codebuilders/BabylonOpBuilder.java | 4 + .../hat/codebuilders/C99HATKernelBuilder.java | 6 ++ .../java/hat/dialect/HATVectorMakeOfOp.java | 76 +++++++++++++++++++ .../java/hat/phases/HATDialectifyTier.java | 1 + .../phases/HATDialectifyVectorOpPhase.java | 62 ++++++++++++++- .../main/java/hat/test/TestVectorTypes.java | 37 +++++++++ .../hat/tools/text/JavaHATCodeBuilder.java | 7 ++ 9 files changed, 211 insertions(+), 17 deletions(-) create mode 100644 hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java diff --git a/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java b/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java index c793efac1fc..9c7b1d39a38 100644 --- a/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java +++ b/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java @@ -65,13 +65,18 @@ default float[] arrayView() { return arr; } - // This is an intrinsic for HAT to create views. It does not execute code - // on the host side, at least for now. default Float4.MutableImpl float4View(int index) { - return null; +// MemorySegment memorySegment = Buffer.getMemorySegment(this); +// float f1 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 0); +// float f2 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 1); +// float f3 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 2); +// float f4 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 3); +// return Float4.makeMutable(Float4.of(f1, f2, f3, f4)); + return null; } - // This is an intrinsic for HAT to store views. It does not execute code - default void storeFloat4View(Float4 v, int index) {} + default void storeFloat4View(Float4 v, int index) { +// MemorySegment.copy(Buffer.getMemorySegment(this), JAVA_FLOAT, ARRAY_OFFSET, v.toArray(), index, 4); + } } diff --git a/hat/core/src/main/java/hat/buffer/Float4.java b/hat/core/src/main/java/hat/buffer/Float4.java index 25546bca476..ffca332066e 100644 --- a/hat/core/src/main/java/hat/buffer/Float4.java +++ b/hat/core/src/main/java/hat/buffer/Float4.java @@ -24,7 +24,10 @@ */ package hat.buffer; +import jdk.incubator.code.CodeReflection; + import java.util.function.BiFunction; +import java.util.stream.IntStream; public interface Float4 extends HatVector { @@ -46,25 +49,25 @@ record ImmutableImpl(float x, float y, float z, float w) implements Float4 { /** * Make a Mutable implementation (for the device side - e.g., the GPU) from an immutable implementation. * - * @param float4Immutable + * @param float4 * @return {@link Float4.MutableImpl} */ - static Float4.MutableImpl makeMutable(ImmutableImpl float4Immutable) { - return new MutableImpl(float4Immutable.x(), float4Immutable.y(), float4Immutable.z(), float4Immutable.w()); + static Float4.MutableImpl makeMutable(Float4 float4) { + return new MutableImpl(float4.x(), float4.y(), float4.z(), float4.w()); } static Float4 of(float x, float y, float z, float w) { return new ImmutableImpl(x, y, z, w); } + // Not implemented for the GPU yet default Float4 lanewise(Float4 other, BiFunction f) { float[] backA = this.toArray(); float[] backB = other.toArray(); float[] backC = new float[backA.length]; - for (int j = 0; j < backA.length; j++) { - var r = f.apply(backA[j], backB[j]); - backC[j] = r; - } + IntStream.range(0, backA.length).forEach(j -> { + backC[j] = f.apply(backA[j], backB[j]); + }); return of(backC[0], backC[1], backC[2], backC[3]); } @@ -100,7 +103,8 @@ default Float4 div(Float4 vb) { return Float4.div(this, vb); } - // Not implemented for the GPU yet. + // Not implemented for the GPU yet + @CodeReflection default float[] toArray() { return new float[] { x(), y(), z(), w() }; } diff --git a/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java b/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java index 393375f00f3..32213febee8 100644 --- a/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java @@ -30,6 +30,7 @@ import hat.dialect.HATF16ConvOp; import hat.dialect.HATF16VarLoadOp; import hat.dialect.HATF16VarOp; +import hat.dialect.HATVectorMakeOfOp; import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; @@ -144,6 +145,8 @@ public interface BabylonOpBuilder> { T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp); + T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp); + default T recurse(ScopedCodeBuilderContext buildContext, Op op) { switch (op) { case CoreOp.VarAccessOp.VarLoadOp $ -> varLoadOp(buildContext, $); @@ -191,6 +194,7 @@ default T recurse(ScopedCodeBuilderContext buildContext, Op op) { case HATF16BinaryOp $ -> hatF16BinaryOp(buildContext, $); case HATF16VarLoadOp $ -> hatF16VarLoadOp(buildContext, $); case HATF16ConvOp $ -> hatF16ConvOp(buildContext, $); + case HATVectorMakeOfOp $ -> hatVectorMakeOf(buildContext, $); default -> throw new IllegalStateException("handle nesting of op " + op); } return (T) this; diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java index 6b1c7f0c141..c501df371d7 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java @@ -33,6 +33,7 @@ import hat.dialect.HATGlobalThreadIdOp; import hat.dialect.HATLocalSizeOp; import hat.dialect.HATLocalThreadIdOp; +import hat.dialect.HATVectorMakeOfOp; import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorVarLoadOp; import hat.ifacemapper.MappableIface; @@ -275,6 +276,11 @@ public T hatF16VarLoadOp(ScopedCodeBuilderContext buildContext, HATF16VarLoadOp return self(); } + @Override + public T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp) { + identifier(hatVectorMakeOfOp.varName()); + return self(); + } public T kernelDeclaration(CoreOp.FuncOp funcOp) { return kernelPrefix().voidType().space().funcName(funcOp); diff --git a/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java b/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java new file mode 100644 index 00000000000..3905aec61f4 --- /dev/null +++ b/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package hat.dialect; + +import jdk.incubator.code.CopyContext; +import jdk.incubator.code.Op; +import jdk.incubator.code.OpTransformer; +import jdk.incubator.code.TypeElement; +import jdk.incubator.code.Value; + +import java.util.List; +import java.util.Map; + +public class HATVectorMakeOfOp extends HATVectorOp { + + private final TypeElement typeElement; + private final int loadN; + + public HATVectorMakeOfOp(String varName, TypeElement typeElement, int loadN, List operands) { + super(varName, operands); + this.typeElement = typeElement; + this.loadN = loadN; + } + + public HATVectorMakeOfOp(HATVectorMakeOfOp op, CopyContext copyContext) { + super(op, copyContext); + this.typeElement = op.typeElement; + this.loadN = op.loadN; + } + + @Override + public Op transform(CopyContext copyContext, OpTransformer opTransformer) { + return new HATVectorMakeOfOp(this, copyContext); + } + + @Override + public TypeElement resultType() { + return typeElement; + } + + @Override + public Map externalize() { + return Map.of("hat.dialect.makeOf." + varName(), typeElement); + } + + public String buildType() { + // floatN + if (typeElement.toString().startsWith("hat.buffer.Float")) { + return "float" + loadN; + } + throw new RuntimeException("Unexpected vector type " + typeElement); + } + +} diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyTier.java b/hat/core/src/main/java/hat/phases/HATDialectifyTier.java index fdde6854070..e6e86c99863 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyTier.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyTier.java @@ -58,6 +58,7 @@ public HATDialectifyTier(Accelerator accelerator) { hatPhases.add(new HATDialectifyVectorOpPhase.SubPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.MulPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.DivPhase(accelerator)); + hatPhases.add(new HATDialectifyVectorOpPhase.MakeMutable(accelerator)); hatPhases.add(new HATDialectifyVectorStorePhase.Float4StorePhase(accelerator)); // Vector Select individual lines diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java index f1d9924249a..6b53aeb3484 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java @@ -30,6 +30,7 @@ import hat.dialect.HATVectorAddOp; import hat.dialect.HATVectorDivOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorMakeOfOp; import hat.dialect.HATVectorMulOp; import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorSubOp; @@ -85,7 +86,8 @@ public enum OpView { ADD("add"), SUB("sub"), MUL("mul"), - DIV("div"); + DIV("div"), + MAKE_MUTABLE("makeMutable"); final String methodName; OpView(String methodName) { this.methodName = methodName; @@ -275,9 +277,6 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { funcOp = OpTk.transform(here, funcOp, nodesInvolved::contains, (blockBuilder, op) -> { CopyContext context = blockBuilder.context(); - // if (!nodesInvolved.contains(op)) { - // blockBuilder.op(op); - //} else if (op instanceof JavaOp.InvokeOp invokeOp) { Op.Result result = invokeOp.result(); List inputOperands = invokeOp.operands(); @@ -307,6 +306,52 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { return funcOp; } + private CoreOp.FuncOp dialectifyMutableOf(CoreOp.FuncOp funcOp) { + var here = OpTk.CallSite.of(this.getClass(), "dialectifyMutableOf" ); + before(here,funcOp); + Stream> float4NodesInvolved = funcOp.elements() + .mapMulti((codeElement, consumer) -> { + if (codeElement instanceof JavaOp.InvokeOp invokeOp) { + if (isVectorOperation(invokeOp)) { + consumer.accept(invokeOp); + Set uses = invokeOp.result().uses(); + for (Op.Result result : uses) { + if (result.op() instanceof CoreOp.VarOp varOp) { + consumer.accept(varOp); + } + } + } + } + }); + + Set> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); + + funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> { + CopyContext context = blockBuilder.context(); + if (!nodesInvolved.contains(op)) { + blockBuilder.op(op); + } else if (op instanceof JavaOp.InvokeOp invokeOp) { + List inputOperandsVarOp = invokeOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + String varName = findNameVector(invokeOp.operands().getFirst()); + HATVectorMakeOfOp makeOf = new HATVectorMakeOfOp(varName, invokeOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(makeOf); + makeOf.setLocation(invokeOp.location()); + context.mapValue(invokeOp.result(), hatLocalResult); + } else if (op instanceof CoreOp.VarOp varOp) { + List inputOperandsVarOp = varOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); + memoryViewOp.setLocation(varOp.location()); + context.mapValue(varOp.result(), hatLocalResult); + } + return blockBuilder; + }); + after(here, funcOp); + return funcOp; + } + private CoreOp.FuncOp dialectifyVectorBinaryWithConcatenationOps(CoreOp.FuncOp funcOp) { var here = OpTk.CallSite.of(this.getClass(), "dialectifyBinaryWithConcatenation"); before(here, funcOp); @@ -371,6 +416,8 @@ public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { funcOp = dialectifyVectorLoad(funcOp); } else if (Objects.requireNonNull(vectorOperation) == OpView.OF) { funcOp = dialectifyVectorOf(funcOp); + } else if (Objects.requireNonNull(vectorOperation) == OpView.MAKE_MUTABLE) { + funcOp = dialectifyMutableOf(funcOp); } else { // Find binary operations funcOp = dialectifyVectorBinaryOps(funcOp); @@ -393,6 +440,13 @@ public DivPhase(Accelerator accelerator) { } } + public static class MakeMutable extends HATDialectifyVectorOpPhase{ + + public MakeMutable(Accelerator accelerator) { + super(accelerator, OpView.MAKE_MUTABLE); + } + } + public static class Float4LoadPhase extends HATDialectifyVectorOpPhase { public Float4LoadPhase(Accelerator accelerator) { diff --git a/hat/tests/src/main/java/hat/test/TestVectorTypes.java b/hat/tests/src/main/java/hat/test/TestVectorTypes.java index 41e2ca27427..a51e6d13ba7 100644 --- a/hat/tests/src/main/java/hat/test/TestVectorTypes.java +++ b/hat/tests/src/main/java/hat/test/TestVectorTypes.java @@ -236,6 +236,17 @@ public static void vectorOps12(@RO KernelContext kernelContext, @RO F32ArrayPadd } } + @CodeReflection + public static void vectorOps14(@RO KernelContext kernelContext, @RW F32ArrayPadded a) { + if (kernelContext.gix < kernelContext.gsx) { + int index = kernelContext.gix; + Float4 vA = a.float4View(index * 4); + Float4.MutableImpl vB = Float4.makeMutable(vA); + vB.x(10.0f); + a.storeFloat4View(vB, index * 4); + } + } + @CodeReflection public static void computeGraph01(@RO ComputeContext cc, @RO F32ArrayPadded a, @RO F32ArrayPadded b, @RW F32ArrayPadded c, int size) { // Note: we need to launch N threads / vectorWidth -> size / 4 for this example @@ -321,6 +332,13 @@ public static void computeGraph12(@RO ComputeContext cc, @RO F32ArrayPadded a, cc.dispatchKernel(computeRange, kernelContext -> TestVectorTypes.vectorOps12(kernelContext, a, b)); } + @CodeReflection + public static void computeGraph14(@RO ComputeContext cc, @RW F32ArrayPadded a, int size) { + // Note: we need to launch N threads / vectorWidth -> size / 4 for this example + ComputeRange computeRange = new ComputeRange(new GlobalMesh1D(size/4)); + cc.dispatchKernel(computeRange, kernelContext -> TestVectorTypes.vectorOps14(kernelContext, a)); + } + @HatTest public void testVectorTypes01() { final int size = 1024; @@ -616,5 +634,24 @@ public void testVectorTypes13() { ); HatAsserts.assertEquals(expectedDiv, vF, 0.001f); } + + @HatTest + public void testVectorTypes14() { + final int size = 1024; + var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST); + var arrayA = F32ArrayPadded.create(accelerator, size); + + Random r = new Random(73); + for (int i = 0; i < size; i++) { + arrayA.array(i, r.nextFloat()); + } + + accelerator.compute(cc -> TestVectorTypes.computeGraph14(cc, arrayA, size)); + + for (int i = 0; i < size; i += 4) { + HatAsserts.assertEquals(10.0f, arrayA.array(i), 0.001f); + } + } + } diff --git a/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java b/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java index 3c7c03b4071..0f8d5b3ec76 100644 --- a/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java +++ b/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java @@ -28,6 +28,7 @@ import hat.codebuilders.HATCodeBuilderWithContext; import hat.dialect.HATBlockThreadIdOp; import hat.dialect.HATF16ConvOp; +import hat.dialect.HATVectorMakeOfOp; import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorOp; import hat.dialect.HATVectorSelectLoadOp; @@ -211,6 +212,12 @@ public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hat return self(); } + @Override + public T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp) { + blockComment("Vector Make Of Op Not Implemented"); + return self(); + } + public T createJava(ScopedCodeBuilderContext buildContext) { buildContext.funcScope(buildContext.funcOp, () -> { typeName(buildContext.funcOp.resultType().toString()).space().funcName(buildContext.funcOp); From 31e0db215f12cff766b059cc750669be5adbdfad Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 10:39:10 +0100 Subject: [PATCH 07/14] [hat] minor change --- hat/tests/src/main/java/hat/test/TestVectorTypes.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/hat/tests/src/main/java/hat/test/TestVectorTypes.java b/hat/tests/src/main/java/hat/test/TestVectorTypes.java index a51e6d13ba7..165d9da8f5f 100644 --- a/hat/tests/src/main/java/hat/test/TestVectorTypes.java +++ b/hat/tests/src/main/java/hat/test/TestVectorTypes.java @@ -24,7 +24,12 @@ */ package hat.test; -import hat.*; +import hat.Accelerator; +import hat.ComputeContext; +import hat.ComputeRange; +import hat.GlobalMesh1D; +import hat.KernelContext; +import hat.LocalMesh1D; import hat.backend.Backend; import hat.buffer.Buffer; import hat.buffer.F32ArrayPadded; @@ -652,6 +657,6 @@ public void testVectorTypes14() { HatAsserts.assertEquals(10.0f, arrayA.array(i), 0.001f); } } - + } From 2f954e0b4042f13ba0e7ea5255bbda9a5f04bd1b Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 11:32:33 +0100 Subject: [PATCH 08/14] [hat] Move vectorOf codegen method to common C99 codegen --- .../hat/backend/ffi/CudaHATKernelBuilder.java | 21 ++-------------- .../backend/ffi/OpenCLHATKernelBuilder.java | 21 ++-------------- .../jextracted/OpenCLHatKernelBuilder.java | 20 ++-------------- .../hat/codebuilders/C99HATKernelBuilder.java | 24 +++++++++++++++++++ 4 files changed, 30 insertions(+), 56 deletions(-) diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java index b1b21d03c4e..cbeccb87049 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java @@ -255,25 +255,8 @@ public CudaHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext } @Override - public CudaHATKernelBuilder hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { - identifier("make_" + hatVectorOp.buildType()).oparen(); - - List inputOperands = hatVectorOp.operands(); - int i; - for (i = 0; i < (inputOperands.size() - 1); i++) { - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - comma().space(); - } - // Last parameter - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - cparen(); + public CudaHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) { + identifier("make_" + hatVectorOfOp.buildType()).oparen(); return self(); } - } diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 964dc157cde..c3503c57e3b 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -208,26 +208,9 @@ public OpenCLHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte } @Override - public OpenCLHATKernelBuilder hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { - oparen().identifier(hatVectorOp.buildType()).cparen().oparen(); - - List inputOperands = hatVectorOp.operands(); - int i; - for (i = 0; i < (inputOperands.size() - 1); i++) { - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - comma().space(); - } - // Last parameter - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - cparen(); + public OpenCLHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) { + oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen(); return self(); } - } diff --git a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java index d17b313ee99..0442cf8ffb8 100644 --- a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java +++ b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java @@ -208,24 +208,8 @@ public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte } @Override - public OpenCLHatKernelBuilder hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { - oparen().identifier(hatVectorOp.buildType()).cparen().oparen(); - - List inputOperands = hatVectorOp.operands(); - int i; - for (i = 0; i < (inputOperands.size() - 1); i++) { - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - comma().space(); - } - // Last parameter - var operand = inputOperands.get(i); - if ((operand instanceof Op.Result r)) { - recurse(buildContext, r.op()); - } - cparen(); + public OpenCLHatKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) { + oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen(); return self(); } diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java index c501df371d7..2dd9676f930 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java @@ -282,6 +282,30 @@ public T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeO return self(); } + public abstract T genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp); + + @Override + public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + genVectorIdentifier(buildContext, hatVectorOp); + + List inputOperands = hatVectorOp.operands(); + int i; + for (i = 0; i < (inputOperands.size() - 1); i++) { + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + comma().space(); + } + // Last parameter + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + cparen(); + return self(); + } + public T kernelDeclaration(CoreOp.FuncOp funcOp) { return kernelPrefix().voidType().space().funcName(funcOp); } From e9864a7798c0418646099458d20e698717f80b40 Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 11:36:18 +0100 Subject: [PATCH 09/14] [hat] codegen - cleanup old OpenCL typecasts --- .../src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java | 3 --- .../java/hat/backend/jextracted/OpenCLHatKernelBuilder.java | 2 -- 2 files changed, 5 deletions(-) diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index c3503c57e3b..c3ca1c49cbb 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -184,12 +184,10 @@ public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon @Override public OpenCLHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) { oparen().typeName("half").cparen(); - // typeName("convert_half").oparen(); Value initValue = hatF16ConvOp.operands().getFirst(); if (initValue instanceof Op.Result r) { recurse(buildContext, r.op()); } - //cparen(); return self(); } @@ -212,5 +210,4 @@ public OpenCLHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext build oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen(); return self(); } - } diff --git a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java index 0442cf8ffb8..c20ce527281 100644 --- a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java +++ b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java @@ -184,12 +184,10 @@ public OpenCLHatKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon @Override public OpenCLHatKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) { oparen().typeName("half").cparen(); - // typeName("convert_half").oparen(); Value initValue = hatF16ConvOp.operands().getFirst(); if (initValue instanceof Op.Result r) { recurse(buildContext, r.op()); } - //cparen(); return self(); } From fd9c6d3c1751b130bd5efbb2b0246792dedeed2b Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 12:14:20 +0100 Subject: [PATCH 10/14] [hat] minor refactor ocl-codegen --- .../src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java | 2 -- .../main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java | 2 +- .../ffi/opencl/src/main/native/cpp/opencl_backend.cpp | 5 ++++- hat/core/src/main/java/hat/codebuilders/CodeBuilder.java | 4 ++++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java index cbeccb87049..2683d1ac188 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java @@ -38,8 +38,6 @@ import jdk.incubator.code.Op; import jdk.incubator.code.Value; -import java.util.List; - public class CudaHATKernelBuilder extends C99HATKernelBuilder { private CudaHATKernelBuilder threadDimId(int id) { diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index c3ca1c49cbb..6df87b2e034 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -183,7 +183,7 @@ public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon @Override public OpenCLHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) { - oparen().typeName("half").cparen(); + oparen().halfType().cparen(); Value initValue = hatF16ConvOp.operands().getFirst(); if (initValue instanceof Op.Result r) { recurse(buildContext, r.op()); diff --git a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp index 115ebd1b3be..11dc319f179 100644 --- a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp +++ b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp @@ -303,7 +303,10 @@ const char *OpenCLBackend::errorMsg(cl_int status) { } extern "C" long getBackend(int configBits) { - std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl; + Backend::Config config(configBits); + if (config.info) { + std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl; + } return reinterpret_cast(new OpenCLBackend(configBits)); } diff --git a/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java b/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java index 4e5ebaebab0..eacddd64e4c 100644 --- a/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java @@ -507,6 +507,9 @@ public final T shortType() { return typeName("short"); } + public final T halfType() { + return typeName("half"); + } @Override public final T comment(String text) { @@ -531,6 +534,7 @@ public T label(String text) { public final T symbol(String text) { return emitText(text); } + @Override public final T typeName(String text) { return emitText(text); From 3c45324c39c29e05ee03772287819308e67c30a7 Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 12:35:47 +0100 Subject: [PATCH 11/14] [hat] Dialect for handling vector types simplified --- .../backend/ffi/OpenCLHATKernelBuilder.java | 4 ++-- .../jextracted/OpenCLHatKernelBuilder.java | 4 ++-- .../java/hat/dialect/HATVectorBinaryOp.java | 24 ++++--------------- .../java/hat/dialect/HATVectorLoadOp.java | 18 +------------- .../java/hat/dialect/HATVectorMakeOfOp.java | 10 +------- .../main/java/hat/dialect/HATVectorOfOp.java | 9 +------ .../main/java/hat/dialect/HATVectorOp.java | 21 +++++++++++++++- .../hat/dialect/HATVectorSelectLoadOp.java | 11 +++------ .../hat/dialect/HATVectorSelectStoreOp.java | 10 ++------ .../java/hat/dialect/HATVectorStoreView.java | 9 +------ .../java/hat/dialect/HATVectorVarLoadOp.java | 2 +- .../main/java/hat/dialect/HATVectorVarOp.java | 4 ++-- 12 files changed, 40 insertions(+), 86 deletions(-) diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 6df87b2e034..a52bc8eb404 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -86,7 +86,7 @@ public OpenCLHATKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon Value dest = hatVectorStoreView.operands().get(0); Value index = hatVectorStoreView.operands().get(2); - identifier("vstore" + hatVectorStoreView.storeN()) + identifier("vstore" + hatVectorStoreView.vectorN()) .oparen() .varName(hatVectorStoreView) .comma() @@ -134,7 +134,7 @@ public OpenCLHATKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont Value source = hatVectorLoadOp.operands().get(0); Value index = hatVectorLoadOp.operands().get(1); - identifier("vload" + hatVectorLoadOp.loadN()) + identifier("vload" + hatVectorLoadOp.vectorN()) .oparen() .intConstZero() .comma() diff --git a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java index c20ce527281..fbbfc2270f2 100644 --- a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java +++ b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java @@ -86,7 +86,7 @@ public OpenCLHatKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon Value dest = hatVectorStoreView.operands().get(0); Value index = hatVectorStoreView.operands().get(2); - identifier("vstore" + hatVectorStoreView.storeN()) + identifier("vstore" + hatVectorStoreView.vectorN()) .oparen() .varName(hatVectorStoreView) .comma() @@ -134,7 +134,7 @@ public OpenCLHatKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont Value source = hatVectorLoadOp.operands().get(0); Value index = hatVectorLoadOp.operands().get(1); - identifier("vload" + hatVectorLoadOp.loadN()) + identifier("vload" + hatVectorLoadOp.vectorN()) .oparen() .intConstZero() .comma() diff --git a/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java b/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java index 3caddfa7d22..af2630e3af1 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java @@ -50,21 +50,20 @@ public String symbol() { private final TypeElement elementType; private final OpType operationType; - private final int vectorN; public HATVectorBinaryOp(String varName, TypeElement typeElement, OpType operationType, List operands) { - super(varName, operands); + int l = typeElement.toString().length(); + int vectorN = Integer.parseInt(typeElement.toString().substring(l - 1, l)); + super(varName, typeElement, vectorN, operands); this.elementType = typeElement; this.operationType = operationType; - int l = typeElement.toString().length(); - vectorN = Integer.parseInt(typeElement.toString().substring(l - 1, l)); + } public HATVectorBinaryOp(HATVectorBinaryOp op, CopyContext copyContext) { super(op, copyContext); this.elementType = op.elementType; this.operationType = op.operationType; - this.vectorN = op.vectorN; } @Override @@ -72,23 +71,8 @@ public TypeElement resultType() { return this.elementType; } - // @Override - //public Map externalize() { - // return Map.of("hat.dialect.floatNOp." + varName(), elementType); - // } - public OpType operationType() { return operationType; } - public int vectorN() { - return vectorN; - } - - public String buildType() { - if (elementType.toString().startsWith("hat.buffer.Float")) { - return "float" + vectorN; - } - throw new RuntimeException("Unexpected vector type " + elementType); - } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java index 89ca2ec0ee5..79e831acdf5 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java @@ -41,7 +41,7 @@ public class HATVectorLoadOp extends HATVectorOp { private final boolean isSharedOrPrivate; public HATVectorLoadOp(String varName, TypeElement typeElement, TypeElement vectorType, int loadN, boolean isShared, List operands) { - super(varName, operands); + super(varName, typeElement, loadN, operands); this.typeElement = typeElement; this.loadN = loadN; this.vectorType = vectorType; @@ -71,22 +71,6 @@ public Map externalize() { return Map.of("hat.dialect.vectorLoadView." + varName(), typeElement); } - public TypeElement vectorType() { - return vectorType; - } - - public int loadN() { - return loadN; - } - - public String buildType() { - // floatN - if (vectorType.toString().startsWith("hat.buffer.Float")) { - return "float" + loadN; - } - throw new RuntimeException("Unexpected vector type " + vectorType); - } - public boolean isSharedOrPrivate() { return this.isSharedOrPrivate; } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java b/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java index 3905aec61f4..3e2902b25c0 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java @@ -39,7 +39,7 @@ public class HATVectorMakeOfOp extends HATVectorOp { private final int loadN; public HATVectorMakeOfOp(String varName, TypeElement typeElement, int loadN, List operands) { - super(varName, operands); + super(varName, typeElement, loadN, operands); this.typeElement = typeElement; this.loadN = loadN; } @@ -65,12 +65,4 @@ public Map externalize() { return Map.of("hat.dialect.makeOf." + varName(), typeElement); } - public String buildType() { - // floatN - if (typeElement.toString().startsWith("hat.buffer.Float")) { - return "float" + loadN; - } - throw new RuntimeException("Unexpected vector type " + typeElement); - } - } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java b/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java index 1331ebf2909..442ad5a94a3 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java @@ -39,7 +39,7 @@ public class HATVectorOfOp extends HATVectorOp { private final int loadN; public HATVectorOfOp(TypeElement typeElement, int loadN, List operands) { - super("", operands); + super("", typeElement, loadN, operands); this.typeElement = typeElement; this.loadN = loadN; } @@ -65,12 +65,5 @@ public Map externalize() { return Map.of("hat.dialect.vectorOf." + varName(), typeElement); } - public String buildType() { - // floatN - if (typeElement.toString().startsWith("hat.buffer.Float")) { - return "float" + loadN; - } - throw new RuntimeException("Unexpected vector type " + typeElement); - } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorOp.java b/hat/core/src/main/java/hat/dialect/HATVectorOp.java index bc270440bbe..0daea61e53a 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorOp.java @@ -25,6 +25,7 @@ package hat.dialect; import jdk.incubator.code.CopyContext; +import jdk.incubator.code.TypeElement; import jdk.incubator.code.Value; import java.util.List; @@ -32,15 +33,21 @@ public abstract class HATVectorOp extends HATOp { private String varName; + private final TypeElement typeElement; + private final int vectorN; - public HATVectorOp(String varName, List operands) { + public HATVectorOp(String varName, TypeElement typeElement, int vectorN, List operands) { super(operands); this.varName = varName; + this.typeElement = typeElement; + this.vectorN = vectorN; } protected HATVectorOp(HATVectorOp that, CopyContext cc) { super(that, cc); this.varName = that.varName; + this.typeElement = that.typeElement; + this.vectorN = that.vectorN; } public String varName() { @@ -74,4 +81,16 @@ public String type() { return type; } } + + public String buildType() { + // floatN + if (typeElement.toString().startsWith("hat.buffer.Float")) { + return "float" + vectorN; + } + throw new RuntimeException("Unexpected vector type " + typeElement); + } + + public int vectorN() { + return vectorN; + } } \ No newline at end of file diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java index 6d60400a91b..293ab369475 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java @@ -39,7 +39,7 @@ public class HATVectorSelectLoadOp extends HATVectorOp { private final int lane; public HATVectorSelectLoadOp(String varName, TypeElement typeElement, int lane, List operands) { - super(varName, operands); + super(varName, typeElement, -1, operands); this.elementType = typeElement; this.lane = lane; } @@ -66,12 +66,7 @@ public Map externalize() { } public String mapLane() { - return switch (lane) { - case 0 -> "x"; - case 1 -> "y"; - case 2 -> "z"; - case 3 -> "w"; - default -> throw new InternalError("Invalid lane: " + lane); - }; + return super.mapLane(lane); } + } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java index 6b5ff6d093b..39902ece89f 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java @@ -41,7 +41,7 @@ public class HATVectorSelectStoreOp extends HATVectorOp { private final CoreOp.VarOp resultVarOp; public HATVectorSelectStoreOp(String varName, TypeElement typeElement, int lane, CoreOp.VarOp resultVarOp, List operands) { - super(varName, operands); + super(varName, typeElement, -1, operands); this.elementType = typeElement; this.lane = lane; this.resultVarOp = resultVarOp; @@ -70,13 +70,7 @@ public Map externalize() { } public String mapLane() { - return switch (lane) { - case 0 -> "x"; - case 1 -> "y"; - case 2 -> "z"; - case 3 -> "w"; - default -> throw new InternalError("Invalid lane: " + lane); - }; + return super.mapLane(lane); } public CoreOp.VarOp resultValue() { diff --git a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java index e69219260d1..9a9ec155388 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java @@ -41,7 +41,7 @@ public final class HATVectorStoreView extends HATVectorOp { private final VectorType vectorType; public HATVectorStoreView(String varName, TypeElement elementType, int storeN, VectorType vectorType, boolean isSharedOrPrivate, List operands) { - super(varName, operands); + super(varName, elementType, storeN, operands); this.elementType = elementType; this.storeN = storeN; this.isSharedOrPrivate = isSharedOrPrivate; @@ -71,15 +71,8 @@ public Map externalize() { return Map.of("hat.dialect.floatNStoreView." + varName(), elementType); } - public int storeN() { - return storeN; - } - public boolean isSharedOrPrivate() { return this.isSharedOrPrivate; } - public String buildType() { - return vectorType.type(); - } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java index 427d205ae18..e5aa2551bff 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java @@ -38,7 +38,7 @@ public class HATVectorVarLoadOp extends HATVectorOp { private final TypeElement typeElement; public HATVectorVarLoadOp(String varName, TypeElement typeElement, List operands) { - super(varName, operands); + super(varName, typeElement, 0, operands); this.typeElement = typeElement; } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java b/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java index d49c82d4742..df3d0068a73 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java @@ -40,7 +40,7 @@ public class HATVectorVarOp extends HATVectorOp { private final int loadN; public HATVectorVarOp(String varName, VarType typeElement, int loadN, List operands) { - super(varName, operands); + super(varName, typeElement, loadN, operands); this.typeElement = typeElement; this.loadN = loadN; } @@ -66,6 +66,7 @@ public Map externalize() { return Map.of("hat.dialect.vectorVarOp." + varName(), typeElement); } + @Override public String buildType() { // floatN if (typeElement.valueType().toString().startsWith("hat.buffer.Float")) { @@ -73,5 +74,4 @@ public String buildType() { } throw new RuntimeException("Unexpected vector type " + typeElement); } - } From 251f1a52e19fb2d3cb564386c8ccb8f3929b2cc9 Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 13:03:41 +0100 Subject: [PATCH 12/14] [hat] Log Config bits for OpenCL moved to OpenCL backend instantiation --- .../ffi/opencl/src/main/native/cpp/opencl_backend.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp index 11dc319f179..31f8e65129b 100644 --- a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp +++ b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp @@ -73,6 +73,11 @@ bool OpenCLBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memoryS OpenCLBackend::OpenCLBackend(int configBits) : Backend(new Config(configBits), new OpenCLQueue(this)) { + + if (config->info) { + std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl; + } + cl_int status; cl_uint platformc = 0; OPENCL_CHECK(clGetPlatformIDs(0, nullptr, &platformc), "clGetPlatformIDs"); @@ -303,10 +308,6 @@ const char *OpenCLBackend::errorMsg(cl_int status) { } extern "C" long getBackend(int configBits) { - Backend::Config config(configBits); - if (config.info) { - std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl; - } return reinterpret_cast(new OpenCLBackend(configBits)); } From df4c78fee802adb6929c13cf9e80b7d2460ecc56 Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Mon, 27 Oct 2025 13:48:44 +0100 Subject: [PATCH 13/14] [hat] config bits message fixed --- hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp index 31f8e65129b..dac43b3c094 100644 --- a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp +++ b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp @@ -75,7 +75,7 @@ OpenCLBackend::OpenCLBackend(int configBits) : Backend(new Config(configBits), new OpenCLQueue(this)) { if (config->info) { - std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl; + std::cerr << "[INFO] Config Bits = " << std::hex << configBits << std::dec << std::endl; } cl_int status; From dfb349ce91e256d580e70c0c4db6b0ff1ab6e5c4 Mon Sep 17 00:00:00 2001 From: Juan Fumero Date: Fri, 31 Oct 2025 09:24:14 +0100 Subject: [PATCH 14/14] [hat][cuda] buildType method for vload/vstore restored --- hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java | 8 ++++++++ .../src/main/java/hat/dialect/HATVectorSelectStoreOp.java | 1 + .../src/main/java/hat/dialect/HATVectorStoreView.java | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java index 79e831acdf5..cc36d310b44 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java @@ -74,4 +74,12 @@ public Map externalize() { public boolean isSharedOrPrivate() { return this.isSharedOrPrivate; } + + public String buildType() { + // floatN + if (vectorType.toString().startsWith("hat.buffer.Float")) { + return "float" + loadN; + } + throw new RuntimeException("Unexpected vector type " + vectorType); + } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java index 39902ece89f..69ec2b10ecd 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java @@ -76,4 +76,5 @@ public String mapLane() { public CoreOp.VarOp resultValue() { return resultVarOp; } + } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java index 9a9ec155388..345cf3026d2 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java @@ -75,4 +75,9 @@ public boolean isSharedOrPrivate() { return this.isSharedOrPrivate; } + @Override + public String buildType() { + // floatN + return vectorType.type(); + } }