forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCTensorMathPointwise.cuh
246 lines (199 loc) · 6.2 KB
/
THCTensorMathPointwise.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
#ifndef THC_TENSORMATH_POINTWISE_CUH
#define THC_TENSORMATH_POINTWISE_CUH
#include <type_traits>
#include <THC/THCTensorMath.h>
#include <THC/THCGeneral.h>
#include <TH/THHalf.h>
#include <THC/THCTensorCopy.h>
#include <THC/THCApply.cuh>
#include <THC/THCNumerics.cuh>
#include <THC/THCReduce.cuh>
template <typename T>
struct TensorCAddOp {
TensorCAddOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out += val * *in;
}
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
*out = *in1 + val * *in2;
}
T val;
};
template <typename T>
struct TensorMulOp {
__device__ __forceinline__ void operator()(T* out, T* in) {
*out *= *in;
}
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
*out = *in1 * *in2;
}
};
template<typename T>
static __device__ __forceinline__
typename std::enable_if<std::is_signed<T>::value, bool>::type
modulo_wrap(T a, T b) {
return (a != 0) && (a < 0) != (b < 0);
}
template<typename T>
static __device__ __forceinline__
typename std::enable_if<std::is_unsigned<T>::value, bool>::type
modulo_wrap(T a, T b) {
return false;
}
template <typename T>
struct TensorCRemainderOp {
__device__ __forceinline__ void operator()(T* out, T* in) {
T val = *out % *in;
if (modulo_wrap(val, *in)) {
val += *in;
}
*out = val;
}
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
T val = *in1 % *in2;
if (modulo_wrap(val, *in2)) {
val += *in2;
}
*out = val;
}
};
template <>
struct TensorCRemainderOp<float> {
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = *in != 0.f ? *out - *in * floorf(*out / *in) : NAN;
}
__device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
*out = *in2 != 0.f ? *in1 - *in2 * floorf(*in1 / *in2) : NAN;
}
};
template <>
struct TensorCRemainderOp<double> {
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = *in != 0. ? *out - *in * floor(*out / *in) : NAN;
}
__device__ __forceinline__ void operator()(double* out, double* in1, double* in2) {
*out = *in2 != 0. ? *in1 - *in2 * floor(*in1 / *in2) : NAN;
}
};
template <>
struct TensorCRemainderOp<at::Half> {
__device__ __forceinline__ void operator()(at::Half* out, at::Half* in) {
*out = *in != 0.f ? *out - *in * floorf(*out / *in) : NAN;
}
__device__ __forceinline__ void operator()(at::Half* out, at::Half* in1, at::Half* in2) {
*out = *in2 != 0.f ? *in1 - *in2 * floorf(*in1 / *in2) : NAN;
}
};
template <typename T>
struct TensorCFmodOp {
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = *out % *in;
}
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
*out = *in1 % *in2;
}
};
template <>
struct TensorCFmodOp<float> {
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = fmodf(*out, *in);
}
__device__ __forceinline__ void operator()(float* out, float* in1, float* in2) {
*out = fmodf(*in1, *in2);
}
};
template <>
struct TensorCFmodOp<double> {
__device__ __forceinline__ void operator()(double* out, double* in) {
*out = fmod(*out, *in);
}
__device__ __forceinline__ void operator()(double* out, double* in1, double* in2) {
*out = fmod(*in1, *in2);
}
};
template <>
struct TensorCFmodOp<at::Half> {
__device__ __forceinline__ void operator()(at::Half* out, at::Half* in) {
*out = fmodf(*out, *in);
}
__device__ __forceinline__ void operator()(at::Half* out, at::Half* in1, at::Half* in2) {
*out = fmodf(*in1, *in2);
}
};
template <typename T>
struct TensorClampOp {
TensorClampOp(T min, T max) : minValue(min), maxValue(max) {}
__device__ __forceinline__ void operator()(T* out, T* in) {
T val = THCNumerics<T>::lt(*in, minValue) ? minValue : *in;
*out = THCNumerics<T>::gt(val, maxValue) ? maxValue : val;
}
__device__ __forceinline__ void operator()(T* v) {
T val = THCNumerics<T>::lt(*v, minValue) ? minValue : *v;
*v = THCNumerics<T>::gt(val, maxValue) ? maxValue : val;
}
const T minValue;
const T maxValue;
};
template <typename T>
struct TensorCrossOp {
TensorCrossOp(int64_t sx, int64_t sy, int64_t so) : sx(sx), sy(sy), so(so) {}
__device__ __forceinline__ void operator()(T* out, T* x, T*y) {
T val0 = THCNumerics<T>::sub(
THCNumerics<T>::mul(x[1 * sx], y[2 * sy]),
THCNumerics<T>::mul(x[2 * sx], y[1 * sy])
);
T val1 = THCNumerics<T>::sub(
THCNumerics<T>::mul(x[2 * sx], y[0 * sy]),
THCNumerics<T>::mul(x[0 * sx], y[2 * sy])
);
T val2 = THCNumerics<T>::sub(
THCNumerics<T>::mul(x[0 * sx], y[1 * sy]),
THCNumerics<T>::mul(x[1 * sx], y[0 * sy])
);
out[0 * so] = val0;
out[1 * so] = val1;
out[2 * so] = val2;
}
const int64_t sx, sy, so;
};
template <typename T>
struct TensorMaxOp {
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::gt(*out, *in) ? *out : *in;
}
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
*out = THCNumerics<T>::gt(*in1, *in2) ? *in1 : *in2;
}
};
template <typename T>
struct TensorMinOp {
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::lt(*out, *in) ? *out : *in;
}
__device__ __forceinline__ void operator()(T* out, T* in1, T* in2) {
*out = THCNumerics<T>::lt(*in1, *in2) ? *in1 : *in2;
}
};
template <typename T>
struct TensorMaxValueOp {
TensorMaxValueOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out) {
*out = THCNumerics<T>::lt(*out, val) ? val : *out; // this order propagates NaN
}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::lt(*in, val) ? val : *in; // this order propagates NaN
}
T val;
};
template <typename T>
struct TensorMinValueOp {
TensorMinValueOp(T v) : val(v) {}
__device__ __forceinline__ void operator()(T* out) {
*out = THCNumerics<T>::gt(*out, val) ? val : *out; // this order propagates NaN
}
__device__ __forceinline__ void operator()(T* out, T* in) {
*out = THCNumerics<T>::gt(*in, val) ? val : *in; // this order propagates NaN
}
T val;
};
#endif // THC_TENSORMATH_POINTWISE_CUH