Skip to content

Commit 6eaa324

Browse files
mfkasim1facebook-github-bot
authored andcommitted
Implement torch.igamma (pytorch#46183)
Summary: Fixes pytorch#41637 This is regularized lower incomplete gamma function, equivalent to scipy's `gammainc` and tensorflow `igamma`. cc fritzo mruberry Pull Request resolved: pytorch#46183 Reviewed By: gchanan Differential Revision: D24479126 Pulled By: mruberry fbshipit-source-id: fdf8ea289fe4ca1b408810732192411e948fcdfe
1 parent dd95bf6 commit 6eaa324

26 files changed

+1590
-8
lines changed

NOTICE

+106
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,112 @@ Apache License Version 2.0:
284284
incurred by, or claims asserted against, such Contributor by reason
285285
of your accepting any such warranty or additional liability.
286286

287+
=======================================================================
288+
Cephes's 3-Clause BSD License
289+
=======================================================================
290+
291+
Code derived from implementations in the Cephes Math Library should mention
292+
its derivation and reference the following license:
293+
294+
3-Clause BSD License for the Cephes Math Library
295+
Copyright (c) 2018, Steven Moshier
296+
All rights reserved.
297+
298+
Redistribution and use in source and binary forms, with or without
299+
modification, are permitted provided that the following conditions are met:
300+
301+
* Redistributions of source code must retain the above copyright
302+
notice, this list of conditions and the following disclaimer.
303+
304+
* Redistributions in binary form must reproduce the above copyright
305+
notice, this list of conditions and the following disclaimer in the
306+
documentation and/or other materials provided with the distribution.
307+
308+
* Neither the name of the nor the
309+
names of its contributors may be used to endorse or promote products
310+
derived from this software without specific prior written permission.
311+
312+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
313+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
314+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
315+
DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY
316+
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
317+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
318+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
319+
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
320+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
321+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
322+
323+
324+
=======================================================================
325+
SciPy's 3-Clause BSD License
326+
=======================================================================
327+
328+
Code derived from implementations in SciPy should mention its derivation
329+
and reference the following license:
330+
331+
Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
332+
All rights reserved.
333+
334+
Redistribution and use in source and binary forms, with or without
335+
modification, are permitted provided that the following conditions
336+
are met:
337+
338+
1. Redistributions of source code must retain the above copyright
339+
notice, this list of conditions and the following disclaimer.
340+
341+
2. Redistributions in binary form must reproduce the above
342+
copyright notice, this list of conditions and the following
343+
disclaimer in the documentation and/or other materials provided
344+
with the distribution.
345+
346+
3. Neither the name of the copyright holder nor the names of its
347+
contributors may be used to endorse or promote products derived
348+
from this software without specific prior written permission.
349+
350+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
351+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
352+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
353+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
354+
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
355+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
356+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
357+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
358+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
359+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
360+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
361+
362+
=======================================================================
363+
Boost's 1.0 Software License
364+
=======================================================================
365+
366+
Code derived from implementations in Boost 1.0 should mention its
367+
derivation and reference the following license:
368+
369+
Boost Software License - Version 1.0 - August 17th, 2003
370+
371+
Permission is hereby granted, free of charge, to any person or organization
372+
obtaining a copy of the software and accompanying documentation covered by
373+
this license (the "Software") to use, reproduce, display, distribute,
374+
execute, and transmit the Software, and to prepare derivative works of the
375+
Software, and to permit third-parties to whom the Software is furnished to
376+
do so, all subject to the following:
377+
378+
The copyright notices in the Software and this entire statement, including
379+
the above license grant, this restriction and the following disclaimer,
380+
must be included in all copies of the Software, in whole or in part, and
381+
all derivative works of the Software, unless such copies or derivative
382+
works are solely in the form of machine-executable object code generated by
383+
a source language processor.
384+
385+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
386+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
387+
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
388+
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
389+
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
390+
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
391+
DEALINGS IN THE SOFTWARE.
392+
287393
END OF TERMS AND CONDITIONS
288394

289395
APPENDIX: How to apply the Apache License to your work.

aten/src/ATen/core/NamedRegistrations.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
210210
m.impl("i0", CppFunction::makeFallthrough());
211211
m.impl("i0.out", CppFunction::makeFallthrough());
212212
m.impl("i0_", CppFunction::makeFallthrough());
213+
m.impl("igamma", CppFunction::makeFallthrough());
214+
m.impl("igamma.out", CppFunction::makeFallthrough());
215+
m.impl("igamma_", CppFunction::makeFallthrough());
213216
m.impl("imag", CppFunction::makeFallthrough());
214217
m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough());
215218
m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough());

