Skip to content

Commit b1c5462

Browse files
Enable op Aten::_jagged_to_padded_dense_forward (#1517)
Enable op `Aten::_jagged_to_padded_dense_forward` --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent 07ca845 commit b1c5462

File tree

5 files changed

+376
-1
lines changed

5 files changed

+376
-1
lines changed

src/ATen/native/nested/xpu/NestedTensorTransformerFunctions.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <ATen/native/nested/NestedTensorUtils.h>
55
#include <ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h>
66

7+
#include <comm/XPUGuard.h>
8+
79
namespace at::native {
810

911
namespace {
@@ -205,4 +207,24 @@ Tensor NestedTensor_to_padded_tensor_xpu(
205207
return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
206208
}
207209

210+
at::Tensor _fbgemm_jagged_to_padded_dense_forward(
211+
const Tensor& values,
212+
TensorList offsets,
213+
c10::IntArrayRef max_lengths,
214+
const double padding_value) {
215+
const size_t num_jagged_dim = offsets.size();
216+
217+
TORCH_CHECK(
218+
max_lengths.size() == num_jagged_dim,
219+
"max_lengths.size(), ",
220+
max_lengths.size(),
221+
" != num_jagged_dim, ",
222+
num_jagged_dim);
223+
at::xpu::OptionalXPUGuard device_guard;
224+
device_guard.set_index(values.get_device());
225+
226+
return at::native::xpu::_fbgemm_jagged_to_padded_dense_forward_kernel(
227+
values, offsets, max_lengths, padding_value);
228+
}
229+
208230
} // namespace at::native

src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.cpp

+342
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/Dispatch.h>
3+
#include <ATen/core/TensorAccessor.h>
4+
#include <ATen/native/StridedRandomAccessor.h>
25
#include <ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h>
36
#include <comm/SYCLContext.h>
47

@@ -613,4 +616,343 @@ void add_padding_kernel(
613616
});
614617
}
615618

