-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_get_device.py
31 lines (26 loc) · 1005 Bytes
/
test_get_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tfdlpack
import tensorflow as tf
import torch as th
import numpy as np
from torch.utils.dlpack import from_dlpack, to_dlpack
types = [np.float16, np.float32, np.float32,
np.int8, np.int16, np.int32, np.int64]
devices = {
1: lambda t: t.cpu(),
2: lambda t: t.cuda(0),
}
def test_get_op():
for np_type in types:
for kDLContext, th_device in devices.items():
np_array = np.array([1, 2, 3], dtype=np_type)
th_tensor = th_device(th.tensor(np_array))
dl_cap = to_dlpack(th_tensor)
tf_device_and_dtype = tfdlpack.get_device_and_dtype(dl_cap)
device_id = th_tensor.device.index
device_id = 0 if device_id is None else device_id
assert kDLContext == tf_device_and_dtype[0].item()
assert device_id == tf_device_and_dtype[1].item()
assert tf.DType(
tf_device_and_dtype[2].item()).as_numpy_dtype == np_type
if __name__ == "__main__":
test_get_op()