Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions KLR/Trace/ISA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -170,24 +170,48 @@ nki builtin.isa.nc_transpose
match engine with
| .pe =>
let N := data.shapePure.freeDims.getLast!
let id : TensorRef := <- match <- lookup_global? (.num `identity 0) with
| some (.access acc) => return .abstract acc
| some _ => throw "identity has wrong type"
| none => throw "identity not defined"
let idName : TensorName <- match id with
| .abstract $ .simple t => pure t
| .abstract $ .basic t => pure t.tensor
| .abstract $ .pattern t => pure t.tensor
| _ => throw "Expected identity matrix to be a ref"
let idSlice : TensorRef := .abstract $ .basic $ <- AccessBasic.make idName [
let dtype := data.tensor.dtype
-- Create identity matrix inline with the data's dtype
-- using memset(0) + affine_select, matching beta3 approach.
let idShape := Core.Shape.mk 128 [128]
let idName := (<- genName).toString
let idAddr : Address := {
name := idName,
memory := .sbuf,
parSize := 128
freeSize := 128 * dtype.size
}
let idTensor <- Core.TensorName.make idName dtype idShape (some idAddr) (<- flags.address_rotation)
let idRef : TensorRef := .abstract (.simple idTensor)
-- Zero the identity tensor
Trace.add_stmt $ .oper (.memSet {
dst := idRef,
value := .float 0.0,
dtype := dtype,
engine := .unassigned
}) (<- genLabel `memset_id)
-- Write 1.0 on the diagonal using affine_select:
-- pattern [[0,1],[0,1],[0,1],[1,128]] with channel_multiplier=-1
-- produces (free_idx - partition_idx); cmp_op=not_equal keeps
-- on_true_tile (zeros) off-diagonal, writes on_false_value (1.0)
-- on the diagonal where the expression equals 0.
Trace.add_stmt $ .oper (.ncAffineSelect {
dst := idRef,
pred := ⟨0, [⟨0, 1, 0⟩, ⟨0, 1, 0⟩, ⟨0, 1, 0⟩, ⟨1, 128, 0⟩], -1⟩,
onTrueTile := idRef,
onFalseValue := .float 1.0,
dtype := some dtype,
cmpOp := .not_equal,
}) (<- genLabel `affsel_id)
let idSlice : TensorRef := .abstract $ .basic $ <- AccessBasic.make idTensor [
.slice $ Slice.make! 0 N 1,
.slice $ Slice.make! 0 N 1
]
Trace.add_stmt $ .oper (.ncMatMul {
dst := .abstract dst,
stationary := idSlice,
moving := .abstract data,
isStationaryOneZero := false,
isStationaryOneZero := true,
isMovingZero := false,
isTranspose := true,
tilePosition := [],
Expand Down
1 change: 0 additions & 1 deletion KLR/Trace/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ partial def lowerRes (t: Term) : Trace (List Core.Access) := do

def traceKernel (k : Kernel) : Trace Core.Kernel := do
let _ <- beginBlock (<- genLabel `main)
addId
globals k
flags k.flags
match k.funs.find? fun f => f.name == k.entry with
Expand Down
Loading