diff --git a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java index 8e665abe807..2683d1ac188 100644 --- a/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java +++ b/hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaHATKernelBuilder.java @@ -28,10 +28,11 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorSelectLoadOp; -import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorSelectLoadOp; +import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorStoreView; import hat.dialect.HATVectorVarOp; import jdk.incubator.code.Op; @@ -250,4 +251,10 @@ public CudaHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildContext } return self(); } + + @Override + public CudaHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) { + identifier("make_" + hatVectorOfOp.buildType()).oparen(); + return self(); + } } diff --git a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java index 0e2ad5d52c9..a52bc8eb404 100644 --- a/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java +++ b/hat/backends/ffi/opencl/src/main/java/hat/backend/ffi/OpenCLHATKernelBuilder.java @@ -28,15 +28,18 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorSelectLoadOp; -import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorSelectLoadOp; +import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorStoreView; import hat.dialect.HATVectorVarOp; import jdk.incubator.code.Op; import jdk.incubator.code.Value; +import java.util.List; + public class OpenCLHATKernelBuilder extends C99HATKernelBuilder { @Override @@ -83,7 +86,7 @@ public OpenCLHATKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon Value dest = hatVectorStoreView.operands().get(0); Value index = hatVectorStoreView.operands().get(2); - identifier("vstore" + hatVectorStoreView.storeN()) + identifier("vstore" + hatVectorStoreView.vectorN()) .oparen() .varName(hatVectorStoreView) .comma() @@ -131,7 +134,7 @@ public OpenCLHATKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont Value source = hatVectorLoadOp.operands().get(0); Value index = hatVectorLoadOp.operands().get(1); - identifier("vload" + hatVectorLoadOp.loadN()) + identifier("vload" + hatVectorLoadOp.vectorN()) .oparen() .intConstZero() .comma() @@ -180,13 +183,11 @@ public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon @Override public OpenCLHATKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) { - oparen().typeName("half").cparen(); - // typeName("convert_half").oparen(); + oparen().halfType().cparen(); Value initValue = hatF16ConvOp.operands().getFirst(); if (initValue instanceof Op.Result r) { recurse(buildContext, r.op()); } - //cparen(); return self(); } @@ -204,4 +205,9 @@ public OpenCLHATKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte return self(); } + @Override + public OpenCLHATKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) { + oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen(); + return self(); + } } diff --git a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp index 115ebd1b3be..dac43b3c094 100644 --- a/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp +++ b/hat/backends/ffi/opencl/src/main/native/cpp/opencl_backend.cpp @@ -73,6 +73,11 @@ bool OpenCLBackend::getBufferFromDeviceIfDirty(void *memorySegment, long memoryS OpenCLBackend::OpenCLBackend(int configBits) : Backend(new Config(configBits), new OpenCLQueue(this)) { + + if (config->info) { + std::cerr << "[INFO] Config Bits = " << std::hex << configBits << std::dec << std::endl; + } + cl_int status; cl_uint platformc = 0; OPENCL_CHECK(clGetPlatformIDs(0, nullptr, &platformc), "clGetPlatformIDs"); @@ -303,7 +308,6 @@ const char *OpenCLBackend::errorMsg(cl_int status) { } extern "C" long getBackend(int configBits) { - std::cerr << "Opencl Driver =" << std::hex << configBits << std::dec << std::endl; return reinterpret_cast(new OpenCLBackend(configBits)); } diff --git a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java index a33fc95417a..fbbfc2270f2 100644 --- a/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java +++ b/hat/backends/jextracted/opencl/src/main/java/hat/backend/jextracted/OpenCLHatKernelBuilder.java @@ -28,15 +28,18 @@ import hat.codebuilders.CodeBuilder; import hat.codebuilders.ScopedCodeBuilderContext; import hat.dialect.HATF16ConvOp; -import hat.dialect.HATVectorSelectLoadOp; -import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorSelectLoadOp; +import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorStoreView; import hat.dialect.HATVectorVarOp; import jdk.incubator.code.Op; import jdk.incubator.code.Value; +import java.util.List; + public class OpenCLHatKernelBuilder extends C99HATKernelBuilder { @Override @@ -83,7 +86,7 @@ public OpenCLHatKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon Value dest = hatVectorStoreView.operands().get(0); Value index = hatVectorStoreView.operands().get(2); - identifier("vstore" + hatVectorStoreView.storeN()) + identifier("vstore" + hatVectorStoreView.vectorN()) .oparen() .varName(hatVectorStoreView) .comma() @@ -131,7 +134,7 @@ public OpenCLHatKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont Value source = hatVectorLoadOp.operands().get(0); Value index = hatVectorLoadOp.operands().get(1); - identifier("vload" + hatVectorLoadOp.loadN()) + identifier("vload" + hatVectorLoadOp.vectorN()) .oparen() .intConstZero() .comma() @@ -181,12 +184,10 @@ public OpenCLHatKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildCon @Override public OpenCLHatKernelBuilder hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp) { oparen().typeName("half").cparen(); - // typeName("convert_half").oparen(); Value initValue = hatF16ConvOp.operands().getFirst(); if (initValue instanceof Op.Result r) { recurse(buildContext, r.op()); } - //cparen(); return self(); } @@ -204,4 +205,10 @@ public OpenCLHatKernelBuilder hatVectorVarOp(ScopedCodeBuilderContext buildConte return self(); } + @Override + public OpenCLHatKernelBuilder genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp) { + oparen().identifier(hatVectorOfOp.buildType()).cparen().oparen(); + return self(); + } + } diff --git a/hat/core/src/main/java/hat/buffer/Buffer.java b/hat/core/src/main/java/hat/buffer/Buffer.java index e018b67395d..07efec4edbd 100644 --- a/hat/core/src/main/java/hat/buffer/Buffer.java +++ b/hat/core/src/main/java/hat/buffer/Buffer.java @@ -50,7 +50,8 @@ default void setState(int newState ){ default String getStateString(){ return BufferState.of(this).getStateString(); } - // default boolean isDeviceDirty(){ + + // default boolean isDeviceDirty(){ // return BufferState.of(this).isDeviceDirty(); // } // default boolean isHostChecked(){ diff --git a/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java b/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java index 87f7367420f..9c7b1d39a38 100644 --- a/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java +++ b/hat/core/src/main/java/hat/buffer/F32ArrayPadded.java @@ -65,13 +65,18 @@ default float[] arrayView() { return arr; } - // This is an intrinsic for HAT to create views. It does not execute code - // on the host side, at least for now. - default Float4 float4View(int index) { - return null; + default Float4.MutableImpl float4View(int index) { +// MemorySegment memorySegment = Buffer.getMemorySegment(this); +// float f1 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 0); +// float f2 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 1); +// float f3 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 2); +// float f4 = memorySegment.get(JAVA_FLOAT, ARRAY_OFFSET + index + 3); +// return Float4.makeMutable(Float4.of(f1, f2, f3, f4)); + return null; } - // This is an intrinsic for HAT to store views. It does not execute code - default void storeFloat4View(Float4 v, int index) {} + default void storeFloat4View(Float4 v, int index) { +// MemorySegment.copy(Buffer.getMemorySegment(this), JAVA_FLOAT, ARRAY_OFFSET, v.toArray(), index, 4); + } } diff --git a/hat/core/src/main/java/hat/buffer/Float4.java b/hat/core/src/main/java/hat/buffer/Float4.java index d7b0408976d..ffca332066e 100644 --- a/hat/core/src/main/java/hat/buffer/Float4.java +++ b/hat/core/src/main/java/hat/buffer/Float4.java @@ -24,8 +24,10 @@ */ package hat.buffer; -import hat.Accelerator; -import hat.ifacemapper.Schema; +import jdk.incubator.code.CodeReflection; + +import java.util.function.BiFunction; +import java.util.stream.IntStream; public interface Float4 extends HatVector { @@ -33,50 +35,76 @@ public interface Float4 extends HatVector { float y(); float z(); float w(); - void x(float x); - void y(float y); - void z(float z); - void w(float w); - Schema schema = Schema.of(Float4.class, - float4->float4.fields("x","y","z","w")); + record MutableImpl(float x, float y, float z, float w) implements Float4 { + public void x(float x) {} + public void y(float y) {} + public void z(float z) {} + public void w(float w) {} + } + + record ImmutableImpl(float x, float y, float z, float w) implements Float4 { + } + + /** + * Make a Mutable implementation (for the device side - e.g., the GPU) from an immutable implementation. + * + * @param float4 + * @return {@link Float4.MutableImpl} + */ + static Float4.MutableImpl makeMutable(Float4 float4) { + return new MutableImpl(float4.x(), float4.y(), float4.z(), float4.w()); + } + + static Float4 of(float x, float y, float z, float w) { + return new ImmutableImpl(x, y, z, w); + } - static Float4 create(Accelerator accelerator) { - return schema.allocate(accelerator, 1); + // Not implemented for the GPU yet + default Float4 lanewise(Float4 other, BiFunction f) { + float[] backA = this.toArray(); + float[] backB = other.toArray(); + float[] backC = new float[backA.length]; + IntStream.range(0, backA.length).forEach(j -> { + backC[j] = f.apply(backA[j], backB[j]); + }); + return of(backC[0], backC[1], backC[2], backC[3]); } static Float4 add(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, Float::sum); } static Float4 sub(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, (a, b) -> a - b); } static Float4 mul(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, (a, b) -> a * b); } static Float4 div(Float4 vA, Float4 vB) { - return null; + return vA.lanewise(vB, (a, b) -> a / b); } default Float4 add(Float4 vb) { - return null; + return Float4.add(this, vb); } default Float4 sub(Float4 vb) { - return null; + return Float4.sub(this, vb); } default Float4 mul(Float4 vb) { - return null; + return Float4.mul(this, vb); } default Float4 div(Float4 vb) { - return null; + return Float4.div(this, vb); } + // Not implemented for the GPU yet + @CodeReflection default float[] toArray() { return new float[] { x(), y(), z(), w() }; } diff --git a/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java b/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java index 20548818494..32213febee8 100644 --- a/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/BabylonOpBuilder.java @@ -30,6 +30,8 @@ import hat.dialect.HATF16ConvOp; import hat.dialect.HATF16VarLoadOp; import hat.dialect.HATF16VarOp; +import hat.dialect.HATVectorMakeOfOp; +import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATVectorBinaryOp; @@ -141,6 +143,10 @@ public interface BabylonOpBuilder> { T hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16ConvOp); + T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp); + + T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp); + default T recurse(ScopedCodeBuilderContext buildContext, Op op) { switch (op) { case CoreOp.VarAccessOp.VarLoadOp $ -> varLoadOp(buildContext, $); @@ -183,10 +189,12 @@ default T recurse(ScopedCodeBuilderContext buildContext, Op op) { case HATVectorSelectLoadOp $ -> hatSelectLoadOp(buildContext, $); case HATVectorSelectStoreOp $ -> hatSelectStoreOp(buildContext, $); case HATVectorVarLoadOp $ -> hatVectorVarLoadOp(buildContext, $); + case HATVectorOfOp $ -> hatVectorOfOps(buildContext, $); case HATF16VarOp $ -> hatF16VarOp(buildContext, $); case HATF16BinaryOp $ -> hatF16BinaryOp(buildContext, $); case HATF16VarLoadOp $ -> hatF16VarLoadOp(buildContext, $); case HATF16ConvOp $ -> hatF16ConvOp(buildContext, $); + case HATVectorMakeOfOp $ -> hatVectorMakeOf(buildContext, $); default -> throw new IllegalStateException("handle nesting of op " + op); } return (T) this; diff --git a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java index 0587c6a949d..2dd9676f930 100644 --- a/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/C99HATKernelBuilder.java @@ -33,6 +33,8 @@ import hat.dialect.HATGlobalThreadIdOp; import hat.dialect.HATLocalSizeOp; import hat.dialect.HATLocalThreadIdOp; +import hat.dialect.HATVectorMakeOfOp; +import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorVarLoadOp; import hat.ifacemapper.MappableIface; import hat.optools.FuncOpParams; @@ -274,6 +276,35 @@ public T hatF16VarLoadOp(ScopedCodeBuilderContext buildContext, HATF16VarLoadOp return self(); } + @Override + public T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp) { + identifier(hatVectorMakeOfOp.varName()); + return self(); + } + + public abstract T genVectorIdentifier(ScopedCodeBuilderContext builderContext, HATVectorOfOp hatVectorOfOp); + + @Override + public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + genVectorIdentifier(buildContext, hatVectorOp); + + List inputOperands = hatVectorOp.operands(); + int i; + for (i = 0; i < (inputOperands.size() - 1); i++) { + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + comma().space(); + } + // Last parameter + var operand = inputOperands.get(i); + if ((operand instanceof Op.Result r)) { + recurse(buildContext, r.op()); + } + cparen(); + return self(); + } public T kernelDeclaration(CoreOp.FuncOp funcOp) { return kernelPrefix().voidType().space().funcName(funcOp); diff --git a/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java b/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java index 4e5ebaebab0..eacddd64e4c 100644 --- a/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java +++ b/hat/core/src/main/java/hat/codebuilders/CodeBuilder.java @@ -507,6 +507,9 @@ public final T shortType() { return typeName("short"); } + public final T halfType() { + return typeName("half"); + } @Override public final T comment(String text) { @@ -531,6 +534,7 @@ public T label(String text) { public final T symbol(String text) { return emitText(text); } + @Override public final T typeName(String text) { return emitText(text); diff --git a/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java b/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java index cae5ae112e2..637a4d9f137 100644 --- a/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java +++ b/hat/core/src/main/java/hat/dialect/HATF16BinaryOp.java @@ -25,8 +25,6 @@ package hat.dialect; import jdk.incubator.code.CopyContext; -import jdk.incubator.code.Op; -import jdk.incubator.code.OpTransformer; import jdk.incubator.code.TypeElement; import jdk.incubator.code.Value; diff --git a/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java b/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java index 35364962b62..af2630e3af1 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorBinaryOp.java @@ -30,7 +30,7 @@ import java.util.List; -public abstract class HATVectorBinaryOp extends HATVectorViewOp { +public abstract class HATVectorBinaryOp extends HATVectorOp { public enum OpType { ADD("+"), SUB("-"), @@ -50,21 +50,20 @@ public String symbol() { private final TypeElement elementType; private final OpType operationType; - private final int vectorN; public HATVectorBinaryOp(String varName, TypeElement typeElement, OpType operationType, List operands) { - super(varName, operands); + int l = typeElement.toString().length(); + int vectorN = Integer.parseInt(typeElement.toString().substring(l - 1, l)); + super(varName, typeElement, vectorN, operands); this.elementType = typeElement; this.operationType = operationType; - int l = typeElement.toString().length(); - vectorN = Integer.parseInt(typeElement.toString().substring(l - 1, l)); + } public HATVectorBinaryOp(HATVectorBinaryOp op, CopyContext copyContext) { super(op, copyContext); this.elementType = op.elementType; this.operationType = op.operationType; - this.vectorN = op.vectorN; } @Override @@ -72,23 +71,8 @@ public TypeElement resultType() { return this.elementType; } - // @Override - //public Map externalize() { - // return Map.of("hat.dialect.floatNOp." + varName(), elementType); - // } - public OpType operationType() { return operationType; } - public int vectorN() { - return vectorN; - } - - public String buildType() { - if (elementType.toString().startsWith("hat.buffer.Float")) { - return "float" + vectorN; - } - throw new RuntimeException("Unexpected vector type " + elementType); - } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java index 402944c96e5..cc36d310b44 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorLoadOp.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -public class HATVectorLoadOp extends HATVectorViewOp { +public class HATVectorLoadOp extends HATVectorOp { private final TypeElement typeElement; private final TypeElement vectorType; @@ -41,7 +41,7 @@ public class HATVectorLoadOp extends HATVectorViewOp { private final boolean isSharedOrPrivate; public HATVectorLoadOp(String varName, TypeElement typeElement, TypeElement vectorType, int loadN, boolean isShared, List operands) { - super(varName, operands); + super(varName, typeElement, loadN, operands); this.typeElement = typeElement; this.loadN = loadN; this.vectorType = vectorType; @@ -71,12 +71,8 @@ public Map externalize() { return Map.of("hat.dialect.vectorLoadView." + varName(), typeElement); } - public TypeElement vectorType() { - return vectorType; - } - - public int loadN() { - return loadN; + public boolean isSharedOrPrivate() { + return this.isSharedOrPrivate; } public String buildType() { @@ -86,8 +82,4 @@ public String buildType() { } throw new RuntimeException("Unexpected vector type " + vectorType); } - - public boolean isSharedOrPrivate() { - return this.isSharedOrPrivate; - } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java b/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java new file mode 100644 index 00000000000..3e2902b25c0 --- /dev/null +++ b/hat/core/src/main/java/hat/dialect/HATVectorMakeOfOp.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package hat.dialect; + +import jdk.incubator.code.CopyContext; +import jdk.incubator.code.Op; +import jdk.incubator.code.OpTransformer; +import jdk.incubator.code.TypeElement; +import jdk.incubator.code.Value; + +import java.util.List; +import java.util.Map; + +public class HATVectorMakeOfOp extends HATVectorOp { + + private final TypeElement typeElement; + private final int loadN; + + public HATVectorMakeOfOp(String varName, TypeElement typeElement, int loadN, List operands) { + super(varName, typeElement, loadN, operands); + this.typeElement = typeElement; + this.loadN = loadN; + } + + public HATVectorMakeOfOp(HATVectorMakeOfOp op, CopyContext copyContext) { + super(op, copyContext); + this.typeElement = op.typeElement; + this.loadN = op.loadN; + } + + @Override + public Op transform(CopyContext copyContext, OpTransformer opTransformer) { + return new HATVectorMakeOfOp(this, copyContext); + } + + @Override + public TypeElement resultType() { + return typeElement; + } + + @Override + public Map externalize() { + return Map.of("hat.dialect.makeOf." + varName(), typeElement); + } + +} diff --git a/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java b/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java new file mode 100644 index 00000000000..442ad5a94a3 --- /dev/null +++ b/hat/core/src/main/java/hat/dialect/HATVectorOfOp.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package hat.dialect; + +import jdk.incubator.code.CopyContext; +import jdk.incubator.code.Op; +import jdk.incubator.code.OpTransformer; +import jdk.incubator.code.TypeElement; +import jdk.incubator.code.Value; + +import java.util.List; +import java.util.Map; + +public class HATVectorOfOp extends HATVectorOp { + + private final TypeElement typeElement; + private final int loadN; + + public HATVectorOfOp(TypeElement typeElement, int loadN, List operands) { + super("", typeElement, loadN, operands); + this.typeElement = typeElement; + this.loadN = loadN; + } + + public HATVectorOfOp(HATVectorOfOp op, CopyContext copyContext) { + super(op, copyContext); + this.typeElement = op.typeElement; + this.loadN = op.loadN; + } + + @Override + public Op transform(CopyContext copyContext, OpTransformer opTransformer) { + return new HATVectorOfOp(this, copyContext); + } + + @Override + public TypeElement resultType() { + return typeElement; + } + + @Override + public Map externalize() { + return Map.of("hat.dialect.vectorOf." + varName(), typeElement); + } + + +} diff --git a/hat/core/src/main/java/hat/dialect/HATVectorViewOp.java b/hat/core/src/main/java/hat/dialect/HATVectorOp.java similarity index 73% rename from hat/core/src/main/java/hat/dialect/HATVectorViewOp.java rename to hat/core/src/main/java/hat/dialect/HATVectorOp.java index 0cef7659239..0daea61e53a 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorViewOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorOp.java @@ -25,22 +25,29 @@ package hat.dialect; import jdk.incubator.code.CopyContext; +import jdk.incubator.code.TypeElement; import jdk.incubator.code.Value; import java.util.List; -public abstract class HATVectorViewOp extends HATOp { +public abstract class HATVectorOp extends HATOp { private String varName; + private final TypeElement typeElement; + private final int vectorN; - public HATVectorViewOp(String varName, List operands) { + public HATVectorOp(String varName, TypeElement typeElement, int vectorN, List operands) { super(operands); this.varName = varName; + this.typeElement = typeElement; + this.vectorN = vectorN; } - protected HATVectorViewOp(HATVectorViewOp that, CopyContext cc) { + protected HATVectorOp(HATVectorOp that, CopyContext cc) { super(that, cc); this.varName = that.varName; + this.typeElement = that.typeElement; + this.vectorN = that.vectorN; } public String varName() { @@ -74,4 +81,16 @@ public String type() { return type; } } + + public String buildType() { + // floatN + if (typeElement.toString().startsWith("hat.buffer.Float")) { + return "float" + vectorN; + } + throw new RuntimeException("Unexpected vector type " + typeElement); + } + + public int vectorN() { + return vectorN; + } } \ No newline at end of file diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java index 1f74303e2fd..293ab369475 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectLoadOp.java @@ -33,13 +33,13 @@ import java.util.List; import java.util.Map; -public class HATVectorSelectLoadOp extends HATVectorViewOp { +public class HATVectorSelectLoadOp extends HATVectorOp { private final TypeElement elementType; private final int lane; public HATVectorSelectLoadOp(String varName, TypeElement typeElement, int lane, List operands) { - super(varName, operands); + super(varName, typeElement, -1, operands); this.elementType = typeElement; this.lane = lane; } @@ -66,12 +66,7 @@ public Map externalize() { } public String mapLane() { - return switch (lane) { - case 0 -> "x"; - case 1 -> "y"; - case 2 -> "z"; - case 3 -> "w"; - default -> throw new InternalError("Invalid lane: " + lane); - }; + return super.mapLane(lane); } + } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java index e527f2ca1dd..69ec2b10ecd 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorSelectStoreOp.java @@ -34,14 +34,14 @@ import java.util.List; import java.util.Map; -public class HATVectorSelectStoreOp extends HATVectorViewOp { +public class HATVectorSelectStoreOp extends HATVectorOp { private final TypeElement elementType; private final int lane; private final CoreOp.VarOp resultVarOp; public HATVectorSelectStoreOp(String varName, TypeElement typeElement, int lane, CoreOp.VarOp resultVarOp, List operands) { - super(varName, operands); + super(varName, typeElement, -1, operands); this.elementType = typeElement; this.lane = lane; this.resultVarOp = resultVarOp; @@ -70,16 +70,11 @@ public Map externalize() { } public String mapLane() { - return switch (lane) { - case 0 -> "x"; - case 1 -> "y"; - case 2 -> "z"; - case 3 -> "w"; - default -> throw new InternalError("Invalid lane: " + lane); - }; + return super.mapLane(lane); } public CoreOp.VarOp resultValue() { return resultVarOp; } + } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java index a936422ca88..345cf3026d2 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorStoreView.java @@ -33,7 +33,7 @@ import java.util.List; import java.util.Map; -public final class HATVectorStoreView extends HATVectorViewOp { +public final class HATVectorStoreView extends HATVectorOp { private final TypeElement elementType; private final int storeN; @@ -41,7 +41,7 @@ public final class HATVectorStoreView extends HATVectorViewOp { private final VectorType vectorType; public HATVectorStoreView(String varName, TypeElement elementType, int storeN, VectorType vectorType, boolean isSharedOrPrivate, List operands) { - super(varName, operands); + super(varName, elementType, storeN, operands); this.elementType = elementType; this.storeN = storeN; this.isSharedOrPrivate = isSharedOrPrivate; @@ -71,15 +71,13 @@ public Map externalize() { return Map.of("hat.dialect.floatNStoreView." + varName(), elementType); } - public int storeN() { - return storeN; - } - public boolean isSharedOrPrivate() { return this.isSharedOrPrivate; } + @Override public String buildType() { + // floatN return vectorType.type(); } } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java b/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java index c97086b863b..e5aa2551bff 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorVarLoadOp.java @@ -33,12 +33,12 @@ import java.util.List; import java.util.Map; -public class HATVectorVarLoadOp extends HATVectorViewOp { +public class HATVectorVarLoadOp extends HATVectorOp { private final TypeElement typeElement; public HATVectorVarLoadOp(String varName, TypeElement typeElement, List operands) { - super(varName, operands); + super(varName, typeElement, 0, operands); this.typeElement = typeElement; } diff --git a/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java b/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java index fa9fd4a0a8c..df3d0068a73 100644 --- a/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java +++ b/hat/core/src/main/java/hat/dialect/HATVectorVarOp.java @@ -34,13 +34,13 @@ import java.util.List; import java.util.Map; -public class HATVectorVarOp extends HATVectorViewOp { +public class HATVectorVarOp extends HATVectorOp { private final VarType typeElement; private final int loadN; public HATVectorVarOp(String varName, VarType typeElement, int loadN, List operands) { - super(varName, operands); + super(varName, typeElement, loadN, operands); this.typeElement = typeElement; this.loadN = loadN; } @@ -66,6 +66,7 @@ public Map externalize() { return Map.of("hat.dialect.vectorVarOp." + varName(), typeElement); } + @Override public String buildType() { // floatN if (typeElement.valueType().toString().startsWith("hat.buffer.Float")) { @@ -73,5 +74,4 @@ public String buildType() { } throw new RuntimeException("Unexpected vector type " + typeElement); } - } diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyTier.java b/hat/core/src/main/java/hat/phases/HATDialectifyTier.java index bc9432cf0a5..e6e86c99863 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyTier.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyTier.java @@ -38,6 +38,7 @@ public class HATDialectifyTier implements Function private List hatPhases = new ArrayList<>(); public HATDialectifyTier(Accelerator accelerator) { + // barriers hatPhases.add(new HATDialectifyBarrierPhase(accelerator)); // Memory hatPhases.add(new HATDialectifyMemoryPhase.SharedPhase(accelerator)); @@ -50,16 +51,20 @@ public HATDialectifyTier(Accelerator accelerator) { hatPhases.add(new HATDialectifyThreadsPhase.LocalSizePhase(accelerator)); hatPhases.add(new HATDialectifyThreadsPhase.BlockPhase(accelerator)); - // views + // views for vector types hatPhases.add(new HATDialectifyVectorOpPhase.Float4LoadPhase(accelerator)); + hatPhases.add(new HATDialectifyVectorOpPhase.Float4OfPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.AddPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.SubPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.MulPhase(accelerator)); hatPhases.add(new HATDialectifyVectorOpPhase.DivPhase(accelerator)); - + hatPhases.add(new HATDialectifyVectorOpPhase.MakeMutable(accelerator)); hatPhases.add(new HATDialectifyVectorStorePhase.Float4StorePhase(accelerator)); + // Vector Select individual lines hatPhases.add(new HATDialectifyVectorSelectPhase(accelerator)); + + // F16 type hatPhases.add(new HATDialectifyFP16Phase(accelerator)); } diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java index 5e3d35468ba..6b53aeb3484 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorOpPhase.java @@ -30,11 +30,13 @@ import hat.dialect.HATVectorAddOp; import hat.dialect.HATVectorDivOp; import hat.dialect.HATVectorLoadOp; +import hat.dialect.HATVectorMakeOfOp; import hat.dialect.HATVectorMulOp; +import hat.dialect.HATVectorOfOp; import hat.dialect.HATVectorSubOp; import hat.dialect.HATVectorVarLoadOp; import hat.dialect.HATVectorVarOp; -import hat.dialect.HATVectorViewOp; +import hat.dialect.HATVectorOp; import hat.dialect.HATVectorBinaryOp; import hat.optools.OpTk; import jdk.incubator.code.CodeElement; @@ -53,7 +55,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public abstract class HATDialectifyVectorOpPhase implements HATDialect{ +public abstract class HATDialectifyVectorOpPhase implements HATDialect { protected final Accelerator accelerator; @Override public Accelerator accelerator(){ @@ -80,10 +82,12 @@ private HATVectorBinaryOp.OpType getBinaryOpType(JavaOp.InvokeOp invokeOp) { public enum OpView { FLOAT4_LOAD("float4View"), + OF("of"), ADD("add"), SUB("sub"), MUL("mul"), - DIV("div"); + DIV("div"), + MAKE_MUTABLE("makeMutable"); final String methodName; OpView(String methodName) { this.methodName = methodName; @@ -107,8 +111,8 @@ private String findNameVector(Value v) { return findNameVector(varLoadOp); } else { // Leaf of tree - - if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp hatVectorViewOp) { - return hatVectorViewOp.varName(); + if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) { + return hatVectorOp.varName(); } return null; } @@ -175,7 +179,7 @@ private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { if (r.op() instanceof CoreOp.VarOp varOp) { List inputOperandsVarOp = invokeOp.operands(); List outputOperandsVarOp = context.getValues(inputOperandsVarOp); - HATVectorViewOp memoryViewOp = new HATVectorLoadOp(varOp.varName(), varOp.resultType(), invokeOp.resultType(), 4, isShared, outputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorLoadOp(varOp.varName(), varOp.resultType(), invokeOp.resultType(), 4, isShared, outputOperandsVarOp); Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(invokeOp.result(), hatLocalResult); @@ -186,7 +190,7 @@ private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { //context.mapValue(varOp.result(), context.getValue(varOp.operands().getFirst())); List inputOperandsVarOp = varOp.operands(); List outputOperandsVarOp = context.getValues(inputOperandsVarOp); - HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(varOp.result(), hatLocalResult); @@ -200,6 +204,51 @@ private CoreOp.FuncOp dialectifyVectorLoad(CoreOp.FuncOp funcOp) { return funcOp; } + private CoreOp.FuncOp dialectifyVectorOf(CoreOp.FuncOp funcOp) { + var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorOf" ); + before(here,funcOp); + Stream> float4NodesInvolved = funcOp.elements() + .mapMulti((codeElement, consumer) -> { + if (codeElement instanceof JavaOp.InvokeOp invokeOp) { + if (isVectorOperation(invokeOp)) { + consumer.accept(invokeOp); + Set uses = invokeOp.result().uses(); + for (Op.Result result : uses) { + if (result.op() instanceof CoreOp.VarOp varOp) { + consumer.accept(varOp); + } + } + } + } + }); + + Set> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); + + funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> { + CopyContext context = blockBuilder.context(); + if (!nodesInvolved.contains(op)) { + blockBuilder.op(op); + } else if (op instanceof JavaOp.InvokeOp invokeOp) { + List inputOperandsVarOp = invokeOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + HATVectorOfOp memoryViewOp = new HATVectorOfOp(invokeOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); + memoryViewOp.setLocation(invokeOp.location()); + context.mapValue(invokeOp.result(), hatLocalResult); + } else if (op instanceof CoreOp.VarOp varOp) { + List inputOperandsVarOp = varOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); + memoryViewOp.setLocation(varOp.location()); + context.mapValue(varOp.result(), hatLocalResult); + } + return blockBuilder; + }); + after(here, funcOp); + return funcOp; + } + private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { var here = OpTk.CallSite.of(this.getClass(), "dialectifyVectorBinaryOps"); before(here, funcOp); @@ -228,9 +277,6 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { funcOp = OpTk.transform(here, funcOp, nodesInvolved::contains, (blockBuilder, op) -> { CopyContext context = blockBuilder.context(); - // if (!nodesInvolved.contains(op)) { - // blockBuilder.op(op); - //} else if (op instanceof JavaOp.InvokeOp invokeOp) { Op.Result result = invokeOp.result(); List inputOperands = invokeOp.operands(); @@ -239,7 +285,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { for (Op.Result r : collect) { if (r.op() instanceof CoreOp.VarOp varOp) { HATVectorBinaryOp.OpType binaryOpType = binaryOperation.get(invokeOp); - HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOpType, varOp.varName(), invokeOp.resultType(), outputOperands); + HATVectorOp memoryViewOp = buildVectorBinaryOp(binaryOpType, varOp.varName(), invokeOp.resultType(), outputOperands); Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(invokeOp.result(), hatVectorOpResult); @@ -249,7 +295,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { } else if (op instanceof CoreOp.VarOp varOp) { List inputOperandsVarOp = varOp.operands(); List outputOperandsVarOp = context.getValues(inputOperandsVarOp); - HATVectorViewOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); Op.Result hatVectorResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varOp.location()); context.mapValue(varOp.result(), hatVectorResult); @@ -260,7 +306,53 @@ private CoreOp.FuncOp dialectifyVectorBinaryOps(CoreOp.FuncOp funcOp) { return funcOp; } - private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp funcOp) { + private CoreOp.FuncOp dialectifyMutableOf(CoreOp.FuncOp funcOp) { + var here = OpTk.CallSite.of(this.getClass(), "dialectifyMutableOf" ); + before(here,funcOp); + Stream> float4NodesInvolved = funcOp.elements() + .mapMulti((codeElement, consumer) -> { + if (codeElement instanceof JavaOp.InvokeOp invokeOp) { + if (isVectorOperation(invokeOp)) { + consumer.accept(invokeOp); + Set uses = invokeOp.result().uses(); + for (Op.Result result : uses) { + if (result.op() instanceof CoreOp.VarOp varOp) { + consumer.accept(varOp); + } + } + } + } + }); + + Set> nodesInvolved = float4NodesInvolved.collect(Collectors.toSet()); + + funcOp = OpTk.transform(here, funcOp,(blockBuilder, op) -> { + CopyContext context = blockBuilder.context(); + if (!nodesInvolved.contains(op)) { + blockBuilder.op(op); + } else if (op instanceof JavaOp.InvokeOp invokeOp) { + List inputOperandsVarOp = invokeOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + String varName = findNameVector(invokeOp.operands().getFirst()); + HATVectorMakeOfOp makeOf = new HATVectorMakeOfOp(varName, invokeOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(makeOf); + makeOf.setLocation(invokeOp.location()); + context.mapValue(invokeOp.result(), hatLocalResult); + } else if (op instanceof CoreOp.VarOp varOp) { + List inputOperandsVarOp = varOp.operands(); + List outputOperandsVarOp = context.getValues(inputOperandsVarOp); + HATVectorOp memoryViewOp = new HATVectorVarOp(varOp.varName(), varOp.resultType(), 4, outputOperandsVarOp); + Op.Result hatLocalResult = blockBuilder.op(memoryViewOp); + memoryViewOp.setLocation(varOp.location()); + context.mapValue(varOp.result(), hatLocalResult); + } + return blockBuilder; + }); + after(here, funcOp); + return funcOp; + } + + private CoreOp.FuncOp dialectifyVectorBinaryWithConcatenationOps(CoreOp.FuncOp funcOp) { var here = OpTk.CallSite.of(this.getClass(), "dialectifyBinaryWithConcatenation"); before(here, funcOp); Map binaryOperation = new HashMap<>(); @@ -299,7 +391,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp f } else if (op instanceof JavaOp.InvokeOp invokeOp) { List inputOperands = invokeOp.operands(); List outputOperands = context.getValues(inputOperands); - HATVectorViewOp memoryViewOp = buildVectorBinaryOp(binaryOperation.get(invokeOp), "null", invokeOp.resultType(), outputOperands); + HATVectorOp memoryViewOp = buildVectorBinaryOp(binaryOperation.get(invokeOp), "null", invokeOp.resultType(), outputOperands); Op.Result hatVectorOpResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(invokeOp.location()); context.mapValue(invokeOp.result(), hatVectorOpResult); @@ -307,7 +399,7 @@ private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp f List inputOperandsVarLoad = varLoadOp.operands(); List outputOperandsVarLoad = context.getValues(inputOperandsVarLoad); String varLoadName = findNameVector(varLoadOp); - HATVectorViewOp memoryViewOp = new HATVectorVarLoadOp(varLoadName, varLoadOp.resultType(), outputOperandsVarLoad); + HATVectorOp memoryViewOp = new HATVectorVarLoadOp(varLoadName, varLoadOp.resultType(), outputOperandsVarLoad); Op.Result hatVectorResult = blockBuilder.op(memoryViewOp); memoryViewOp.setLocation(varLoadOp.location()); context.mapValue(varLoadOp.result(), hatVectorResult); @@ -322,10 +414,14 @@ private CoreOp.FuncOp dialectifyVectorBinaryWithContatenationOps(CoreOp.FuncOp f public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { if (Objects.requireNonNull(vectorOperation) == OpView.FLOAT4_LOAD) { funcOp = dialectifyVectorLoad(funcOp); + } else if (Objects.requireNonNull(vectorOperation) == OpView.OF) { + funcOp = dialectifyVectorOf(funcOp); + } else if (Objects.requireNonNull(vectorOperation) == OpView.MAKE_MUTABLE) { + funcOp = dialectifyMutableOf(funcOp); } else { // Find binary operations funcOp = dialectifyVectorBinaryOps(funcOp); - funcOp = dialectifyVectorBinaryWithContatenationOps(funcOp); + funcOp = dialectifyVectorBinaryWithConcatenationOps(funcOp); } return funcOp; } @@ -344,13 +440,27 @@ public DivPhase(Accelerator accelerator) { } } - public static class Float4LoadPhase extends HATDialectifyVectorOpPhase{ + public static class MakeMutable extends HATDialectifyVectorOpPhase{ + + public MakeMutable(Accelerator accelerator) { + super(accelerator, OpView.MAKE_MUTABLE); + } + } + + public static class Float4LoadPhase extends HATDialectifyVectorOpPhase { public Float4LoadPhase(Accelerator accelerator) { super(accelerator, OpView.FLOAT4_LOAD); } } + public static class Float4OfPhase extends HATDialectifyVectorOpPhase { + + public Float4OfPhase(Accelerator accelerator) { + super(accelerator, OpView.OF); + } + } + public static class MulPhase extends HATDialectifyVectorOpPhase{ public MulPhase(Accelerator accelerator) { diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java index ac476a86e4b..0d50b4c3d7c 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorSelectPhase.java @@ -27,7 +27,7 @@ import hat.Accelerator; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; -import hat.dialect.HATVectorViewOp; +import hat.dialect.HATVectorOp; import hat.optools.OpTk; import jdk.incubator.code.CodeElement; import jdk.incubator.code.CopyContext; @@ -85,7 +85,7 @@ private String findNameVector(Value v) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { return findNameVector(varLoadOp); } else { - if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp vectorViewOp) { + if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp vectorViewOp) { return vectorViewOp.varName(); } return null; @@ -141,7 +141,7 @@ private CoreOp.FuncOp vloadSelectPhase(CoreOp.FuncOp funcOp) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { List outputOperandsInvokeOp = context.getValues(inputInvokeOp); int lane = getLane(invokeOp.invokeDescriptor().name()); - HATVectorViewOp vSelectOp; + HATVectorOp vSelectOp; String name = findNameVector(varLoadOp); if (invokeOp.resultType() != JavaType.VOID) { vSelectOp = new HATVectorSelectLoadOp(name, invokeOp.resultType(), lane, outputOperandsInvokeOp); @@ -198,7 +198,7 @@ private CoreOp.FuncOp vstoreSelectPhase(CoreOp.FuncOp funcOp) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { List outputOperandsInvokeOp = context.getValues(inputInvokeOp); int lane = getLane(invokeOp.invokeDescriptor().name()); - HATVectorViewOp vSelectOp; + HATVectorOp vSelectOp; String name = findNameVector(varLoadOp); if (invokeOp.resultType() == JavaType.VOID) { // The operand 1 in the store is the address (lane) diff --git a/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java b/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java index 4ef09df7db2..f6df79913fc 100644 --- a/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java +++ b/hat/core/src/main/java/hat/phases/HATDialectifyVectorStorePhase.java @@ -28,7 +28,7 @@ import hat.dialect.HATLocalVarOp; import hat.dialect.HATPrivateVarOp; import hat.dialect.HATVectorStoreView; -import hat.dialect.HATVectorViewOp; +import hat.dialect.HATVectorOp; import hat.optools.OpTk; import jdk.incubator.code.CodeElement; import jdk.incubator.code.CopyContext; @@ -85,8 +85,8 @@ private String findNameVector(CoreOp.VarAccessOp.VarLoadOp varLoadOp) { private String findNameVector(Value v) { if (v instanceof Op.Result r && r.op() instanceof CoreOp.VarAccessOp.VarLoadOp varLoadOp) { return findNameVector(varLoadOp); - } else if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorViewOp hatVectorViewOp) { - return hatVectorViewOp.varName(); + } else if (v instanceof CoreOp.Result r && r.op() instanceof HATVectorOp hatVectorOp) { + return hatVectorOp.varName(); }else{ return null; } @@ -134,8 +134,8 @@ public CoreOp.FuncOp apply(CoreOp.FuncOp funcOp) { String name = findNameVector(v); boolean isSharedOrPrivate = findIsSharedOrPrivateSpace(invokeOp.operands().get(0)); - HATVectorViewOp storeView = switch (vectorOperation) { - case FLOAT4_STORE -> new HATVectorStoreView(name, invokeOp.resultType(), 4, HATVectorViewOp.VectorType.FLOAT4, isSharedOrPrivate, outputOperandsVarOp); + HATVectorOp storeView = switch (vectorOperation) { + case FLOAT4_STORE -> new HATVectorStoreView(name, invokeOp.resultType(), 4, HATVectorOp.VectorType.FLOAT4, isSharedOrPrivate, outputOperandsVarOp); }; Op.Result hatLocalResult = blockBuilder.op(storeView); storeView.setLocation(invokeOp.location()); diff --git a/hat/tests/src/main/java/hat/test/TestVectorTypes.java b/hat/tests/src/main/java/hat/test/TestVectorTypes.java index 0f7a2196e4d..165d9da8f5f 100644 --- a/hat/tests/src/main/java/hat/test/TestVectorTypes.java +++ b/hat/tests/src/main/java/hat/test/TestVectorTypes.java @@ -24,7 +24,12 @@ */ package hat.test; -import hat.*; +import hat.Accelerator; +import hat.ComputeContext; +import hat.ComputeRange; +import hat.GlobalMesh1D; +import hat.KernelContext; +import hat.LocalMesh1D; import hat.backend.Backend; import hat.buffer.Buffer; import hat.buffer.F32ArrayPadded; @@ -56,7 +61,7 @@ public static void vectorOps01(@RO KernelContext kernelContext, @RO F32ArrayPadd public static void vectorOps02(@RO KernelContext kernelContext, @RO F32ArrayPadded a, @RW F32ArrayPadded b) { if (kernelContext.gix < kernelContext.gsx) { int index = kernelContext.gix; - Float4 vA = a.float4View(index * 4); + Float4.MutableImpl vA = a.float4View(index * 4); float scaleX = vA.x() * 10.0f; vA.x(scaleX); b.storeFloat4View(vA, index * 4); @@ -67,16 +72,22 @@ public static void vectorOps02(@RO KernelContext kernelContext, @RO F32ArrayPadd public static void vectorOps03(@RO KernelContext kernelContext, @RO F32ArrayPadded a, @RW F32ArrayPadded b) { if (kernelContext.gix < kernelContext.gsx) { int index = kernelContext.gix; + + // Obtain a view of the input data as a float4 and + // store that view in private memory Float4 vA = a.float4View(index * 4); + + // operate with the float4 float scaleX = vA.x() * 10.0f; float scaleY = vA.y() * 20.0f; float scaleZ = vA.z() * 30.0f; float scaleW = vA.w() * 40.0f; - vA.x(scaleX); - vA.y(scaleY); - vA.z(scaleZ); - vA.w(scaleW); - b.storeFloat4View(vA, index * 4); + + // Create a float4 within the device code + Float4 vResult = Float4.of(scaleX, scaleY, scaleZ, scaleW); + + // store the float4 from private memory to global memory + b.storeFloat4View(vResult, index * 4); } } @@ -84,7 +95,7 @@ public static void vectorOps03(@RO KernelContext kernelContext, @RO F32ArrayPadd public static void vectorOps04(@RO KernelContext kernelContext, @RO F32ArrayPadded a, @RW F32ArrayPadded b) { if (kernelContext.gix < kernelContext.gsx) { int index = kernelContext.gix; - Float4 vA = a.float4View(index * 4); + Float4.MutableImpl vA = a.float4View(index * 4); vA.x(vA.x() * 10.0f); vA.y(vA.y() * 20.0f); vA.z(vA.z() * 30.0f); @@ -230,6 +241,17 @@ public static void vectorOps12(@RO KernelContext kernelContext, @RO F32ArrayPadd } } + @CodeReflection + public static void vectorOps14(@RO KernelContext kernelContext, @RW F32ArrayPadded a) { + if (kernelContext.gix < kernelContext.gsx) { + int index = kernelContext.gix; + Float4 vA = a.float4View(index * 4); + Float4.MutableImpl vB = Float4.makeMutable(vA); + vB.x(10.0f); + a.storeFloat4View(vB, index * 4); + } + } + @CodeReflection public static void computeGraph01(@RO ComputeContext cc, @RO F32ArrayPadded a, @RO F32ArrayPadded b, @RW F32ArrayPadded c, int size) { // Note: we need to launch N threads / vectorWidth -> size / 4 for this example @@ -315,6 +337,13 @@ public static void computeGraph12(@RO ComputeContext cc, @RO F32ArrayPadded a, cc.dispatchKernel(computeRange, kernelContext -> TestVectorTypes.vectorOps12(kernelContext, a, b)); } + @CodeReflection + public static void computeGraph14(@RO ComputeContext cc, @RW F32ArrayPadded a, int size) { + // Note: we need to launch N threads / vectorWidth -> size / 4 for this example + ComputeRange computeRange = new ComputeRange(new GlobalMesh1D(size/4)); + cc.dispatchKernel(computeRange, kernelContext -> TestVectorTypes.vectorOps14(kernelContext, a)); + } + @HatTest public void testVectorTypes01() { final int size = 1024; @@ -569,4 +598,65 @@ public void testVectorTypes12() { HatAsserts.assertEquals(arrayA.array(i), arrayB.array(i), 0.001f); } } + + @HatTest + public void testVectorTypes13() { + // Test the CPU implementation of Float4 + Float4 vA = Float4.of(1, 2, 3, 4); + Float4 vB = Float4.of(4, 3, 2, 1); + Float4 vC = Float4.add(vA, vB); + Float4 expectedSum = Float4.of(vA.x() + vB.x(), + vA.y() + vB.y(), + vA.z() + vB.z(), + vA.w() + vB.w() + ); + HatAsserts.assertEquals(expectedSum, vC, 0.001f); + + Float4 vD = Float4.sub(vA, vB); + Float4 expectedSub = Float4.of( + vA.x() - vB.x(), + vA.y() - vB.y(), + vA.z() - vB.z(), + vA.w() - vB.w() + ); + HatAsserts.assertEquals(expectedSub, vD, 0.001f); + + Float4 vE = Float4.mul(vA, vB); + Float4 expectedMul = Float4.of( + vA.x() * vB.x(), + vA.y() * vB.y(), + vA.z() * vB.z(), + vA.w() * vB.w() + ); + HatAsserts.assertEquals(expectedMul, vE, 0.001f); + + Float4 vF = Float4.div(vA, vB); + Float4 expectedDiv = Float4.of( + vA.x() / vB.x(), + vA.y() / vB.y(), + vA.z() / vB.z(), + vA.w() / vB.w() + ); + HatAsserts.assertEquals(expectedDiv, vF, 0.001f); + } + + @HatTest + public void testVectorTypes14() { + final int size = 1024; + var accelerator = new Accelerator(MethodHandles.lookup(), Backend.FIRST); + var arrayA = F32ArrayPadded.create(accelerator, size); + + Random r = new Random(73); + for (int i = 0; i < size; i++) { + arrayA.array(i, r.nextFloat()); + } + + accelerator.compute(cc -> TestVectorTypes.computeGraph14(cc, arrayA, size)); + + for (int i = 0; i < size; i += 4) { + HatAsserts.assertEquals(10.0f, arrayA.array(i), 0.001f); + } + } + } + diff --git a/hat/tests/src/main/java/hat/test/engine/HatAsserts.java b/hat/tests/src/main/java/hat/test/engine/HatAsserts.java index 853c9f2388b..cee0b7ccdb9 100644 --- a/hat/tests/src/main/java/hat/test/engine/HatAsserts.java +++ b/hat/tests/src/main/java/hat/test/engine/HatAsserts.java @@ -24,6 +24,8 @@ */ package hat.test.engine; +import hat.buffer.Float4; + public class HatAsserts { public static void assertEquals(int expected, int actual) { @@ -50,6 +52,18 @@ public static void assertEquals(double expected, double actual, double delta) { } } + public static void assertEquals(Float4 expected, Float4 actual, float delta) { + float[] arrayExpected = expected.toArray(); + float[] arrayActual = actual.toArray(); + for (int i = 0; i < 4; i++) { + var expectedValue = arrayExpected[i]; + var actualValue = arrayActual[i]; + if (Math.abs(expectedValue - actualValue) > delta) { + throw new HatAssertionError("Expected: " + expectedValue + " != actual: " + actualValue); + } + } + } + public static void assertTrue(boolean isCorrect) { if (!isCorrect) { throw new HatAssertionError("Expected: " + isCorrect); diff --git a/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java b/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java index 991814b17a6..0f8d5b3ec76 100644 --- a/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java +++ b/hat/tools/src/main/java/hat/tools/text/JavaHATCodeBuilder.java @@ -28,6 +28,9 @@ import hat.codebuilders.HATCodeBuilderWithContext; import hat.dialect.HATBlockThreadIdOp; import hat.dialect.HATF16ConvOp; +import hat.dialect.HATVectorMakeOfOp; +import hat.dialect.HATVectorOfOp; +import hat.dialect.HATVectorOp; import hat.dialect.HATVectorSelectLoadOp; import hat.dialect.HATVectorSelectStoreOp; import hat.dialect.HATF16BinaryOp; @@ -203,6 +206,18 @@ public T hatF16ConvOp(ScopedCodeBuilderContext buildContext, HATF16ConvOp hatF16 return self(); } + @Override + public T hatVectorOfOps(ScopedCodeBuilderContext buildContext, HATVectorOfOp hatVectorOp) { + blockComment("Vector Of Ops Not Implemented"); + return self(); + } + + @Override + public T hatVectorMakeOf(ScopedCodeBuilderContext builderContext, HATVectorMakeOfOp hatVectorMakeOfOp) { + blockComment("Vector Make Of Op Not Implemented"); + return self(); + } + public T createJava(ScopedCodeBuilderContext buildContext) { buildContext.funcScope(buildContext.funcOp, () -> { typeName(buildContext.funcOp.resultType().toString()).space().funcName(buildContext.funcOp);