Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ public CudaHATKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildConte
}

csbrace().cparen().osbrace().intConstZero().csbrace()
.space().equals().space()
.varName(hatVectorStoreView);
.space().equals().space();
// if the value to be stored is an operation, recurse on the operation
if (hatVectorStoreView.operands().get(1) instanceof Op.Result r && r.op() instanceof HATVectorBinaryOp) {
recurse(buildContext, r.op());
} else {
varName(hatVectorStoreView);
}

return self();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,14 @@ public OpenCLHATKernelBuilder hatVectorStoreOp(ScopedCodeBuilderContext buildCon
Value index = hatVectorStoreView.operands().get(2);

identifier("vstore" + hatVectorStoreView.vectorN())
.oparen()
.varName(hatVectorStoreView)
.comma()
.space()
.oparen();
// if the value to be stored is an operation, recurse on the operation
if (hatVectorStoreView.operands().get(1) instanceof Op.Result r && r.op() instanceof HATVectorBinaryOp) {
recurse(buildContext, r.op());
} else {
varName(hatVectorStoreView);
}
comma().space()
.intConstZero()
.comma()
.space()
Expand Down Expand Up @@ -156,17 +160,23 @@ public OpenCLHATKernelBuilder hatVectorLoadOp(ScopedCodeBuilderContext buildCont

@Override
public OpenCLHATKernelBuilder hatSelectLoadOp(ScopedCodeBuilderContext buildContext, HATVectorSelectLoadOp hatVSelectLoadOp) {
identifier(hatVSelectLoadOp.varName())
.dot()
.identifier(hatVSelectLoadOp.mapLane());
if (hatVSelectLoadOp.operands().getFirst() instanceof Op.Result res && res.op() instanceof HATVectorLoadOp vLoadOp) {
recurse(buildContext, vLoadOp);
} else {
identifier(hatVSelectLoadOp.varName());
}
dot().identifier(hatVSelectLoadOp.mapLane());
return self();
}

@Override
public OpenCLHATKernelBuilder hatSelectStoreOp(ScopedCodeBuilderContext buildContext, HATVectorSelectStoreOp hatVSelectStoreOp) {
identifier(hatVSelectStoreOp.varName())
.dot()
.identifier(hatVSelectStoreOp.mapLane())
if (hatVSelectStoreOp.operands().getFirst() instanceof Op.Result res && res.op() instanceof HATVectorLoadOp vLoadOp) {
recurse(buildContext, vLoadOp);
} else {
identifier(hatVSelectStoreOp.varName());
}
dot().identifier(hatVSelectStoreOp.mapLane())
.space().equals().space();
if (hatVSelectStoreOp.resultValue() != null) {
// We have detected a direct resolved result (resolved name)
Expand Down
6 changes: 2 additions & 4 deletions hat/core/src/main/java/hat/buffer/F32ArrayPadded.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ default F32ArrayPadded copyTo(float[] floats) {
return this;
}

default float[] arrayView() {
float[] arr = new float[this.length()];
this.copyTo(arr);
return arr;
default Float4.MutableImpl[] float4ArrayView() {
return null;
}

default Float4.MutableImpl float4View(int index) {
Expand Down
159 changes: 1 addition & 158 deletions hat/core/src/main/java/hat/callgraph/KernelCallGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,7 @@ public Stream<KernelReachableResolvedMethodCall> kernelReachableResolvedStream()
initialModuleOp.functionTable().forEach((_, accessableFuncOp) ->
initialFuncOps.add( tier.apply(accessableFuncOp))
);
CoreOp.ModuleOp interiModuleOp = CoreOp.module(initialFuncOps);
CoreOp.FuncOp interimEntrypointFuncOp = convertArrayViewForFunc(computeContext.accelerator.lookup, entrypoint.funcOp());
entrypoint.funcOp(interimEntrypointFuncOp);


List<CoreOp.FuncOp> interimFuncOps = new ArrayList<>();
interiModuleOp.functionTable().forEach((_, accessableFuncOp) ->
interimFuncOps.add(convertArrayViewForFunc(computeContext.accelerator.lookup, accessableFuncOp))
);
setModuleOp(CoreOp.module(interimFuncOps));
setModuleOp(CoreOp.module(initialFuncOps));
}
/*
* A ResolvedKernelMethodCall (entrypoint or java method reachable from a compute entrypojnt) has the following calls
Expand Down Expand Up @@ -227,152 +218,4 @@ public void noconvertArrayView() {
// }
}
*/
public CoreOp.FuncOp convertArrayViewForFunc(MethodHandles.Lookup l, CoreOp.FuncOp entry) {
if (!OpTk.isArrayView(l, entry)) return entry;
usesArrayView = true;
// maps a replaced result to the result it should be replaced by
Map<Op.Result, Op.Result> replaced = new HashMap<>();
Map<Op, CoreOp.VarAccessOp.VarLoadOp> bufferVarLoads = new HashMap<>();

return entry.transform(entry.funcName(), (bb, op) -> {
switch (op) {
case JavaOp.InvokeOp iop -> {
if (OpTk.isBufferArray(iop) &&
OpTk.firstOperand(iop) instanceof Op.Result r) { // ensures we can use iop as key for replaced vvv
replaced.put(iop.result(), r);
bufferVarLoads.put(((Op.Result) OpTk.firstOperand(r.op())).op(), (CoreOp.VarAccessOp.VarLoadOp) r.op()); // map buffer VarOp to its corresponding VarLoadOp
return bb;
}
}
case CoreOp.VarOp vop -> {
if (OpTk.isBufferInitialize(vop) &&
OpTk.firstOperand(vop) instanceof Op.Result r) { // makes sure we don't process a new int[] for example
Op bufferLoad = replaced.get(r).op(); // gets the VarLoadOp associated w/ og buffer
replaced.put(vop.result(), (Op.Result) OpTk.firstOperand(bufferLoad)); // gets VarOp associated w/ og buffer
return bb;
}
}
case CoreOp.VarAccessOp.VarLoadOp vlop -> {
if (OpTk.isBufferInitialize(vlop) &&
OpTk.firstOperand(vlop) instanceof Op.Result r) {
if (r.op() instanceof CoreOp.VarOp) { // if this is the VarLoadOp after the .arrayView() InvokeOp
Op.Result replacement = (OpTk.notGlobalVarOp(vlop)) ?
(Op.Result) OpTk.firstOperand(((Op.Result) OpTk.firstOperand(r.op())).op()) :
bufferVarLoads.get(replaced.get(r).op()).result();
replaced.put(vlop.result(), replacement);
} else { // if this is a VarLoadOp loading in the buffer
Value loaded = OpTk.getValue(bb, replaced.get(r));
Op.Result newVlop = bb.op(CoreOp.VarAccessOp.varLoad(loaded));
bb.context().mapValue(vlop.result(), newVlop);
replaced.put(vlop.result(), newVlop);
}
return bb;
}
}
// handles only 1D and 2D arrays
case JavaOp.ArrayAccessOp.ArrayLoadOp alop -> {
if (OpTk.isBufferArray(alop) &&
OpTk.firstOperand(alop) instanceof Op.Result r) {
Op.Result buffer = replaced.getOrDefault(r, r);
if (((ArrayType) OpTk.firstOperand(op).type()).dimensions() == 1) { // we ignore the first array[][] load if using 2D arrays
if (r.op() instanceof JavaOp.ArrayAccessOp.ArrayLoadOp rowOp) {
// idea: we want to calculate the idx for the buffer access
// idx = (long) (((long) rowOp.idx * (long) buffer.width()) + alop.idx)
Op.Result x = (Op.Result) OpTk.getValue(bb, rowOp.operands().getLast());
Op.Result y = (Op.Result) OpTk.getValue(bb, alop.operands().getLast());
Op.Result ogBufferLoad = replaced.get((Op.Result) OpTk.firstOperand(rowOp));
Op.Result ogBuffer = replaced.getOrDefault((Op.Result) OpTk.firstOperand(ogBufferLoad.op()), (Op.Result) OpTk.firstOperand(ogBufferLoad.op()));
Op.Result bufferLoad = bb.op(CoreOp.VarAccessOp.varLoad(OpTk.getValue(bb, ogBuffer)));

Class<?> c = (Class<?>) OpTk.classTypeToTypeOrThrow(l, (ClassType) ((VarType) ogBuffer.type()).valueType());
MethodRef m = MethodRef.method(c, "width", int.class);
Op.Result width = bb.op(JavaOp.invoke(m, OpTk.getValue(bb, bufferLoad)));
Op.Result longX = bb.op(JavaOp.conv(JavaType.LONG, x));
Op.Result longY = bb.op(JavaOp.conv(JavaType.LONG, y));
Op.Result longWidth = bb.op(JavaOp.conv(JavaType.LONG, OpTk.getValue(bb, width)));
Op.Result mul = bb.op(JavaOp.mul(OpTk.getValue(bb, longY), OpTk.getValue(bb, longWidth)));
Op.Result idx = bb.op(JavaOp.add(OpTk.getValue(bb, longX), OpTk.getValue(bb, mul)));

Class<?> storedClass = OpTk.primitiveTypeToClass(alop.result().type());
MethodRef arrayMethod = MethodRef.method(c, "array", storedClass, long.class);
Op.Result invokeRes = bb.op(JavaOp.invoke(arrayMethod, OpTk.getValue(bb, ogBufferLoad), OpTk.getValue(bb, idx)));
bb.context().mapValue(alop.result(), invokeRes);
} else {
JavaOp.ConvOp conv = JavaOp.conv(JavaType.LONG, OpTk.getValue(bb, alop.operands().get(1)));
Op.Result convRes = bb.op(conv);

Class<?> c = (Class<?>) OpTk.classTypeToTypeOrThrow(l, (ClassType) buffer.type());
Class<?> storedClass = OpTk.primitiveTypeToClass(alop.result().type());
MethodRef m = MethodRef.method(c, "array", storedClass, long.class);
Op.Result invokeRes = bb.op(JavaOp.invoke(m, OpTk.getValue(bb, buffer), convRes));
bb.context().mapValue(alop.result(), invokeRes);
}
}
}
return bb;
}
// handles only 1D and 2D arrays
case JavaOp.ArrayAccessOp.ArrayStoreOp asop -> {
if (OpTk.isBufferArray( asop) &&
OpTk.firstOperand(asop) instanceof Op.Result r) {
Op.Result buffer = replaced.getOrDefault(r, r);
if (((ArrayType) OpTk.firstOperand(op).type()).dimensions() == 1) { // we ignore the first array[][] load if using 2D arrays
if (r.op() instanceof JavaOp.ArrayAccessOp.ArrayLoadOp rowOp) {
Op.Result x = (Op.Result) rowOp.operands().getLast();
Op.Result y = (Op.Result) asop.operands().get(1);
Op.Result ogBufferLoad = replaced.get((Op.Result) OpTk.firstOperand(rowOp));
Op.Result ogBuffer = replaced.getOrDefault((Op.Result) OpTk.firstOperand(ogBufferLoad.op()), (Op.Result) OpTk.firstOperand(ogBufferLoad.op()));
Op.Result bufferLoad = bb.op(CoreOp.VarAccessOp.varLoad(OpTk.getValue(bb, ogBuffer)));
Op.Result computed = (Op.Result) asop.operands().getLast();

Class<?> c = (Class<?>) OpTk.classTypeToTypeOrThrow(l, (ClassType) ((VarType) ogBuffer.type()).valueType());
MethodRef m = MethodRef.method(c, "width", int.class);
Op.Result width = bb.op(JavaOp.invoke(m, OpTk.getValue(bb, bufferLoad)));
Op.Result longX = bb.op(JavaOp.conv(JavaType.LONG, OpTk.getValue(bb, x)));
Op.Result longY = bb.op(JavaOp.conv(JavaType.LONG, OpTk.getValue(bb, y)));
Op.Result longWidth = bb.op(JavaOp.conv(JavaType.LONG, OpTk.getValue(bb, width)));
Op.Result mul = bb.op(JavaOp.mul(OpTk.getValue(bb, longY), OpTk.getValue(bb, longWidth)));
Op.Result idx = bb.op(JavaOp.add(OpTk.getValue(bb, longX), OpTk.getValue(bb, mul)));

MethodRef arrayMethod = MethodRef.method(c, "array", void.class, long.class, int.class);
Op.Result invokeRes = bb.op(JavaOp.invoke(arrayMethod, OpTk.getValue(bb, ogBufferLoad), OpTk.getValue(bb, idx), OpTk.getValue(bb, computed)));
bb.context().mapValue(asop.result(), invokeRes);
} else {
Op.Result idx = bb.op(JavaOp.conv(JavaType.LONG, OpTk.getValue(bb, asop.operands().get(1))));
Value val = OpTk.getValue(bb, asop.operands().getLast());

boolean noRootVlop = (buffer.op() instanceof CoreOp.VarOp);
ClassType classType = (noRootVlop) ?
(ClassType) ((CoreOp.VarOp) buffer.op()).varValueType() :
(ClassType) buffer.type();

Class<?> c = (Class<?>) OpTk.classTypeToTypeOrThrow(l, classType);
Class<?> storedClass = OpTk.primitiveTypeToClass(val.type());
MethodRef m = MethodRef.method(c, "array", void.class, long.class, storedClass);
Op.Result invokeRes = (noRootVlop) ?
bb.op(JavaOp.invoke(m, OpTk.getValue(bb, r), idx, val)) :
bb.op(JavaOp.invoke(m, OpTk.getValue(bb, buffer), idx, val));
bb.context().mapValue(asop.result(), invokeRes);
}
}
}
return bb;
}
case JavaOp.ArrayLengthOp alen -> {
if (OpTk.isBufferArray(alen) &&
OpTk.firstOperand(alen) instanceof Op.Result r) {
Op.Result buffer = replaced.get(r);
Class<?> c = (Class<?>) OpTk.classTypeToTypeOrThrow(l, (ClassType) buffer.type());
MethodRef m = MethodRef.method(c, "length", int.class);
Op.Result invokeRes = bb.op(JavaOp.invoke(m, OpTk.getValue(bb, buffer)));
bb.context().mapValue(alen.result(), invokeRes);
}
return bb;
}
default -> {}
}
bb.op(op);
return bb;
});
}
}
Loading