Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*/*/__pycache__/
*/*/*/*/__pycache__/
*/*/*/*/*/__pycache__/
*.pyc
temp/
153 changes: 153 additions & 0 deletions reflexion_oneshot_tritonbench_iter1_4/exec/embedding_triton_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import triton
import triton.language as tl

@triton.jit
def embedding_kernel(weight, input_ids, out, vob_start_id, vob_end_id, stride_weight_row, stride_out_row, n_ctx, hiden_size, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr):
pid = tl.program_id(0)
offs_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
mask_n = offs_n < n_ctx
token_ids_raw = tl.load(input_ids + offs_n, mask=mask_n, other=vob_end_id)
valid_id_mask = (token_ids_raw >= vob_start_id) & (token_ids_raw < vob_end_id)
token_ids_clamped = tl.where(valid_id_mask, token_ids_raw - vob_start_id, 0)
offs_vec = token_ids_clamped[:, None] * stride_weight_row + offs_d[None, :]
load_mask = valid_id_mask[:, None] & (offs_d[None, :] < hiden_size)
vec = tl.load(weight + offs_vec, mask=load_mask, other=0.0)
vec = tl.where(valid_id_mask[:, None], vec, 0.0)
dest_offs = offs_n[:, None] * stride_out_row + offs_d[None, :]
store_mask = mask_n[:, None] & (offs_d[None, :] < hiden_size)
tl.store(out + dest_offs, vec, mask=store_mask)

@torch.no_grad()
def embedding(input_ids: torch.Tensor, weight: torch.Tensor, vob_start_id: int, vob_end_id: int, out: torch.Tensor):
assert input_ids.ndim == 1
assert weight.ndim == 2
assert out.ndim == 2 and out.shape[0] == input_ids.shape[0] and (out.shape[1] == weight.shape[1])
n_ctx = input_ids.shape[0]
BLOCK_DMODEL = triton.next_power_of_2(weight.shape[1])
BLOCK_N = 128
grid = (triton.cdiv(n_ctx, BLOCK_N),)
embedding_kernel[grid](weight, input_ids, out, vob_start_id, vob_end_id, weight.stride(0), out.stride(0), n_ctx, weight.shape[1], BLOCK_N=BLOCK_N, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=1)

##################################################################################################################################################





import torch



def test_embedding():

# 参数定义

vocab_size = 1000 # 词汇表大小

embedding_dim = 512 # 嵌入维度

sequence_length = 128 # 输入序列长度

vob_start_id = 10 # 词汇表起始 ID

vob_end_id = 1000 # 词汇表结束 ID



# 创建测试输入张量

input_ids = torch.randint(

vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'

)

weight = torch.randn(

vocab_size, embedding_dim, dtype=torch.float32, device='cuda'

)

out = torch.zeros(

sequence_length, embedding_dim, dtype=torch.float32, device='cuda'

)



# 调用嵌入函数

embedding(input_ids, weight, vob_start_id, vob_end_id, out)



# 保存结果

results = {}

results['test_case_1'] = out.clone()



# 测试不同的输入

input_ids = torch.randint(

vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'

)

embedding(input_ids, weight, vob_start_id, vob_end_id, out)

results['test_case_2'] = out.clone()



# 测试不同的词汇表范围

vob_start_id = 0

vob_end_id = 500

input_ids = torch.randint(

vob_start_id, vob_end_id, (sequence_length,), dtype=torch.int32, device='cuda'

)

embedding(input_ids, weight, vob_start_id, vob_end_id, out)

results['test_case_3'] = out.clone()



# 测试不同的嵌入维度

embedding_dim = 256

weight = torch.randn(

vocab_size, embedding_dim, dtype=torch.float32, device='cuda'

)

out = torch.zeros(

sequence_length, embedding_dim, dtype=torch.float32, device='cuda'

)

embedding(input_ids, weight, vob_start_id, vob_end_id, out)

results['test_case_4'] = out.clone()



return results



result_gold = test_embedding()
144 changes: 144 additions & 0 deletions reflexion_oneshot_tritonbench_iter1_4/exec/flash_decode2_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import triton
import triton.language as tl

@triton.jit
def _fwd_kernel_flash_decode_stage2(B_Seqlen, Mid_O, Mid_O_LogExpSum, Out, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_obs, stride_oh, stride_od, head_dim, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
block_n_size = (cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ
sum_exp = 0.0
max_logic = -float('inf')
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for block_seq_n in range(0, block_n_size):
tv = tl.load(Mid_O + cur_batch * stride_mid_ob + cur_head * stride_mid_oh + block_seq_n * stride_mid_os + offs_d, mask=offs_d < head_dim, other=0.0)
tlogic = tl.load(Mid_O_LogExpSum + cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + block_seq_n)
new_max_logic = tl.maximum(tlogic, max_logic)
old_scale = tl.exp(max_logic - new_max_logic)
acc *= old_scale
exp_logic = tl.exp(tlogic - new_max_logic)
acc += exp_logic * tv
sum_exp = sum_exp * old_scale + exp_logic
max_logic = new_max_logic
tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim)

@torch.no_grad()
def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq):
Lk = mid_out.shape[-1]
head_dim = Lk
batch, head_num = (mid_out.shape[0], mid_out.shape[1])
BLOCK_DMODEL = triton.next_power_of_2(head_dim)
grid = (batch, head_num)
_fwd_kernel_flash_decode_stage2[grid](B_Seqlen, mid_out, mid_out_logexpsum, Out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), Out.stride(0), Out.stride(1), Out.stride(2), head_dim, BLOCK_SEQ=block_seq, BLOCK_DMODEL=BLOCK_DMODEL, num_warps=4, num_stages=2)

##################################################################################################################################################





import torch



# Define the test function

def test_flash_decode_stage2():

# Define the parameters for different test cases

batch_size = 2

head_num = 4

seq_block_num = 3

head_dim = 64

block_seq = 16



test_cases = {

"test_case_1": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq

},

"test_case_2": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq + 1 # Different block size

},

"test_case_3": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq // 2 # Different block size

},

"test_case_4": {

"B_Seqlen": torch.randint(1, seq_block_num * block_seq, (batch_size,), dtype=torch.int32, device='cuda'),

"mid_out": torch.randn((batch_size, head_num, seq_block_num, head_dim), dtype=torch.float32, device='cuda'),

"mid_out_logexpsum": torch.randn((batch_size, head_num, seq_block_num), dtype=torch.float32, device='cuda'),

"Out": torch.zeros((batch_size, head_num, head_dim), dtype=torch.float32, device='cuda'),

"block_seq": block_seq * 2 # Different block size

}

}



# Execute the function for all test cases

results = {}

for key, test_case in test_cases.items():

flash_decode_stage2(test_case["mid_out"], test_case["mid_out_logexpsum"], test_case["B_Seqlen"], test_case["Out"], test_case["block_seq"])

results[key] = test_case["Out"]



return results



# Run the test

result_gold = test_flash_decode_stage2()
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"speed_up": [
3.2874,
4.8564,
2.4925,
2.0037,
1.809,
5.3753,
0.6082
],
"efficiency": [
88.5857,
67.5919,
1.5125,
1.4315,
1.076,
72.3781,
13.8734
]
}
Loading