aten/src/ATen/core/aten_interned_strings.h

+2
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ _(aten, hstack) \
371371
_(aten, hypot) \
372372
_(aten, i0) \
373373
_(aten, i0_) \
374+
_(aten, igamma) \
375+
_(aten, igamma_) \
374376
_(aten, ifft) \
375377
_(aten, index) \
376378
_(aten, index_add) \

aten/src/ATen/cpu/vec256/vec256_base.h

+7
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,13 @@ struct Vec256 {
394394
Vec256<T> i0() const {
395395
return map(calc_i0);
396396
}
397+
Vec256<T> igamma(const Vec256<T> &x) const {
398+
Vec256<T> ret;
399+
for (int64_t i = 0; i < size(); i++) {
400+
ret[i] = calc_igamma(values[i], x[i]);
401+
}
402+
return ret;
403+
}
397404
Vec256<T> neg() const {
398405
// NB: the trailing return type is needed because we need to coerce the
399406
// return value back to T in the case of unary operator- incuring a

aten/src/ATen/cpu/vec256/vec256_bfloat16.h

+19
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,25 @@ template <> class Vec256<BFloat16> {
290290
auto o2 = _mm256_loadu_ps(tmp2);
291291
return cvtfp32_bf16(o1, o2);
292292
}
293+
Vec256<BFloat16> igamma(const Vec256<BFloat16> &x) const {
294+
__m256 lo, hi;
295+
__m256 xlo, xhi;
296+
cvtbf16_fp32(values, lo, hi);
297+
cvtbf16_fp32(x.values, xlo, xhi);
298+
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
299+
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
300+
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
301+
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2];
302+
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
303+
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
304+
for (int64_t i = 0; i < size() / 2; ++i) {
305+
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
306+
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
307+
}
308+
auto o1 = _mm256_loadu_ps(tmp1);
309+
auto o2 = _mm256_loadu_ps(tmp2);
310+
return cvtfp32_bf16(o1, o2);
311+
}
293312
Vec256<BFloat16> log() const {
294313
return map(Sleef_logf8_u10);
295314
}

aten/src/ATen/cpu/vec256/vec256_complex_double.h

+3
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ template <> class Vec256<c10::complex<double>> {
252252
Vec256<c10::complex<double>> hypot(const Vec256<c10::complex<double>> &b) const {
253253
AT_ERROR("not supported for complex numbers");
254254
}
255+
Vec256<c10::complex<double>> igamma(const Vec256<c10::complex<double>> &x) const {
256+
AT_ERROR("not supported for complex numbers");
257+
}
255258
Vec256<c10::complex<double>> neg() const {
256259
auto zero = _mm256_setzero_pd();
257260
return _mm256_sub_pd(zero, values);

aten/src/ATen/cpu/vec256/vec256_complex_float.h

+3
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ template <> class Vec256<c10::complex<float>> {
290290
Vec256<c10::complex<float>> hypot(const Vec256<c10::complex<float>> &b) const {
291291
AT_ERROR("not supported for complex numbers");
292292
}
293+
Vec256<c10::complex<float>> igamma(const Vec256<c10::complex<float>> &x) const {
294+
AT_ERROR("not supported for complex numbers");
295+
}
293296
Vec256<c10::complex<float>> neg() const {
294297
auto zero = _mm256_setzero_ps();
295298
return _mm256_sub_ps(zero, values);

aten/src/ATen/cpu/vec256/vec256_double.h

+10
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ template <> class Vec256<double> {
155155
Vec256<double> i0() const {
156156
return map(calc_i0);
157157
}
158+
Vec256<double> igamma(const Vec256<double> &x) const {
159+
__at_align32__ double tmp[size()];
160+
__at_align32__ double tmp_x[size()];
161+
store(tmp);
162+
x.store(tmp_x);
163+
for (int64_t i = 0; i < size(); i++) {
164+
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
165+
}
166+
return loadu(tmp);
167+
}
158168
Vec256<double> log() const {
159169
return Vec256<double>(Sleef_logd4_u10(values));
160170
}

aten/src/ATen/cpu/vec256/vec256_float.h

+10
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,16 @@ template <> class Vec256<float> {
193193
Vec256<float> i0() const {
194194
return map(calc_i0);
195195
}
196+
Vec256<float> igamma(const Vec256<float> &x) const {
197+
__at_align32__ float tmp[size()];
198+
__at_align32__ float tmp_x[size()];
199+
store(tmp);
200+
x.store(tmp_x);
201+
for (int64_t i = 0; i < size(); i++) {
202+
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
203+
}
204+
return loadu(tmp);
205+
}
196206
Vec256<float> neg() const {
197207
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
198208
}

aten/src/ATen/cpu/vec256/vec256_float_neon.h

+10
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,16 @@ template <> class Vec256<float> {
362362
Vec256<float> i0() const {
363363
return map(calc_i0);
364364
}
365+
Vec256<float> igamma(const Vec256<float> &x) const {
366+
__at_align32__ float tmp[size()];
367+
__at_align32__ float tmp_x[size()];
368+
store(tmp);
369+
x.store(tmp_x);
370+
for (int64_t i = 0; i < size(); i++) {
371+
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
372+
}
373+
return loadu(tmp);
374+
}
365375
Vec256<float> log() const {
366376
return map(std::log);
367377
}

aten/src/ATen/native/BinaryOps.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ DEFINE_DISPATCH(logaddexp2_stub);
4646
DEFINE_DISPATCH(gcd_stub);
4747
DEFINE_DISPATCH(lcm_stub);
4848
DEFINE_DISPATCH(hypot_stub);
49+
DEFINE_DISPATCH(igamma_stub);
4950
DEFINE_DISPATCH(nextafter_stub);
5051
DEFINE_DISPATCH(heaviside_stub);
5152

@@ -968,6 +969,23 @@ Tensor& hypot_(Tensor& self, const Tensor& other) {
968969
return at::hypot_out(self, self, other);
969970
}
970971

972+
Tensor& igamma_out(Tensor& result, const Tensor& self, const Tensor& other) {
973+
auto iter = TensorIterator::binary_op(result, self, other);
974+
igamma_stub(iter.device_type(), iter);
975+
return result;
976+
}
977+
978+
Tensor igamma(const Tensor& self, const Tensor& other) {
979+
Tensor result;
980+
auto iter = TensorIterator::binary_op(result, self, other);
981+
igamma_stub(iter.device_type(), iter);
982+
return iter.output();
983+
}
984+
985+
Tensor& igamma_(Tensor& self, const Tensor& other) {
986+
return at::igamma_out(self, self, other);
987+
}
988+
971989
Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) {
972990
auto iter = TensorIterator::binary_op(result, self, other);
973991
nextafter_stub(iter.device_type(), iter);

aten/src/ATen/native/BinaryOps.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace at { namespace native {
1010
inline void alpha_check(const ScalarType dtype, Scalar alpha) {
1111
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
1212
"Boolean alpha only supported for Boolean results.");
13-
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
13+
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
1414
|| alpha.isIntegral(true),
1515
"For integral input tensors, argument alpha must not be a floating point number.");
1616
}
@@ -68,6 +68,7 @@ DECLARE_DISPATCH(binary_fn, logaddexp2_stub);
6868
DECLARE_DISPATCH(binary_fn, gcd_stub);
6969
DECLARE_DISPATCH(binary_fn, lcm_stub);
7070
DECLARE_DISPATCH(binary_fn, hypot_stub);
71+
DECLARE_DISPATCH(binary_fn, igamma_stub);
7172
DECLARE_DISPATCH(binary_fn, nextafter_stub);
7273
DECLARE_DISPATCH(binary_fn, heaviside_stub);
7374

0 commit comments

Comments
 (0)