forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSparseCsrTensorImpl.cpp
153 lines (141 loc) · 4.98 KB
/
SparseCsrTensorImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/core/LegacyTypeDispatch.h>
namespace at {
namespace {
DeviceType SparseCsrTensorSetToDeviceType(DispatchKeySet key_set) {
if (key_set.has(DispatchKey::SparseCsrCPU)) {
return kCPU;
} else if (key_set.has(DispatchKey::SparseCsrCUDA)) {
return kCUDA;
} else {
TORCH_CHECK(false,
"Cannot construct SparseCsrTensor with non-sparse tensor type ID ",
key_set);
}
}
} // namespace
SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type)
: SparseCsrTensorImpl(
key_set,
data_type,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(ScalarType::Int)) // crow_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(ScalarType::Int)) // col_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.dtype(data_type)) // values
) {}
SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
at::Tensor crow_indices,
at::Tensor col_indices,
at::Tensor values)
: TensorImpl(key_set, data_type, values.device()),
crow_indices_(std::move(crow_indices)),
col_indices_(std::move(col_indices)),
values_(std::move(values)) {}
void SparseCsrTensorImpl::resize_and_clear_(
const int64_t nnz_size,
IntArrayRef size) {
// call crow_indices().options() here since the struct contructor calls the
// tensor constructor with args for device specific init.
auto empty_crow_indices = at::empty(size[0] + 1, crow_indices().options());
auto empty_col_indices = at::empty(nnz_size, col_indices().options());
auto empty_values = at::empty(nnz_size, values().options());
crow_indices_ = empty_crow_indices;
col_indices_ = empty_col_indices;
values_ = empty_values;
sizes_and_strides_.set_sizes(size);
}
void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
crow_indices_ = at::empty_like(
src.crow_indices(),
src.crow_indices().options(),
src.crow_indices().suggest_memory_format());
col_indices_ = at::empty_like(
src.col_indices(),
src.col_indices().options(),
src.col_indices().suggest_memory_format());
values_ = at::empty_like(
src.values(),
src.values().options(),
src.values().suggest_memory_format());
sizes_and_strides_.set_sizes(src.sizes());
}
void SparseCsrTensorImpl::set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values) {
auto crow_indices_type = crow_indices.scalar_type();
auto col_indices_type = col_indices.scalar_type();
TORCH_CHECK(
crow_indices_type == col_indices_type,
"both crow_indices and col_indices should have the same type.");
TORCH_CHECK(
crow_indices_type == kInt || crow_indices_type == kLong,
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
crow_indices_type);
TORCH_CHECK(
values.scalar_type() == typeMetaToScalarType(dtype()),
"dtype of values (",
values.scalar_type(),
") must match dtype of sparse tensor (",
typeMetaToScalarType(dtype()),
")");
TORCH_CHECK(
col_indices.layout() == kStrided,
"expected col_indices to be a strided tensor, but got indices of layout ",
col_indices.layout());
TORCH_CHECK(
crow_indices.layout() == kStrided,
"expected crow_indices to be a strided tensor, but got crow_indices of layout ",
crow_indices.layout());
TORCH_CHECK(
values.layout() == kStrided && values.is_contiguous(),
"expected values to be a strided and contiguous tensor, but got values of layout ",
values.layout());
TORCH_CHECK(
values.device().type() == device().type(),
"device type of values (",
values.device().type(),
") must match device type of device().type()",
device().type(),
")");
TORCH_CHECK(
values.is_cuda() || col_indices.get_device() == crow_indices.get_device(),
"crow_indices and col_indices devices (",
crow_indices.get_device(),
", ",
col_indices.get_device(),
") must match with the (non-cuda) device of values (",
values.get_device(),
")");
TORCH_CHECK(
col_indices.size(0) == values.size(0),
"col_indices and values must have equal sizes, but got col_indices.size(0): ",
col_indices.size(0),
", values.size(0): ",
values.size(0));
crow_indices_ = crow_indices;
col_indices_ = col_indices;
values_ = values;
}
} // namespace at