Skip to content

Commit a3494bd

Browse files
dylanbespalkofacebook-github-bot
authored andcommitted
CPU-Strided-Complex Fixes for real and imag ops (pytorch#29840)
Summary: In-tree changes to pytorch to support complex numbers are being submitted here. Out-of-tree support for complex numbers is here: [pytorch-cpu-strided-complex extension](https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex) - [x] Replaced std:real(a) with a.real() in kernel level code. - [x] Fixed Vec256_base implementation of complex ops so that it works correctly on Non-AVX devices. - [x] Fix NumericUtils.h cc: iotamudelta, ezyang, bddppq, zasdfgbnm Pull Request resolved: pytorch#29840 Differential Revision: D18531274 Pulled By: ezyang fbshipit-source-id: 0fa842c68e4bd55134fe0271880e2d15fe692b7f
1 parent 7d28768 commit a3494bd

File tree

6 files changed

+116
-72
lines changed

6 files changed

+116
-72
lines changed

aten/src/ATen/NumericUtils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
3434
}
3535

3636
template <typename T,
37-
typename std::enable_if<std::is_complex_t<T>::value, int>::type = 0>
37+
typename std::enable_if<c10::is_complex_t<T>::value, int>::type = 0>
3838
inline bool _isnan(T val) {
39-
return std::isnan(std::real(val)) || std::isnan(std::imag(val));
39+
return std::isnan(val.real()) || std::isnan(val.imag());
4040
}
4141