619+
#define JAGGED_TENSOR_DISPATCH_DIMS() \
620+
AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [=] { \
621+
switch (num_jagged_dim) { \
622+
case 1: \
623+
INVOKE_KERNEL_WITH_DIM(1); \
624+
break; \
625+
case 2: \
626+
INVOKE_KERNEL_WITH_DIM(2); \
627+
break; \
628+
case 3: \
629+
INVOKE_KERNEL_WITH_DIM(3); \
630+
break; \
631+
case 4: \
632+
INVOKE_KERNEL_WITH_DIM(4); \
633+
break; \
634+
case 5: \
635+
INVOKE_KERNEL_WITH_DIM(5); \
636+
break; \
637+
default: \
638+
TORCH_CHECK( \
639+
false, "unsupported number of jagged dim ", num_jagged_dim); \
640+
} \
641+
});
642+
643+
inline std::string torch_tensor_device_name(const at::Tensor& ten) {
644+
return c10::DeviceTypeName(ten.device().type());
645+
}
646+
647+
inline std::string torch_tensor_device_name(
648+
const std::optional<at::Tensor>& ten) {
649+
if (ten.has_value()) {
650+
return torch_tensor_device_name(ten.value());
651+
} else {
652+
return "N/A";
653+
}
654+
}
655+
656+
inline bool torch_tensor_on_xpu_gpu_check(const at::Tensor& ten) {
657+
return ten.is_xpu();
658+
}
659+
660+
inline bool torch_tensor_on_xpu_gpu_check(
661+
const std::optional<at::Tensor>& ten) {
662+
return !ten.has_value() || torch_tensor_on_xpu_gpu_check(ten.value());
663+
}
664+
665+
#define TENSOR_ON_XPU_GPU(x) \
666+
TORCH_CHECK( \
667+
torch_tensor_on_xpu_gpu_check(x), \
668+
#x " must be a XPU tensor; it is currently on device ", \
669+
torch_tensor_device_name(x))
670+
671+
// A wrapper class for passing dynamically sized dimension information (e.g.
672+
// tensor.dims()) from the host to device.
673+
constexpr size_t kStackArrayMaxDims = 5;
674+
675+
template <typename T>
676+
struct StackArray {
677+
T vals[kStackArrayMaxDims];
678+
size_t ndim;
679+
};
680+
681+
template <typename scalar_t>
682+
struct PaddingValueFuncutor {
683+
scalar_t operator()(scalar_t x, scalar_t /*unused*/) const {
684+
return x;
685+
}
686+
};
687+
688+
// Subgroup size
689+
static constexpr int32_t kSubgroupSize = 32;
690+
// Max thread num in one thread workgroup
691+
static constexpr int32_t kMaxThreads = 1024;
692+
693+
inline int32_t div_round_up(int32_t a, int32_t b) {
694+
return (a + b - 1) / b;
695+
}
696+
697+
inline int32_t round_down(int32_t a, int32_t b) {
698+
return a / b * b;
699+
}
700+
701+
inline std::tuple<sycl::range<2>, sycl::range<2>, StackArray<int64_t>>
702+
check_shape_and_partition_(
703+
const Tensor& values,
704+
const std::vector<Tensor>& offsets,
705+
const Tensor& dense_tensor) {
706+
const int outer_dense_size = dense_tensor.size(0);
707+
TORCH_CHECK(
708+
outer_dense_size == offsets[0].numel() - 1,
709+
"outer_dense_size, ",
710+
outer_dense_size,
711+
" != offsets[0].numel() - 1, ",
712+
offsets[0].numel() - 1);
713+
const int inner_dense_size = dense_tensor.size(-1);
714+
TORCH_CHECK(
715+
inner_dense_size == values.size(-1),
716+
"inner_dense_size, ",
717+
inner_dense_size,
718+
" != values.size(-1), ",
719+
values.size(-1));
720+
const int jagged_folded_size =
721+
dense_tensor.numel() / (outer_dense_size * inner_dense_size);
722+
723+
const int wg_size_x =
724+
inner_dense_size >= kSubgroupSize / 2 ? kSubgroupSize : inner_dense_size;
725+
const int wg_size_y = kMaxThreads / kSubgroupSize;
726+
const int num_group =
727+
div_round_up(outer_dense_size * jagged_folded_size, wg_size_y);
728+
729+
StackArray<int64_t> jagged_dims_tensor{};
730+
const int num_jagged_dim = dense_tensor.dim() - 2;
731+
TORCH_CHECK(num_jagged_dim <= static_cast<int>(kStackArrayMaxDims));
732+
jagged_dims_tensor.ndim = num_jagged_dim;
733+
std::memcpy(
734+
&(jagged_dims_tensor.vals[0]),
735+
dense_tensor.sizes().data() + 1,
736+
num_jagged_dim * sizeof(int64_t));
737+
return {
738+
sycl::range<2>(wg_size_x, wg_size_y),
739+
sycl::range<2>(num_group * wg_size_x, wg_size_y),
740+
jagged_dims_tensor};
741+
}
742+
743+
template <int NUM_JAGGED_DIM, typename index_t>
744+
inline bool walk_down_tensor_storage_tree_(
745+
int& offset,
746+
const int flattened_jagged_idx,
747+
const StackArray<int64_t>& jagged_dims,
748+
const StackArray<index_t*>& x_offsets) {
749+
// compute coorindates
750+
int jagged_coords[NUM_JAGGED_DIM];
751+
int j_temp = flattened_jagged_idx;
752+
#pragma unroll
753+
for (int d = NUM_JAGGED_DIM - 1; d >= 0; --d) {
754+
const int jagged_size = jagged_dims.vals[d];
755+
jagged_coords[d] = j_temp % jagged_size;
756+
j_temp /= jagged_size;
757+
}
758+
759+
// walk down the tree
760+
bool is_zero = false;
761+
#pragma unroll
762+
for (int d = 0; d < NUM_JAGGED_DIM; ++d) {
763+
const int begin = x_offsets.vals[d][offset];
764+
const int end = x_offsets.vals[d][offset + 1];
765+
if (jagged_coords[d] >= end - begin) {
766+
is_zero = true;
767+
break;
768+
}
769+
offset = begin + jagged_coords[d];
770+
}
771+
return is_zero;
772+
}
773+
774+
template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
775+
struct JaggedDenseElementwiseDenseFunctor {
776+
void operator()(sycl::nd_item<2> item) const {
777+
const int outer_dense_size = y_.size(0);
778+
const int jagged_folded_size = y_.size(1);
779+
const int inner_dense_size = y_.size(2);
780+
auto output = output_;
781+
const int outer_begin =
782+
item.get_group(0) * item.get_local_range(1) + item.get_local_id(1);
783+
const int outer_stride = item.get_group_range(0) * item.get_local_range(1);
784+
for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size;
785+
outer += outer_stride) {
786+
const int oidx = outer / jagged_folded_size;
787+
const int jidx = outer % jagged_folded_size;
788+
789+
int offset = oidx;
790+
const bool is_zero = walk_down_tensor_storage_tree_<NUM_JAGGED_DIM>(
791+
offset, jidx, jagged_dims_, x_offsets_);
792+
793+
if (is_zero) {
794+
int iidx;
795+
for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size;
796+
iidx += item.get_local_range(0)) {
797+
output[oidx][jidx][2 * iidx] =
798+
f_(padding_value_, y_[oidx][jidx][2 * iidx]);
799+
output[oidx][jidx][2 * iidx + 1] =
800+
f_(padding_value_, y_[oidx][jidx][2 * iidx + 1]);
801+
}
802+
if (iidx * 2 + 1 == inner_dense_size) {
803+
output[oidx][jidx][2 * iidx] =
804+
f_(padding_value_, y_[oidx][jidx][2 * iidx]);
805+
}
806+
} else {
807+
int iidx;
808+
for (iidx = item.get_local_id(0); iidx * 2 + 1 < inner_dense_size;
809+
iidx += item.get_local_range(0)) {
810+
output[oidx][jidx][2 * iidx] =
811+
f_(x_values_[offset][2 * iidx], y_[oidx][jidx][2 * iidx]);
812+
output[oidx][jidx][2 * iidx + 1] =
813+
f_(x_values_[offset][2 * iidx + 1], y_[oidx][jidx][2 * iidx + 1]);
814+
}
815+
if (iidx * 2 + 1 == inner_dense_size) {
816+
output[oidx][jidx][2 * iidx] =
817+
f_(x_values_[offset][2 * iidx], y_[oidx][jidx][2 * iidx]);
818+
}
819+
}
820+
}
821+
}
822+
JaggedDenseElementwiseDenseFunctor(
823+
const at::PackedTensorAccessor32<scalar_t, 2, RestrictPtrTraits> x_values,
824+
StackArray<index_t*> x_offsets,
825+
const at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> y,
826+
at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> output,
827+
StackArray<int64_t> jagged_dims,
828+
F f,
829+
const scalar_t padding_value)
830+
: x_values_(x_values),
831+
x_offsets_(x_offsets),
832+
y_(y),
833+
output_(output),
834+
jagged_dims_(jagged_dims),
835+
f_(f),
836+
padding_value_(padding_value) {}
837+
838+
private:
839+
const at::PackedTensorAccessor32<scalar_t, 2, RestrictPtrTraits> x_values_;
840+
StackArray<index_t*> x_offsets_;
841+
const at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> y_;
842+
at::PackedTensorAccessor32<scalar_t, 3, RestrictPtrTraits> output_;
843+
StackArray<int64_t> jagged_dims_;
844+
F f_;
845+
const scalar_t padding_value_;
846+
};
847+
848+
template <typename scalar_t, typename F>
849+
void jagged_dense_elementwise_dense_template(
850+
const Tensor& x_values,
851+
const std::vector<Tensor>& x_offsets,
852+
const Tensor& y,
853+
const Tensor& output,
854+
F f,
855+
const scalar_t padding_value = static_cast<scalar_t>(0)) {
856+
TENSOR_ON_XPU_GPU(x_values);
857+
for (auto& x_offset : x_offsets) {
858+
TENSOR_ON_XPU_GPU(x_offset);
859+
}
860+
861+
const int num_jagged_dim = y.dim() - 2;
862+
TORCH_CHECK(
863+
x_offsets.size() == static_cast<size_t>(num_jagged_dim),
864+
"x_offsets.size(), ",
865+
x_offsets.size(),
866+
" != num_jagged_dim ",
867+
num_jagged_dim);
868+
869+
if (y.numel() == 0) {
870+
return;
871+
}
872+
873+
sycl::range<2> global_range, local_range;
874+
StackArray<int64_t> jagged_dims_tensor;
875+
std::tie(local_range, global_range, jagged_dims_tensor) =
876+
check_shape_and_partition_(x_values, x_offsets, y);
877+
878+
// Canonicalize y and output to 3D, collapsing jagged dimensions.
879+
const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
880+
Tensor output_reshaped = output.view(y_reshaped.sizes());
881+
882+
#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
883+
{ \
884+
std::vector<Tensor> x_offsets_contig; \
885+
x_offsets_contig.resize(num_jagged_dim); \
886+
StackArray<index_t*> x_offset_ptrs; \
887+
x_offset_ptrs.ndim = num_jagged_dim; \
888+
for (int d = 0; d < num_jagged_dim; ++d) { \
889+
x_offsets_contig[d] = x_offsets[d].contiguous(); \
890+
x_offset_ptrs.vals[d] = \
891+
x_offsets_contig[d].template data_ptr<index_t>(); \
892+
} \
893+
auto kfn = JaggedDenseElementwiseDenseFunctor< \
894+
NUM_JAGGED_DIM, \
895+
index_t, \
896+
scalar_t, \
897+
F>( \
898+
x_values.packed_accessor32<scalar_t, 2, RestrictPtrTraits>(), \
899+
x_offset_ptrs, \
900+
y_reshaped.packed_accessor32<scalar_t, 3, RestrictPtrTraits>(), \
901+
output_reshaped.packed_accessor32<scalar_t, 3, RestrictPtrTraits>(), \
902+
jagged_dims_tensor, \
903+
f, \
904+
padding_value); \
905+
sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), kfn); \
906+
}
907+
908+
JAGGED_TENSOR_DISPATCH_DIMS();
909+
910+
#undef INVOKE_KERNEL_WITH_DIM
911+
}
912+
913+
at::Tensor _fbgemm_jagged_to_padded_dense_forward_kernel(
914+
const Tensor& values,
915+
TensorList offsets,
916+
c10::IntArrayRef max_lengths,
917+
const double padding_value) {
918+
const Tensor values_canonicalized = values.view(
919+
{values.size(0),
920+
std::accumulate(
921+
values.sizes().begin() + 1,
922+
values.sizes().end(),
923+
1,
924+
std::multiplies<size_t>())});
925+
at::SymDimVector padded_values_shape({at::SymInt(offsets[0].size(0) - 1)});
926+
padded_values_shape.insert(
927+
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
928+
929+
// Canonicalize padded_values by unsqueeze the last dim if the inner dense
930+
// dimension is 1 and folded.
931+
const bool D_folded = values.dim() == 1;
932+
if (!D_folded) {
933+
padded_values_shape.push_back(values.size(-1));
934+
}
935+
Tensor padded_values =
936+
at::empty_symint(padded_values_shape, values.options());
937+
Tensor padded_values_view =
938+
D_folded ? padded_values.unsqueeze(-1) : padded_values;
939+
940+
AT_DISPATCH_ALL_TYPES_AND2(
941+
at::ScalarType::Half,
942+
at::ScalarType::BFloat16,
943+
values.scalar_type(),
944+
"jagged_to_padded_dense_xpu",
945+
[&] {
946+
jagged_dense_elementwise_dense_template<scalar_t>(
947+
values_canonicalized,
948+
offsets.vec(),
949+
padded_values_view, // dummy not used in the lambda function
950+
padded_values_view,
951+
PaddingValueFuncutor<scalar_t>(),
952+
static_cast<scalar_t>(padding_value));
953+
});
954+
955+
return padded_values;
956+
}
957+
616958
} // namespace at::native::xpu

src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h

+6
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,10 @@ TORCH_XPU_API void add_padding_kernel(
5151
const int batch_size,
5252
const int output_batch_size);
5353

54+
TORCH_XPU_API at::Tensor _fbgemm_jagged_to_padded_dense_forward_kernel(
55+
const Tensor& values,
56+
TensorList offsets,
57+
c10::IntArrayRef max_lengths,
58+
const double padding_value);
59+
5460
} // namespace at::native::xpu

0 commit comments

Comments
 (0)