|
1 | 1 | #include <ATen/ATen.h>
|
| 2 | +#include <ATen/Dispatch.h> |
| 3 | +#include <ATen/core/TensorAccessor.h> |
| 4 | +#include <ATen/native/StridedRandomAccessor.h> |
2 | 5 | #include <ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h>
|
3 | 6 | #include <comm/SYCLContext.h>
|
4 | 7 |
|
@@ -613,4 +616,343 @@ void add_padding_kernel(
|
613 | 616 | });
|
614 | 617 | }
|
615 | 618 |
|
| 619 | +#define JAGGED_TENSOR_DISPATCH_DIMS() \ |
| 620 | + AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \ |
| 621 | + switch (num_jagged_dim) { \ |
| 622 | + case 1: \ |
| 623 | + INVOKE_KERNEL_WITH_DIM(1); \ |
| 624 | + break; \ |
| 625 | + case 2: \ |
| 626 | + INVOKE_KERNEL_WITH_DIM(2); \ |
| 627 | + break; \ |
| 628 | + case 3: \ |
| 629 | + INVOKE_KERNEL_WITH_DIM(3); \ |
| 630 | + break; \ |
| 631 | + case 4: \ |
| 632 | + INVOKE_KERNEL_WITH_DIM(4); \ |
| 633 | + break; \ |
| 634 | + case 5: \ |
| 635 | + INVOKE_KERNEL_WITH_DIM(5); \ |
| 636 | + break; \ |
| 637 | + default: \ |
| 638 | + TORCH_CHECK( \ |
| 639 | + false, "unsupported number of jagged dim ", num_jagged_dim); \ |
| 640 | + } \ |
| 641 | + }); |
| 642 | + |
| 643 | +inline std::string torch_tensor_device_name(const at::Tensor& ten) { |
| 644 | + return c10::DeviceTypeName(ten.device().type()); |
| 645 | +} |
| 646 | + |
| 647 | +inline std::string torch_tensor_device_name( |
| 648 | + const std::optional<at::Tensor>& ten) { |
| 649 | + if (ten.has_value()) { |
| 650 | + return torch_tensor_device_name(ten.value()); |
| 651 | + } else { |
| 652 | + return "N/A"; |
| 653 | + } |
| 654 | +} |
| 655 | + |
| 656 | +inline bool torch_tensor_on_xpu_gpu_check(const at::Tensor& ten) { |
| 657 | + return ten.is_xpu(); |
| 658 | +} |
| 659 | + |
| 660 | +inline bool torch_tensor_on_xpu_gpu_check( |
| 661 | + const std::optional<at::Tensor>& ten) { |
| 662 | + return !ten.has_value() || torch_tensor_on_xpu_gpu_check(ten.value()); |
| 663 | +} |
| 664 | + |
| 665 | +#define TENSOR_ON_XPU_GPU(x) \ |
| 666 | + TORCH_CHECK( \ |
| 667 | + torch_tensor_on_xpu_gpu_check(x), \ |
| 668 | + #x " must be a XPU tensor; it is currently on device ", \ |
| 669 | + torch_tensor_device_name(x)) |
| 670 | + |
| 671 | +// A wrapper class for passing dynamically sized dimension information (e.g. |
| 672 | +// tensor.dims()) from the host to device. |
| 673 | +constexpr size_t kStackArrayMaxDims = 5; |
| 674 | + |
| 675 | +template <typename T> |
| 676 | +struct StackArray { |
| 677 | + T vals[kStackArrayMaxDims]; |
| 678 | + size_t ndim; |
| 679 | +}; |
| 680 | + |
| 681 | +template <typename scalar_t> |
| 682 | +struct PaddingValueFuncutor { |
| 683 | + scalar_t operator()(scalar_t x, scalar_t /*unused*/) const { |
| 684 | + return x; |
| 685 | + } |
| 686 | +}; |
| 687 | + |
| 688 | +// Subgroup size |
| 689 | +static constexpr int32_t kSubgroupSize = 32; |
| 690 | +// Max thread num in one thread workgroup |
| 691 | +static constexpr int32_t kMaxThreads = 1024; |
| 692 | + |
| 693 | +inline int32_t div_round_up(int32_t a, int32_t b) { |
| 694 | + return (a + b - 1) / b; |
| 695 | +} |
| 696 | + |
| 697 | +inline int32_t round_down(int32_t a, int32_t b) { |
| 698 | + return a / b * b; |
| 699 | +} |
| 700 | + |
| 701 | +inline std::tuple<sycl::range<2>, sycl::range<2>, StackArray<int64_t>> |
| 702 | +check_shape_and_partition_( |
| 703 | + const Tensor& values, |
| 704 | + const std::vector<Tensor>& offsets, |
| 705 | + const Tensor& dense_tensor) { |
| 706 | + const int outer_dense_size = dense_tensor.size(0); |
| 707 | + TORCH_CHECK( |
| 708 | + outer_dense_size == offsets[0].numel() - 1, |
| 709 | + "outer_dense_size, ", |
| 710 | + outer_dense_size, |
| 711 | + " != offsets[0].numel() - 1, ", |
| 712 | + offsets[0].numel() - 1); |
| 713 | + const int inner_dense_size = dense_tensor.size(-1); |
| 714 | + TORCH_CHECK( |
| 715 | + inner_dense_size == values.size(-1), |
| 716 | + "inner_dense_size, ", |
| 717 | + inner_dense_size, |
| 718 | + " != values.size(-1), ", |
| 719 | + values.size(-1)); |
| 720 | + const int jagged_folded_size = |
| 721 | + dense_tensor.numel() / (outer_dense_size * inner_dense_size); |
| 722 | + |
| 723 | + const int wg_size_x = |
| 724 | + inner_dense_size >= kSubgroupSize / 2 ? kSubgroupSize : inner_dense_size; |
| 725 | + const int wg_size_y = kMaxThreads / kSubgroupSize; |
| 726 | + const int num_group = |
| 727 | + div_round_up(outer_dense_size * jagged_folded_size, wg_size_y); |
| 728 | + |
| 729 | + StackArray<int64_t> jagged_dims_tensor{}; |
| 730 | + const int num_jagged_dim = dense_tensor.dim() - 2; |
| 731 | + TORCH_CHECK(num_jagged_dim <= static_cast<int>(kStackArrayMaxDims)); |
| 732 | + jagged_dims_tensor.ndim = num_jagged_dim; |
| 733 | + std::memcpy( |
| 734 | + &(jagged_dims_tensor.vals[0]), |
| 735 | + dense_tensor.sizes().data() + 1, |
| 736 | + num_jagged_dim * sizeof(int64_t)); |
| 737 | + return { |
| 738 | + sycl::range<2>(wg_size_x, wg_size_y), |
| 739 | + sycl::range<2>(num_group * wg_size_x, wg_size_y), |
| 740 | + jagged_dims_tensor}; |
| 741 | +} |
| 742 | + |
| 743 | +template <int NUM_JAGGED_DIM, typename index_t> |
| 744 | +inline bool walk_down_tensor_storage_tree_( |
| 745 | + int& offset, |
| 746 | + const int flattened_jagged_idx, |
| 747 | + const StackArray<int64_t>& jagged_dims, |
| 748 | + const StackArray<index_t*>& x_offsets) { |
| 749 | + // compute coorindates |
| 750 | + int jagged_coords[NUM_JAGGED_DIM]; |
| 751 | + int j_temp = flattened_jagged_idx; |
| 752 | +#pragma unroll |
| 753 | + for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) { |
| 754 | + const int jagged_size = jagged_dims.vals[d]; |
| 755 | + jagged_coords[d] = j_temp % jagged_size; |
| 756 | + j_temp /= jagged_size; |
| 757 | + } |
| 758 | + |
| 759 | + // walk down the tree |
| 760 | + bool is_zero = false; |
| 761 | +#pragma unroll |
| 762 | + for (int d = 0; d < NUM_JAGGED_DIM; ++d) { |
| 763 | + const int begin = x_offsets.vals[d][offset]; |
| 764 | + const int end = x_offsets.vals[d][offset + 1]; |
| 765 | + if (jagged_coords[d] >= end - begin) { |
| 766 | + is_zero = true; |
| 767 | + break; |
| 768 | + } |
| 769 | + offset = begin + jagged_coords[d]; |
| 770 | + } |
| 771 | + return is_zero; |
| 772 | +} |
| 773 | + |
| 774 | +template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F> |
| 775 | +struct JaggedDenseElementwiseDenseFunctor { |
| 776 | + void operator()(sycl::nd_item<2> item) const { |
| 777 | + const int outer_dense_size = y_.size(0); |
| 778 | + const int jagged_folded_size = y_.size(1); |
| 779 | + const int inner_dense_size = y_.size(2); |
| 780 | + auto output = output_; |
| 781 | + const int outer_begin = |
| 782 | + item.get_group(0) * item.get_local_range(1) + item.get_local_id(1); |
| 783 | + const int outer_stride = item.get_group_range(0) * item.get_local_range(1); |
| 784 | + for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size; |
| 785 | + outer += outer_stride) { |
| 786 | + const int oidx = outer / jagged_folded_size; |
| 787 | + const int jidx = outer % jagged_folded_size; |
| 788 | + |
| 789 | + int offset = oidx; |
| 790 | + const bool is_zero = walk_down_tensor_storage_tree_<NUM_JAGGED_DIM>( |
| 791 | + offset, jidx, jagged_dims_, x_offsets_); |
| 792 | + |
| 793 | + if (is_zero) { |
| 794 | + int iidx; |
| 795 | + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; |
| 796 | + iidx += item.get_local_range(0)) { |
| 797 | + output[oidx][jidx][2 * iidx] = |
| 798 | + f_(padding_value_, y_[oidx][jidx][2 * iidx]); |
| 799 | + output[oidx][jidx][2 * iidx + 1] = |
| 800 | + f_(padding_value_, y_[oidx][jidx][2 * iidx + 1]); |
| 801 | + } |
| 802 | + if (iidx * 2 + 1 == inner_dense_size) { |
| 803 | + output[oidx][jidx][2 * iidx] = |
| 804 | + f_(padding_value_, y_[oidx][jidx][2 * iidx]); |
| 805 | + } |
| 806 | + } else { |
| 807 | + int iidx; |
| 808 | + for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size; |
| 809 | + iidx += item.get_local_range(0)) { |
| 810 | + output[oidx][jidx][2 * iidx] = |
| 811 | + f_(x_values_[offset][2 * iidx], y_[oidx][jidx][2 * iidx]); |
| 812 | + output[oidx][jidx][2 * iidx + 1] = |
| 813 | + f_(x_values_[offset][2 * iidx + 1], y_[oidx][jidx][2 * iidx + 1]); |
| 814 | + } |
| 815 | + if (iidx * 2 + 1 == inner_dense_size) { |
| 816 | + output[oidx][jidx][2 * iidx] = |
| 817 | + f_(x_values_[offset][2 * iidx], y_[oidx][jidx][2 * iidx]); |
| 818 | + } |
| 819 | + } |
| 820 | + } |
| 821 | + } |
| 822 | + JaggedDenseElementwiseDenseFunctor( |
| 823 | + const at::PackedTensorAccessor32<scalar_t, 2, RestrictPtrTraits> x_values, |
| 824 | + StackArray<index_t*> x_offsets, |
| 825 | + const at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> y, |
| 826 | + at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> output, |
| 827 | + StackArray<int64_t> jagged_dims, |
| 828 | + F f, |
| 829 | + const scalar_t padding_value) |
| 830 | + : x_values_(x_values), |
| 831 | + x_offsets_(x_offsets), |
| 832 | + y_(y), |
| 833 | + output_(output), |
| 834 | + jagged_dims_(jagged_dims), |
| 835 | + f_(f), |
| 836 | + padding_value_(padding_value) {} |
| 837 | + |
| 838 | + private: |
| 839 | + const at::PackedTensorAccessor32<scalar_t, 2, RestrictPtrTraits> x_values_; |
| 840 | + StackArray<index_t*> x_offsets_; |
| 841 | + const at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> y_; |
| 842 | + at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> output_; |
| 843 | + StackArray<int64_t> jagged_dims_; |
| 844 | + F f_; |
| 845 | + const scalar_t padding_value_; |
| 846 | +}; |
| 847 | + |
| 848 | +template <typename scalar_t, typename F> |
| 849 | +void jagged_dense_elementwise_dense_template( |
| 850 | + const Tensor& x_values, |
| 851 | + const std::vector<Tensor>& x_offsets, |
| 852 | + const Tensor& y, |
| 853 | + const Tensor& output, |
| 854 | + F f, |
| 855 | + const scalar_t padding_value = static_cast<scalar_t>(0)) { |
| 856 | + TENSOR_ON_XPU_GPU(x_values); |
| 857 | + for (auto& x_offset : x_offsets) { |
| 858 | + TENSOR_ON_XPU_GPU(x_offset); |
| 859 | + } |
| 860 | + |
| 861 | + const int num_jagged_dim = y.dim() - 2; |
| 862 | + TORCH_CHECK( |
| 863 | + x_offsets.size() == static_cast<size_t>(num_jagged_dim), |
| 864 | + "x_offsets.size(), ", |
| 865 | + x_offsets.size(), |
| 866 | + " != num_jagged_dim ", |
| 867 | + num_jagged_dim); |
| 868 | + |
| 869 | + if (y.numel() == 0) { |
| 870 | + return; |
| 871 | + } |
| 872 | + |
| 873 | + sycl::range<2> global_range, local_range; |
| 874 | + StackArray<int64_t> jagged_dims_tensor; |
| 875 | + std::tie(local_range, global_range, jagged_dims_tensor) = |
| 876 | + check_shape_and_partition_(x_values, x_offsets, y); |
| 877 | + |
| 878 | + // Canonicalize y and output to 3D, collapsing jagged dimensions. |
| 879 | + const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); |
| 880 | + Tensor output_reshaped = output.view(y_reshaped.sizes()); |
| 881 | + |
| 882 | +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ |
| 883 | + { \ |
| 884 | + std::vector<Tensor> x_offsets_contig; \ |
| 885 | + x_offsets_contig.resize(num_jagged_dim); \ |
| 886 | + StackArray<index_t*> x_offset_ptrs; \ |
| 887 | + x_offset_ptrs.ndim = num_jagged_dim; \ |
| 888 | + for (int d = 0; d < num_jagged_dim; ++d) { \ |
| 889 | + x_offsets_contig[d] = x_offsets[d].contiguous(); \ |
| 890 | + x_offset_ptrs.vals[d] = \ |
| 891 | + x_offsets_contig[d].template data_ptr<index_t>(); \ |
| 892 | + } \ |
| 893 | + auto kfn = JaggedDenseElementwiseDenseFunctor< \ |
| 894 | + NUM_JAGGED_DIM, \ |
| 895 | + index_t, \ |
| 896 | + scalar_t, \ |
| 897 | + F>( \ |
| 898 | + x_values.packed_accessor32<scalar_t, 2, RestrictPtrTraits>(), \ |
| 899 | + x_offset_ptrs, \ |
| 900 | + y_reshaped.packed_accessor32<scalar_t, 3, RestrictPtrTraits>(), \ |
| 901 | + output_reshaped.packed_accessor32<scalar_t, 3, RestrictPtrTraits>(), \ |
| 902 | + jagged_dims_tensor, \ |
| 903 | + f, \ |
| 904 | + padding_value); \ |
| 905 | + sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), kfn); \ |
| 906 | + } |
| 907 | + |
| 908 | + JAGGED_TENSOR_DISPATCH_DIMS(); |
| 909 | + |
| 910 | +#undef INVOKE_KERNEL_WITH_DIM |
| 911 | +} |
| 912 | + |
| 913 | +at::Tensor _fbgemm_jagged_to_padded_dense_forward_kernel( |
| 914 | + const Tensor& values, |
| 915 | + TensorList offsets, |
| 916 | + c10::IntArrayRef max_lengths, |
| 917 | + const double padding_value) { |
| 918 | + const Tensor values_canonicalized = values.view( |
| 919 | + {values.size(0), |
| 920 | + std::accumulate( |
| 921 | + values.sizes().begin() + 1, |
| 922 | + values.sizes().end(), |
| 923 | + 1, |
| 924 | + std::multiplies<size_t>())}); |
| 925 | + at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)}); |
| 926 | + padded_values_shape.insert( |
| 927 | + padded_values_shape.end(), max_lengths.begin(), max_lengths.end()); |
| 928 | + |
| 929 | + // Canonicalize padded_values by unsqueeze the last dim if the inner dense |
| 930 | + // dimension is 1 and folded. |
| 931 | + const bool D_folded = values.dim() == 1; |
| 932 | + if (!D_folded) { |
| 933 | + padded_values_shape.push_back(values.size(-1)); |
| 934 | + } |
| 935 | + Tensor padded_values = |
| 936 | + at::empty_symint(padded_values_shape, values.options()); |
| 937 | + Tensor padded_values_view = |
| 938 | + D_folded ? padded_values.unsqueeze(-1) : padded_values; |
| 939 | + |
| 940 | + AT_DISPATCH_ALL_TYPES_AND2( |
| 941 | + at::ScalarType::Half, |
| 942 | + at::ScalarType::BFloat16, |
| 943 | + values.scalar_type(), |
| 944 | + "jagged_to_padded_dense_xpu", |
| 945 | + [&] { |
| 946 | + jagged_dense_elementwise_dense_template<scalar_t>( |
| 947 | + values_canonicalized, |
| 948 | + offsets.vec(), |
| 949 | + padded_values_view, // dummy not used in the lambda function |
| 950 | + padded_values_view, |
| 951 | + PaddingValueFuncutor<scalar_t>(), |
| 952 | + static_cast<scalar_t>(padding_value)); |
| 953 | + }); |
| 954 | + |
| 955 | + return padded_values; |
| 956 | +} |
| 957 | + |
616 | 958 | } // namespace at::native::xpu
|
0 commit comments