4242
inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {

aten/src/ATen/cpu/vec256/vec256_base.h

+72-28
Original file line numberDiff line numberDiff line change
@@ -179,42 +179,86 @@ struct Vec256 {
179179
}
180180
return ret;
181181
}
182-
template <typename other_t = T,
183-
typename std::enable_if<!std::is_floating_point<other_t>::value && !std::is_complex_t<other_t>::value, int>::type = 0>
182+
template <typename other_t_abs = T,
183+
typename std::enable_if<!std::is_floating_point<other_t_abs>::value && !c10::is_complex_t<other_t_abs>::value, int>::type = 0>
184184
Vec256<T> abs() const {
185-
// other_t is for SFINAE and clarity. Make sure it is not changed.
186-
static_assert(std::is_same<other_t, T>::value, "other_t must be T");
187-
return map([](T x) -> T { return x < static_cast<other_t>(0) ? -x : x; });
185+
// other_t_abs is for SFINAE and clarity. Make sure it is not changed.
186+
static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
187+
return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
188188
}
189-
template <typename float_t = T,
190-
typename std::enable_if<std::is_floating_point<float_t>::value, int>::type = 0>
189+
template <typename float_t_abs = T,
190+
typename std::enable_if<std::is_floating_point<float_t_abs>::value, int>::type = 0>
191191
Vec256<T> abs() const {
192-
// float_t is for SFINAE and clarity. Make sure it is not changed.
193-
static_assert(std::is_same<float_t, T>::value, "float_t must be T");
192+
// float_t_abs is for SFINAE and clarity. Make sure it is not changed.
193+
static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
194194
// Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
195195
// 0.0) properly.
196196
return map(std::abs);
197197
}
198-
template <typename complex_t = T,
199-
typename std::enable_if<std::is_complex_t<complex_t>::value, int>::type = 0>
198+
template <typename complex_t_abs = T,
199+
typename std::enable_if<c10::is_complex_t<complex_t_abs>::value, int>::type = 0>
200200
Vec256<T> abs() const {
201-
// complex_t is for SFINAE and clarity. Make sure it is not changed.
202-
static_assert(std::is_same<complex_t, T>::value, "complex_t must be T");
201+
// complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
202+
static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T");
203203
// Specifically map() does not perform the type conversion needed by abs.
204-
return map([](T x) { return (T)std::abs(x); });
204+
return map([](T x) { return static_cast<T>(std::abs(x)); });
205205
}
206+
template <typename other_t_angle = T,
207+
typename std::enable_if<!c10::is_complex_t<other_t_angle>::value, int>::type = 0>
206208
Vec256<T> angle() const {
207-
return *this;
209+
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
210+
static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
211+
return Vec256(0);
212+
}
213+
template <typename complex_t_angle = T,
214+
typename std::enable_if<c10::is_complex_t<complex_t_angle>::value, int>::type = 0>
215+
Vec256<T> angle() const {
216+
// complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
217+
static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T");
218+
return map([](T x) { return static_cast<T>(std::arg(x)); });
208219
}
220+
template <typename other_t_real = T,
221+
typename std::enable_if<!c10::is_complex_t<other_t_real>::value, int>::type = 0>
209222
Vec256<T> real() const {
223+
// other_t_real is for SFINAE and clarity. Make sure it is not changed.
224+
static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T");
210225
return *this;
211226
}
227+
template <typename complex_t_real = T,
228+
typename std::enable_if<c10::is_complex_t<complex_t_real>::value, int>::type = 0>
229+
Vec256<T> real() const {
230+
// complex_t_real is for SFINAE and clarity. Make sure it is not changed.
231+
static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T");
232+
return map([](T x) { return static_cast<T>(x.real()); });
233+
}
234+
template <typename other_t_imag = T,
235+
typename std::enable_if<!c10::is_complex_t<other_t_imag>::value, int>::type = 0>
212236
Vec256<T> imag() const {
213-
return *this;
237+
// other_t_imag is for SFINAE and clarity. Make sure it is not changed.
238+
static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T");
239+
return Vec256(0);
240+
}
241+
template <typename complex_t_imag = T,
242+
typename std::enable_if<c10::is_complex_t<complex_t_imag>::value, int>::type = 0>
243+
Vec256<T> imag() const {
244+
// complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
245+
static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T");
246+
return map([](T x) { return static_cast<T>(x.imag()); });
214247
}
248+
template <typename other_t_conj = T,
249+
typename std::enable_if<!c10::is_complex_t<other_t_conj>::value, int>::type = 0>
215250
Vec256<T> conj() const {
251+
// other_t_conj is for SFINAE and clarity. Make sure it is not changed.
252+
static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T");
216253
return *this;
217254
}
255+
template <typename complex_t_conj = T,
256+
typename std::enable_if<c10::is_complex_t<complex_t_conj>::value, int>::type = 0>
257+
Vec256<T> conj() const {
258+
// complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
259+
static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T");
260+
return map([](T x) { return static_cast<T>(std::conj(x)); });
261+
}
218262
Vec256<T> acos() const {
219263
return map(std::acos);
220264
}
@@ -259,14 +303,14 @@ struct Vec256 {
259303
return map(std::log1p);
260304
}
261305
template <typename other_t_log2 = T,
262-
typename std::enable_if<!std::is_complex_t<other_t_log2>::value, int>::type = 0>
306+
typename std::enable_if<!c10::is_complex_t<other_t_log2>::value, int>::type = 0>
263307
Vec256<T> log2() const {
264308
// other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
265309
static_assert(std::is_same<other_t_log2, T>::value, "other_t_log2 must be T");
266310
return map(std::log2);
267311
}
268312
template <typename complex_t_log2 = T,
269-
typename std::enable_if<std::is_complex_t<complex_t_log2>::value, int>::type = 0>
313+
typename std::enable_if<c10::is_complex_t<complex_t_log2>::value, int>::type = 0>
270314
Vec256<T> log2() const {
271315
// complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
272316
static_assert(std::is_same<complex_t_log2, T>::value, "complex_t_log2 must be T");
@@ -395,7 +439,7 @@ template <class T> Vec256<T> inline operator||(
395439
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
396440
// either input is a NaN.
397441
template <class T,
398-
typename std::enable_if<!std::is_complex_t<T>::value, int>::type = 0>
442+
typename std::enable_if<!c10::is_complex_t<T>::value, int>::type = 0>
399443
Vec256<T> inline maximum(const Vec256<T> &a, const Vec256<T> &b) {
400444
Vec256<T> c = Vec256<T>();
401445
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -411,7 +455,7 @@ Vec256<T> inline maximum(const Vec256<T> &a, const Vec256<T> &b) {
411455
}
412456

413457
template <class T,
414-
typename std::enable_if<std::is_complex_t<T>::value, int>::type = 0>
458+
typename std::enable_if<c10::is_complex_t<T>::value, int>::type = 0>
415459
Vec256<T> inline maximum(const Vec256<T> &a, const Vec256<T> &b) {
416460
Vec256<T> c = Vec256<T>();
417461
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -438,7 +482,7 @@ inline T maximum(const T& a, const T& b) {
438482
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
439483
// either input is a NaN.
440484
template <class T,
441-
typename std::enable_if<!std::is_complex_t<T>::value, int>::type = 0>
485+
typename std::enable_if<!c10::is_complex_t<T>::value, int>::type = 0>
442486
Vec256<T> inline minimum(const Vec256<T> &a, const Vec256<T> &b) {
443487
Vec256<T> c = Vec256<T>();
444488
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -454,7 +498,7 @@ Vec256<T> inline minimum(const Vec256<T> &a, const Vec256<T> &b) {
454498
}
455499

456500
template <class T,
457-
typename std::enable_if<std::is_complex_t<T>::value, int>::type = 0>
501+
typename std::enable_if<c10::is_complex_t<T>::value, int>::type = 0>
458502
Vec256<T> inline minimum(const Vec256<T> &a, const Vec256<T> &b) {
459503
Vec256<T> c = Vec256<T>();
460504
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -480,7 +524,7 @@ inline T minimum(const T& a, const T& b) {
480524

481525
// To save BC, it will not propagate NaN based on IEEE 754 201X
482526
template <class T,
483-
typename std::enable_if<!std::is_complex_t<T>::value, int>::type = 0>
527+
typename std::enable_if<!c10::is_complex_t<T>::value, int>::type = 0>
484528
Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec256<T> &max_vec) {
485529
Vec256<T> c = Vec256<T>();
486530
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -490,7 +534,7 @@ Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec25
490534
}
491535

492536
template <class T,
493-
typename std::enable_if<std::is_complex_t<T>::value, int>::type = 0>
537+
typename std::enable_if<c10::is_complex_t<T>::value, int>::type = 0>
494538
Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec256<T> &max_vec) {
495539
Vec256<T> c = Vec256<T>();
496540
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -500,7 +544,7 @@ Vec256<T> inline clamp(const Vec256<T> &a, const Vec256<T> &min_vec, const Vec25
500544
}
501545

502546
template <class T,
503-
typename std::enable_if<!std::is_complex_t<T>::value, int>::type = 0>
547+
typename std::enable_if<!c10::is_complex_t<T>::value, int>::type = 0>
504548
Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
505549
Vec256<T> c = Vec256<T>();
506550
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -510,7 +554,7 @@ Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
510554
}
511555

512556
template <class T,
513-
typename std::enable_if<std::is_complex_t<T>::value, int>::type = 0>
557+
typename std::enable_if<c10::is_complex_t<T>::value, int>::type = 0>
514558
Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
515559
Vec256<T> c = Vec256<T>();
516560
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -520,7 +564,7 @@ Vec256<T> inline clamp_max(const Vec256<T> &a, const Vec256<T> &max_vec) {
520564
}
521565

522566
template <class T,
523-
typename std::enable_if<!std::is_complex_t<T>::value, int>::type = 0>
567+
typename std::enable_if<!c10::is_complex_t<T>::value, int>::type = 0>
524568
Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
525569
Vec256<T> c = Vec256<T>();
526570
for (int i = 0; i != Vec256<T>::size(); i++) {
@@ -530,7 +574,7 @@ Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
530574
}
531575

532576
template <class T,
533-
typename std::enable_if<std::is_complex_t<T>::value, int>::type = 0>
577+
typename std::enable_if<c10::is_complex_t<T>::value, int>::type = 0>
534578
Vec256<T> inline clamp_min(const Vec256<T> &a, const Vec256<T> &min_vec) {
535579
Vec256<T> c = Vec256<T>();
536580
for (int i = 0; i != Vec256<T>::size(); i++) {

aten/src/ATen/cpu/vec256/vec256_complex_double.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ template <> class Vec256<std::complex<double>> {
2424
Vec256() {}
2525
Vec256(__m256d v) : values(v) {}
2626
Vec256(std::complex<double> val) {
27-
double real_value = std::real(val);
28-
double imag_value = std::imag(val);
27+
double real_value = val.real();
28+
double imag_value = val.imag();
2929
values = _mm256_setr_pd(real_value, imag_value,
3030
real_value, imag_value);
3131
}
3232
Vec256(std::complex<double> val1, std::complex<double> val2) {
33-
values = _mm256_setr_pd(std::real(val1), std::imag(val1),
34-
std::real(val2), std::imag(val2));
33+
values = _mm256_setr_pd(val1.real(), val1.imag(),
34+
val2.real(), val2.imag());
3535
}
3636
operator __m256d() const {
3737
return values;

aten/src/ATen/cpu/vec256/vec256_complex_float.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@ template <> class Vec256<std::complex<float>> {
2424
Vec256() {}
2525
Vec256(__m256 v) : values(v) {}
2626
Vec256(std::complex<float> val) {
27-
float real_value = std::real(val);
28-
float imag_value = std::imag(val);
27+
float real_value = val.real();
28+
float imag_value = val.imag();
2929
values = _mm256_setr_ps(real_value, imag_value,
3030
real_value, imag_value,
3131
real_value, imag_value,
3232
real_value, imag_value
3333
);
3434
}
3535
Vec256(std::complex<float> val1, std::complex<float> val2, std::complex<float> val3, std::complex<float> val4) {
36-
values = _mm256_setr_ps(std::real(val1), std::imag(val1),
37-
std::real(val2), std::imag(val2),
38-
std::real(val3), std::imag(val3),
39-
std::real(val4), std::imag(val4)
36+
values = _mm256_setr_ps(val1.real(), val1.imag(),
37+
val2.real(), val2.imag(),
38+
val3.real(), val3.imag(),
39+
val4.real(), val4.imag()
4040
);
4141
}
4242
operator __m256() const {

0 commit comments

Comments
 (0)