@@ -1580,3 +1580,298 @@ def fused_rotary_emb(
15801580 outputs = {"q_out" : q_out , "k_out" : k_out , "v_out" : v_out },
15811581 )
15821582 return q_out , k_out , v_out
1583+
1584+
1585+ ########################### split concat ###############################
1586+ split_concat_template = (
1587+ """
1588+ std::vector<paddle::Tensor> ${op_name}_func(
1589+ const paddle::Tensor &x,
1590+ const paddle::Tensor &y) {
1591+
1592+ int batch = x.dims()[0];
1593+
1594+ int seq_qkv = x.dims()[1];
1595+ int seq_eqkv = y.dims()[1];
1596+ int output_hidden = x.dims()[2] / 3;
1597+
1598+
1599+ auto qkv = get_tensor_ptr(x);
1600+ auto eqkv = get_tensor_ptr(y);
1601+
1602+
1603+ auto out0_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1604+ auto out1_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1605+ auto out2_tensor = paddle::empty({batch, seq_qkv+seq_eqkv, output_hidden}, x.dtype(), x.place());
1606+
1607+ auto out0 = get_tensor_ptr(out0_tensor);
1608+ auto out1 = get_tensor_ptr(out1_tensor);
1609+ auto out2 = get_tensor_ptr(out2_tensor);
1610+
1611+
1612+ auto run_stream = out0_tensor.stream();
1613+
1614+ """
1615+ + tune_and_invoke_part
1616+ + """
1617+ return {out0_tensor, out1_tensor, out2_tensor};
1618+ }
1619+
1620+ std::vector<std::vector<int64_t>> ${op_name}_InferShape(
1621+ const std::vector<int64_t>& A_shape, const std::vector<int64_t>& B_shape) {
1622+
1623+ int64_t seq1 = A_shape[1];
1624+ int64_t seq2 = B_shape[1];
1625+ int64_t seq = -1;
1626+ if (seq1 > 0 && seq2 > 0){
1627+ seq = seq1 + seq2;
1628+ }
1629+ std::vector<int64_t> out_shape = {A_shape[0], seq, A_shape[2]/3};
1630+
1631+ return {out_shape, out_shape, out_shape};
1632+ }
1633+
1634+ std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
1635+ return {A_dtype, A_dtype, A_dtype};
1636+ }
1637+
1638+ PD_BUILD_OP(${op_name})
1639+ .Inputs({"x", "y"})
1640+ .Outputs({"out0_tensor", "out1_tensor", "out2_tensor"})
1641+ .SetKernelFn(PD_KERNEL(${op_name}_func))
1642+ .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
1643+ .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
1644+ """
1645+ )
1646+
1647+
1648+ @paddle_use_triton (
1649+ custom_op_template = split_concat_template ,
1650+ key = ["1" ],
1651+ )
1652+ def split_concat_kernel (
1653+ out0 ,
1654+ out1 ,
1655+ out2 ,
1656+ qkv ,
1657+ eqkv ,
1658+ batch ,
1659+ seq_qkv ,
1660+ seq_eqkv ,
1661+ output_hidden ,
1662+ BLOCK_SIZE : tl .constexpr ,
1663+ ):
1664+ out_id = tl .program_id (axis = 0 )
1665+ batch = tl .program_id (axis = 1 )
1666+ out_row = tl .program_id (axis = 2 )
1667+ if out_row < seq_qkv :
1668+ read_ptr = out_id * output_hidden + out_row * 3 * output_hidden + batch * seq_qkv * output_hidden * 3 + qkv
1669+ else :
1670+ read_ptr = (
1671+ out_id * output_hidden
1672+ + (out_row - seq_qkv ) * 3 * output_hidden
1673+ + batch * seq_eqkv * output_hidden * 3
1674+ + eqkv
1675+ )
1676+
1677+ read_offsets = tl .arange (0 , BLOCK_SIZE )
1678+ mask = read_offsets < output_hidden
1679+ read_data = tl .load (read_ptr + read_offsets , mask = mask )
1680+
1681+ real_output = out0
1682+ if out_id == 1 :
1683+ real_output = out1
1684+ elif out_id == 2 :
1685+ real_output = out2
1686+
1687+ write_ptr = batch * (seq_qkv + seq_eqkv ) * output_hidden + out_row * output_hidden + real_output + read_offsets
1688+
1689+ tl .store (write_ptr , read_data , mask = mask )
1690+
1691+
1692+ def split_concat (x , y ):
1693+ assert len (x .shape ) == 3
1694+ assert len (y .shape ) == 3
1695+
1696+ assert x .shape [0 ] == y .shape [0 ]
1697+ assert x .shape [2 ] == y .shape [2 ]
1698+
1699+ batch = x .shape [0 ]
1700+ seq_qkv = x .shape [1 ]
1701+ hidd_x = x .shape [2 ]
1702+ seq_eqkv = y .shape [1 ]
1703+ ouput_hidden = hidd_x // 3
1704+ BLOCK_SIZE = triton .next_power_of_2 (ouput_hidden )
1705+ op_name = "split_concat"
1706+ op_name += get_dtype_str (x .dtype )
1707+ op_name += f"_{ BLOCK_SIZE } "
1708+
1709+ if op_name not in OpProtoHolder .instance ().op_proto_map .keys ():
1710+ out0 = paddle .empty (shape = [batch , seq_qkv + seq_eqkv , ouput_hidden ], dtype = x .dtype )
1711+ out1 = paddle .empty (shape = [batch , seq_qkv + seq_eqkv , ouput_hidden ], dtype = x .dtype )
1712+ out2 = paddle .empty (shape = [batch , seq_qkv + seq_eqkv , ouput_hidden ], dtype = x .dtype )
1713+ grid = ("3" , "batch" , "seq_qkv + seq_eqkv" )
1714+
1715+ split_concat_kernel [(op_name , grid )](
1716+ out0 , out1 , out2 , x , y , batch , seq_qkv , seq_eqkv , ouput_hidden , BLOCK_SIZE = BLOCK_SIZE
1717+ )
1718+
1719+ if in_dynamic_or_pir_mode ():
1720+ print (f"== we are in dynamic mode, op_name: { op_name } " )
1721+ outs = _C_ops ._run_custom_op (
1722+ op_name ,
1723+ x ,
1724+ y ,
1725+ )
1726+ return outs [0 ], outs [1 ], outs [2 ]
1727+ else :
1728+ print (f"== we are in dynamic to static mode, op_name: { op_name } " )
1729+ helper = LayerHelper (op_name , ** locals ())
1730+ inputs = {
1731+ "x" : x ,
1732+ "y" : y ,
1733+ }
1734+ out0 = helper .create_variable_for_type_inference (dtype = x .dtype )
1735+ out1 = helper .create_variable_for_type_inference (dtype = x .dtype )
1736+ out2 = helper .create_variable_for_type_inference (dtype = x .dtype )
1737+
1738+ helper .append_op (
1739+ type = op_name ,
1740+ inputs = inputs ,
1741+ outputs = {"out0_tensor" : out0 , "out1_tensor" : out1 , "out2_tensor" : out2 },
1742+ )
1743+ return out0 , out1 , out2
1744+
1745+
1746+ ########################### triton split ###############################
1747+ triton_split_template = (
1748+ """
1749+ std::vector<paddle::Tensor> ${op_name}_func(
1750+ const paddle::Tensor &x,
1751+ const std::vector<int64_t> num_or_sections,
1752+ const int64_t axis) {
1753+
1754+ int output_batch = x.dims()[0];
1755+ int output_seq0 = num_or_sections[0];
1756+ int output_seq1 = num_or_sections[1];
1757+ int output_hidden = x.dims()[2];
1758+
1759+ auto out0_tensor = paddle::empty({output_batch, output_seq0, output_hidden}, x.dtype(), x.place());
1760+ auto out1_tensor = paddle::empty({output_batch, output_seq1, output_hidden}, x.dtype(), x.place());
1761+
1762+ auto out0 = get_tensor_ptr(out0_tensor);
1763+ auto out1 = get_tensor_ptr(out1_tensor);
1764+
1765+ auto input = get_tensor_ptr(x);
1766+
1767+ auto run_stream = out0_tensor.stream();
1768+
1769+ """
1770+ + tune_and_invoke_part
1771+ + """
1772+ return {out0_tensor, out1_tensor};
1773+ }
1774+
1775+ std::vector<std::vector<int64_t>> ${op_name}_InferShape(
1776+ const std::vector<int64_t>& A_shape) {
1777+
1778+ std::vector<int64_t> out_shape0 = {A_shape[0], 1024, A_shape[2]};
1779+ std::vector<int64_t> out_shape1 = {A_shape[0], 154, A_shape[2]};
1780+
1781+ return {out_shape0, out_shape1};
1782+ }
1783+
1784+ std::vector<paddle::DataType> ${op_name}_InferDtype(const paddle::DataType& A_dtype) {
1785+ return {A_dtype, A_dtype};
1786+ }
1787+
1788+ PD_BUILD_OP(${op_name})
1789+ .Inputs({"x"})
1790+ .Outputs({"out0_tensor", "out1_tensor"})
1791+ .SetKernelFn(PD_KERNEL(${op_name}_func))
1792+ .Attrs({"num_or_sections: std::vector<int64_t>", "axis: int64_t"})
1793+ .SetInferDtypeFn(PD_INFER_DTYPE(${op_name}_InferDtype))
1794+ .SetInferShapeFn(PD_INFER_SHAPE(${op_name}_InferShape));
1795+ """
1796+ )
1797+
1798+
1799+ @paddle_use_triton (
1800+ custom_op_template = triton_split_template ,
1801+ key = ["1" ],
1802+ )
1803+ def triton_split_kernel (
1804+ out0 ,
1805+ out1 ,
1806+ input ,
1807+ output_seq0 ,
1808+ output_seq1 ,
1809+ output_batch ,
1810+ output_hidden ,
1811+ BLOCK_SIZE : tl .constexpr ,
1812+ ):
1813+ batch = tl .program_id (axis = 0 )
1814+ out_row = tl .program_id (axis = 1 )
1815+ read_ptr = out_row * output_hidden + batch * (output_seq0 + output_seq1 ) * output_hidden + input
1816+
1817+ read_offsets = tl .arange (0 , BLOCK_SIZE )
1818+ mask = read_offsets < output_hidden
1819+ read_data = tl .load (read_ptr + read_offsets , mask = mask )
1820+
1821+ if out_row < output_seq0 :
1822+ write_ptr = batch * output_seq0 * output_hidden + out_row * output_hidden + out0 + read_offsets
1823+ else :
1824+ write_ptr = batch * output_seq1 * output_hidden + (out_row - output_seq0 ) * output_hidden + out1 + read_offsets
1825+
1826+ tl .store (write_ptr , read_data , mask = mask )
1827+
1828+
1829+ def triton_split (x , num_or_sections = [- 1 , - 1 ], axis = 1 ):
1830+ assert len (x .shape ) == 3
1831+ output_batch = x .shape [0 ]
1832+ output_seq0 = num_or_sections [0 ]
1833+ output_seq1 = num_or_sections [1 ]
1834+ output_hidden = x .shape [2 ]
1835+
1836+ BLOCK_SIZE = triton .next_power_of_2 (output_hidden )
1837+ op_name = "triton_split"
1838+ op_name += get_dtype_str (x .dtype )
1839+ op_name += f"_{ BLOCK_SIZE } "
1840+
1841+ if op_name not in OpProtoHolder .instance ().op_proto_map .keys ():
1842+ out0 = paddle .empty (shape = [output_batch , output_seq0 , output_hidden ], dtype = x .dtype )
1843+ out1 = paddle .empty (shape = [output_batch , output_seq1 , output_hidden ], dtype = x .dtype )
1844+ grid = ("output_batch" , "output_seq0+output_seq1" )
1845+
1846+ triton_split_kernel [(op_name , grid )](
1847+ out0 , out1 , x , output_seq0 , output_seq1 , output_batch , output_hidden , BLOCK_SIZE = 2048
1848+ )
1849+
1850+ if in_dynamic_or_pir_mode ():
1851+ print (f"== we are in dynamic mode, op_name: { op_name } " )
1852+ outs = _C_ops ._run_custom_op (
1853+ op_name ,
1854+ x ,
1855+ num_or_sections ,
1856+ axis ,
1857+ )
1858+ return outs [0 ], outs [1 ]
1859+ else :
1860+ print (f"== we are in dynamic to static mode, op_name: { op_name } " )
1861+ helper = LayerHelper (op_name , ** locals ())
1862+ inputs = {
1863+ "x" : x ,
1864+ }
1865+ out0 = helper .create_variable_for_type_inference (dtype = x .dtype )
1866+ out1 = helper .create_variable_for_type_inference (dtype = x .dtype )
1867+
1868+ helper .append_op (
1869+ type = op_name ,
1870+ inputs = inputs ,
1871+ attrs = {
1872+ "num_or_sections" : num_or_sections ,
1873+ "axis" : axis ,
1874+ },
1875+ outputs = {"out0_tensor" : out0 , "out1_tensor" : out1 },
1876+ )
1877+ return out0 , out1
0 commit comments