Skip to content

Commit 243bf9c

Browse files
committed
Merge branch 'main' into slogdet
2 parents 6bd93e0 + f977292 commit 243bf9c

31 files changed

+1732
-326
lines changed

.github/scripts/env.sh

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="\
1414
impi-devel==2021.14.2 |\
1515
oneccl-devel==2021.14.1 |\
1616
mkl-devel==2025.0.1 |\
17+
onemkl-sycl-blas==2025.0.1 |\
1718
onemkl-sycl-dft==2025.0.1 |\
19+
onemkl-sycl-lapack==2025.0.1 |\
1820
tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.2 \
1921
"

src/ATen/native/nested/xpu/NestedTensorTransformerFunctions.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Tensor nested_from_padded_xpu(
5454

5555
Tensor metadata =
5656
at::cat({target_size_sizes, padded_sizes_tensor, target_offsets});
57-
metadata = metadata.to(at::Device(kCUDA), kInt, true, true);
57+
metadata = metadata.to(at::Device(kXPU), kInt, true, true);
5858

5959
auto output_size_ptr = metadata.data_ptr<int>();
6060
auto input_size_ptr = output_size_ptr + target_size_sizes.numel();

src/ATen/native/sparse/xpu/sycl/SparseCsrTensorMathKernels.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ Tensor reduce_sparse_csr_dim1_xpu_template(
248248
auto ioptions = crow_indices.options();
249249
Tensor values = sparse.values();
250250
auto nrows = sparse.size(0);
251-
auto numel = values.numel();
252251

253252
Tensor new_crow_indices = at::empty({crow_indices.numel()}, ioptions);
254253
Tensor new_col_indices = at::empty({}, ioptions);

src/ATen/native/sparse/xpu/sycl/SparseSoftmaxKernels.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ Tensor get_offsets(
380380
}
381381
// auto strides = host_strides;
382382
auto strides = at::empty({ndim}, indices.options());
383-
auto strides_ptr = strides.data_ptr<int64_t>();
383+
// auto strides_ptr = strides.data_ptr<int64_t>();
384384

385385
// syclMemcpyAsync(
386386
// strides_ptr,
@@ -392,11 +392,10 @@ Tensor get_offsets(
392392
strides[kk] = host_strides[kk];
393393
}
394394

395-
auto indices_accessor = indices.packed_accessor64<int64_t, 2>();
395+
// auto indices_accessor = indices.packed_accessor64<int64_t, 2>();
396396
Tensor offsets = at::ones({nnz}, indices.options());
397397

398398
for (int i = 0; i < nnz; i++) {
399-
int64_t pool_index = 0;
400399
for (int64_t j = 0; j < ndim; j++) {
401400
if (j != dim) {
402401
offsets[i] += (strides[j] * indices[j][i]);

src/ATen/native/xpu/BatchNorm.cpp

+6-10
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_xpu(
240240
c10::MaybeOwned<Tensor> weight_maybe_owned =
241241
at::borrow_from_optional_tensor(weight_opt);
242242
const Tensor& weight = *weight_maybe_owned;
243-
const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
243+
const Tensor& bias = bias_opt.value_or(Tensor());
244244
Tensor reserve;
245245

246246
reserve = at::empty({0}, input.options().dtype(kByte));
@@ -284,7 +284,7 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> _batch_norm_with_update_xpu_out(
284284
c10::MaybeOwned<Tensor> weight_maybe_owned =
285285
at::borrow_from_optional_tensor(weight_opt);
286286
const Tensor& weight = *weight_maybe_owned;
287-
const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); });
287+
const Tensor& bias = bias_opt.value_or(Tensor());
288288

289289
std::tie(out, save_mean, save_var) = xpu::batch_norm_kernel(
290290
input,
@@ -315,14 +315,10 @@ std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_xpu(
315315
double eps,
316316
std::array<bool, 3> grad_input_mask,
317317
const Tensor& reserve) {
318-
const Tensor& running_mean =
319-
c10::value_or_else(running_mean_opt, [] { return Tensor(); });
320-
const Tensor& running_var =
321-
c10::value_or_else(running_var_opt, [] { return Tensor(); });
322-
const Tensor& save_mean =
323-
c10::value_or_else(save_mean_opt, [] { return Tensor(); });
324-
const Tensor& save_var =
325-
c10::value_or_else(save_var_opt, [] { return Tensor(); });
318+
const Tensor& running_mean = running_mean_opt.value_or(Tensor());
319+
const Tensor& running_var = running_var_opt.value_or(Tensor());
320+
const Tensor& save_mean = save_mean_opt.value_or(Tensor());
321+
const Tensor& save_var = save_var_opt.value_or(Tensor());
326322
return xpu::batch_norm_backward_kernel(
327323
grad_output,
328324
input,

src/ATen/native/xpu/Copy.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ void _copy_xpu(TensorIterator& iter, bool non_blocking) {
249249
if (copy_kind == _H2D_) {
250250
if (at::detail::getXPUHooks().isPinnedPtr(src)) {
251251
q.memcpy(dst, src, nbytes);
252-
at::xpu::CachingHostAllocator_recordEvent(
252+
at::getHostAllocator(at::kXPU)->record_event(
253253
const_cast<void*>(src),
254254
iter.tensor(1).storage().data_ptr().get_context(),
255255
at::xpu::getCurrentXPUStream());
@@ -259,7 +259,7 @@ void _copy_xpu(TensorIterator& iter, bool non_blocking) {
259259
// by CPU tensor factory won't be cached in CPU allocator. When host
260260
// memory is freed with CPU tensor dtor at the end of train main loop,
261261
// but the corresponding H2D copy might not have been executed yet.
262-
auto stage_mem_dptr = at::xpu::HostAlloc(nbytes);
262+
auto stage_mem_dptr = at::getHostAllocator(at::kXPU)->allocate(nbytes);
263263
void* stage_mem = stage_mem_dptr.get();
264264
if (!stage_mem) {
265265
throw std::runtime_error(
@@ -268,15 +268,15 @@ void _copy_xpu(TensorIterator& iter, bool non_blocking) {
268268

269269
std::memcpy(stage_mem, src, nbytes);
270270
q.memcpy(dst, stage_mem, nbytes);
271-
at::xpu::CachingHostAllocator_recordEvent(
271+
at::getHostAllocator(at::kXPU)->record_event(
272272
const_cast<void*>(stage_mem),
273273
stage_mem_dptr.get_context(),
274274
at::xpu::getCurrentXPUStream());
275275
}
276276
} else {
277277
q.memcpy(dst, src, nbytes);
278278
if (at::detail::getXPUHooks().isPinnedPtr(dst)) {
279-
at::xpu::CachingHostAllocator_recordEvent(
279+
at::getHostAllocator(at::kXPU)->record_event(
280280
const_cast<void*>(dst),
281281
iter.tensor(0).storage().data_ptr().get_context(),
282282
at::xpu::getCurrentXPUStream());

src/ATen/native/xpu/NMS.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_) {
3636
return at::empty({0}, dets.options().dtype(at::kLong));
3737
}
3838

39-
constexpr int nms_items_per_group = sizeof(unsigned long long) * 8;
40-
4139
auto order_t = std::get<1>(
4240
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
4341
auto dets_sorted = dets.index_select(0, order_t).contiguous();

src/ATen/native/xpu/SoftMax.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,16 @@ TORCH_IMPL_FUNC(softmax_backward_xpu_out)
3838
"grad_output");
3939
c10::impl::check_and_update_common_device(
4040
common_device, output, "xpu::_softmax_backward_data_out_out", "output");
41+
bool half_to_float = grad.scalar_type() != input_dtype;
42+
if (half_to_float) {
43+
TORCH_CHECK(
44+
(grad.scalar_type() == ScalarType::Float &&
45+
input_dtype == ScalarType::Half),
46+
"expected input and grad types to match, or input to be at::Half and grad to be at::Float");
47+
}
4148

42-
native::xpu::_softmax_backward_kernel(grad, output, dim, false, grad_input);
49+
native::xpu::_softmax_backward_kernel(
50+
grad, output, dim, half_to_float, grad_input);
4351
}
4452

4553
TORCH_IMPL_FUNC(log_softmax_backward_xpu_out)
@@ -64,8 +72,15 @@ TORCH_IMPL_FUNC(log_softmax_backward_xpu_out)
6472
output,
6573
"xpu::_log_softmax_backward_data_out_out",
6674
"output");
75+
bool half_to_float = grad.scalar_type() != input_dtype;
76+
if (half_to_float) {
77+
TORCH_CHECK(
78+
(grad.scalar_type() == ScalarType::Float &&
79+
input_dtype == ScalarType::Half),
80+
"expected input and grad types to match, or input to be at::Half and grad to be at::Float");
81+
}
6782
native::xpu::_log_softmax_backward_kernel(
68-
grad, output, dim, false, grad_input);
83+
grad, output, dim, half_to_float, grad_input);
6984
}
7085

7186
TORCH_IMPL_FUNC(log_softmax_xpu_out)

src/ATen/native/xpu/SpectralOps.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/native/Resize.h>
55
#include <ATen/ops/_fft_c2c_native.h>
66
#include <ATen/ops/_fft_c2r_native.h>
7+
#include <ATen/ops/_fft_r2c_native.h>
78
#endif // USE_ONEMKL
89

910
namespace at::native {
@@ -79,4 +80,39 @@ Tensor& _fft_c2r_xpu_out(
7980
#endif // USE_ONEMKL
8081
}
8182

83+
Tensor _fft_r2c_xpu(
84+
const Tensor& self,
85+
IntArrayRef dim,
86+
int64_t normalization,
87+
bool onesided) {
88+
TORCH_CHECK(self.is_floating_point());
89+
90+
#if defined(USE_ONEMKL)
91+
return native::xpu::_fft_r2c_mkl(self, dim, normalization, onesided);
92+
#else
93+
Tensor out_cpu = native::_fft_r2c_mkl(
94+
self.to(Device(at::kCPU)), dim, normalization, onesided);
95+
return out_cpu.to(Device(at::kXPU));
96+
#endif // USE_ONEMKL
97+
}
98+
99+
Tensor& _fft_r2c_xpu_out(
100+
const Tensor& self,
101+
IntArrayRef dim,
102+
int64_t normalization,
103+
bool onesided,
104+
Tensor& out) {
105+
TORCH_CHECK(self.is_floating_point());
106+
107+
#if defined(USE_ONEMKL)
108+
return native::xpu::_fft_r2c_mkl_out(self, dim, normalization, onesided, out);
109+
#else
110+
Tensor out_cpu = native::_fft_r2c_mkl(
111+
self.to(Device(at::kCPU)), dim, normalization, onesided);
112+
at::native::resize_output(out, out_cpu.sizes());
113+
out.copy_(out_cpu);
114+
return out;
115+
#endif // USE_ONEMKL
116+
}
117+
82118
} // namespace at::native

src/ATen/native/xpu/XPUFallback.template

-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
190190
"_cholesky_solve_helper",
191191
"dot",
192192
"_efficient_attention_forward",
193-
"_fft_r2c",
194193
"_flash_attention_forward",
195194
"geqrf",
196195
"linalg_cholesky_ex.L",

src/ATen/native/xpu/mkl/SpectralOps.cpp

+113
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/native/Resize.h>
33
#include <ATen/native/SpectralOpsUtils.h>
44
#include <ATen/native/xpu/mkl/SpectralOps.h>
5+
#include <ATen/native/xpu/sycl/FFTKernelFunctor.h>
56
#include <ATen/ops/complex.h>
67
#include <ATen/ops/imag.h>
78
#include <ATen/ops/mul.h>
@@ -52,6 +53,7 @@ void _mkl_dft(
5253

5354
auto istrides = input.strides();
5455
auto ostrides = output.strides();
56+
5557
int64_t idist = istrides[0];
5658
int64_t odist = ostrides[0];
5759

@@ -477,5 +479,116 @@ Tensor& _fft_c2r_mkl_out(
477479
out, result, normalization, result.sizes(), dim);
478480
}
479481

482+
REGISTER_XPU_DISPATCH(
483+
fft_fill_with_conjugate_symmetry_stub,
484+
&_fft_fill_with_conjugate_symmetry_xpu);
485+
486+
Tensor _fft_r2c_mkl(
487+
const Tensor& self,
488+
IntArrayRef dim,
489+
int64_t normalization,
490+
bool onesided) {
491+
if (dim.empty()) {
492+
return self.clone();
493+
}
494+
495+
auto input_sizes = self.sizes();
496+
DimVector onesided_sizes(input_sizes.begin(), input_sizes.end());
497+
auto last_dim = dim.back();
498+
auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
499+
onesided_sizes[last_dim] = last_dim_halfsize;
500+
501+
IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes;
502+
503+
auto sorted_dims = impl::_sort_dims(self, dim, /*exclude_last=*/true);
504+
auto out = at::empty(
505+
out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
506+
507+
auto working_tensor = self.clone(MemoryFormat::Contiguous);
508+
509+
// First do the R2C transform on the last dimension
510+
impl::_exec_fft(
511+
out, working_tensor, out_sizes, last_dim, onesided, /*forward=*/true);
512+
513+
if (dim.size() > 1) {
514+
working_tensor = at::empty(
515+
out_sizes,
516+
self.options().dtype(c10::toComplexType(self.scalar_type())));
517+
}
518+
519+
sorted_dims.resize(sorted_dims.size() - 1);
520+
521+
while (!sorted_dims.empty()) {
522+
if (working_tensor.is_same(self)) {
523+
working_tensor = std::move(out);
524+
out = at::empty(
525+
out_sizes,
526+
self.options().dtype(c10::toComplexType(self.scalar_type())));
527+
} else {
528+
std::swap(out, working_tensor);
529+
}
530+
531+
const auto max_dims =
532+
std::min(static_cast<size_t>(impl::mkl_max_ndim), sorted_dims.size());
533+
auto fft_dims =
534+
IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims);
535+
impl::_exec_fft(
536+
out,
537+
working_tensor,
538+
out_sizes,
539+
fft_dims,
540+
onesided,
541+
/*forward=*/true);
542+
sorted_dims.resize(sorted_dims.size() - max_dims);
543+
544+
if (sorted_dims.empty()) {
545+
break;
546+
}
547+
548+
sorted_dims = impl::_sort_dims(self, sorted_dims);
549+
}
550+
551+
// Only need to normalize the onesided slice since data in the other half is
552+
// overwritten
553+
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
554+
working_tensor = self;
555+
if (!onesided) {
556+
if (out.sizes()[last_dim] != out_sizes[last_dim]) {
557+
working_tensor.resize_(out_sizes, MemoryFormat::Contiguous);
558+
working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(out);
559+
out = std::move(working_tensor);
560+
}
561+
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
562+
}
563+
564+
return impl::_fft_apply_normalization(out, normalization, input_sizes, dim);
565+
}
566+
567+
Tensor& _fft_r2c_mkl_out(
568+
const Tensor& self,
569+
IntArrayRef dim,
570+
int64_t normalization,
571+
bool onesided,
572+
Tensor& out) {
573+
auto result = _fft_r2c_mkl(
574+
self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
575+
576+
if (onesided) {
577+
return impl::_fft_apply_normalization_out(
578+
out, result, normalization, self.sizes(), dim);
579+
}
580+
581+
at::native::resize_output(out, self.sizes());
582+
583+
auto last_dim = dim.back();
584+
auto last_dim_halfsize = result.sizes()[last_dim];
585+
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
586+
587+
impl::_fft_apply_normalization_out(
588+
out_slice, result, normalization, self.sizes(), dim);
589+
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
590+
return out;
591+
}
592+
480593
} // namespace at::native::xpu
481594
#endif // USE_ONEMKL

src/ATen/native/xpu/mkl/SpectralOps.h

+13
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,17 @@ TORCH_XPU_API Tensor& _fft_c2r_mkl_out(
3030
int64_t last_dim_size,
3131
Tensor& out);
3232

33+
TORCH_XPU_API Tensor _fft_r2c_mkl(
34+
const Tensor& self,
35+
IntArrayRef dim,
36+
int64_t normalization,
37+
bool onesided);
38+
39+
TORCH_XPU_API Tensor& _fft_r2c_mkl_out(
40+
const Tensor& self,
41+
IntArrayRef dim,
42+
int64_t normalization,
43+
bool onesided,
44+
Tensor& out);
45+
3346
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/BatchNormKernels.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -5416,8 +5416,7 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_with_counts_kernel(
54165416
c10::MaybeOwned<Tensor> running_mean_maybe_owned =
54175417
at::borrow_from_optional_tensor(running_mean_opt);
54185418
const Tensor& running_mean = *running_mean_maybe_owned;
5419-
const Tensor& running_var =
5420-
c10::value_or_else(running_var_opt, [] { return Tensor(); });
5419+
const Tensor& running_var = running_var_opt.value_or(Tensor());
54215420

54225421
auto scalar_type =
54235422
running_mean.defined() ? running_mean.scalar_type() : self.scalar_type();
@@ -5471,8 +5470,7 @@ std::tuple<Tensor, Tensor> batch_norm_gather_stats_kernel(
54715470
c10::MaybeOwned<Tensor> running_mean_maybe_owned =
54725471
at::borrow_from_optional_tensor(running_mean_opt);
54735472
const Tensor& running_mean = *running_mean_maybe_owned;
5474-
const Tensor& running_var =
5475-
c10::value_or_else(running_var_opt, [] { return Tensor(); });
5473+
const Tensor& running_var = running_var_opt.value_or(Tensor());
54765474

54775475
std::vector<int64_t> counts(mean.size(0), count);
54785476
Tensor counts_ = at::from_blob(

src/ATen/native/xpu/sycl/Dequant_int4.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ struct DequantInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
3030
static_assert(TileN == SgSize);
3131
static_assert(TileK == 1);
3232
int nsg_k = k / GroupK;
33-
int nsg_n = n / GroupN;
3433

3534
int g_idx = it.get_group(0);
3635
auto sg = it.get_sub_group();

0 commit comments

Comments
 (0)