Skip to content

Commit dbe4de7

Browse files
authored
add dpcpp attribute for Tensor class (#107)
Signed-off-by: hongzhen <[email protected]>
1 parent 810b4ac commit dbe4de7

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

torch_patches/dpcpp-v1.5-rc3.patch

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,33 @@ index 8ce6045..ba7f79e 100644
4242
} else {
4343
dispatch_key = DispatchKey::SparseCPUTensorId;
4444
}
45+
diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h
46+
index 0b40acc..dd1db23 100644
47+
--- a/aten/src/ATen/templates/TensorBody.h
48+
+++ b/aten/src/ATen/templates/TensorBody.h
49+
@@ -420,6 +420,7 @@ class CAFFE2_API Tensor {
50+
Tensor cpu() const;
51+
Tensor cuda() const;
52+
Tensor hip() const;
53+
+ Tensor dpcpp() const;
54+
55+
// ~~~~~ Autograd API ~~~~~
56+
57+
diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h
58+
index 33983ec..754efd6 100644
59+
--- a/aten/src/ATen/templates/TensorMethods.h
60+
+++ b/aten/src/ATen/templates/TensorMethods.h
61+
@@ -42,6 +42,10 @@ inline Tensor Tensor::toType(ScalarType t) const {
62+
return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false);
63+
}
64+
65+
+inline Tensor Tensor::dpcpp() const {
66+
+ return to(options().device(DeviceType::DPCPP), /*non_blocking*/ false, /*copy*/ false);
67+
+}
68+
+
69+
// TODO: Deprecate me
70+
inline Tensor Tensor::toBackend(Backend b) const {
71+
return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false);
4572
diff --git a/c10/core/Backend.h b/c10/core/Backend.h
4673
index 5f3d8c7..a47240b 100644
4774
--- a/c10/core/Backend.h
@@ -394,8 +421,41 @@ index 9a4c9b3..6d02405 100644
394421
} else if (tid == DispatchKey::SparseCPUTensorId) {
395422
return DeviceType::CPU;
396423
} else if (tid == DispatchKey::SparseCUDATensorId) {
424+
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
425+
index 2a9dc9d..9410392 100644
426+
--- a/tools/autograd/templates/python_variable_methods.cpp
427+
+++ b/tools/autograd/templates/python_variable_methods.cpp
428+
@@ -369,6 +369,20 @@ static PyObject * THPVariable_cpu(PyObject* self, PyObject* args, PyObject* kwar
429+
END_HANDLE_TH_ERRORS
430+
}
431+
432+
+static PyObject * THPVariable_dpcpp(PyObject* self, PyObject* args, PyObject* kwargs)
433+
+{
434+
+ HANDLE_TH_ERRORS
435+
+ static PythonArgParser parser({
436+
+ "dpcpp(*, MemoryFormat? memory_format=None)"
437+
+ });
438+
+ auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
439+
+ ParsedArgs<1> parsed_args;
440+
+ auto r = parser.parse(args, kwargs, parsed_args);
441+
+ auto opt_memory_format = r.memoryformatOptional(0);
442+
+ return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::DPCPP), false, false, opt_memory_format));
443+
+ END_HANDLE_TH_ERRORS
444+
+}
445+
+
446+
static Tensor dispatch_nonzero(const Tensor & self) {
447+
pybind11::gil_scoped_release no_gil;
448+
OptionalDeviceGuard device_guard(device_of(self));
449+
@@ -871,6 +885,7 @@ PyMethodDef variable_methods[] = {
450+
{"copy_", (PyCFunction)(void(*)(void))THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL},
451+
{"cpu", (PyCFunction)(void(*)(void))THPVariable_cpu, METH_VARARGS | METH_KEYWORDS, NULL},
452+
{"cuda", (PyCFunction)(void(*)(void))THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
453+
+ {"dpcpp", (PyCFunction)(void(*)(void))THPVariable_dpcpp, METH_VARARGS | METH_KEYWORDS, NULL},
454+
{"data_ptr", (PyCFunction)THPVariable_data_ptr, METH_NOARGS, NULL},
455+
{"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL},
456+
{"has_names", (PyCFunction)THPVariable_has_names, METH_NOARGS, NULL},
397457
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
398-
index a6a9fca462..d42a05fd4a 100644
458+
index a6a9fca..d42a05f 100644
399459
--- a/torch/csrc/jit/passes/quantization.cpp
400460
+++ b/torch/csrc/jit/passes/quantization.cpp
401461
@@ -1718,7 +1718,7 @@ class FoldConvBatchNorm2dHelper {

0 commit comments

Comments
 (0)