@@ -28,13 +28,15 @@ __pd = None
28
28
__pyspark = None
29
29
__tf = None
30
30
__K = None
31
+ __torch = None
31
32
__ipywidgets = None
32
33
33
34
34
35
def _check_imported():
35
- global __np, __pd, __pyspark, __tf, __K, __ipywidgets
36
+ global __np, __pd, __pyspark, __tf, __K, __torch, __ipywidgets
36
37
37
- if 'numpy' in sys.modules:
38
+ if '
39
+ ' in sys.modules:
38
40
# don't really need the try
39
41
import numpy as __np
40
42
@@ -55,6 +57,9 @@ def _check_imported():
55
57
except ImportError:
56
58
__K = None
57
59
60
+ if 'torch' in sys.modules:
61
+ import torch as __torch
62
+
58
63
if 'ipywidgets' in sys.modules:
59
64
import ipywidgets as __ipywidgets
60
65
@@ -66,6 +71,8 @@ def _jupyterlab_variableinspector_getsizeof(x):
66
71
return "?"
67
72
elif __tf and isinstance(x, __tf.Variable):
68
73
return "?"
74
+ elif __torch and isinstance(x, __torch.Tensor):
75
+ return x.element_size() * x.nelement()
69
76
elif __pd and type(x).__name__ == 'DataFrame':
70
77
return x.memory_usage().sum()
71
78
else:
@@ -88,6 +95,9 @@ def _jupyterlab_variableinspector_getshapeof(x):
88
95
if __tf and isinstance(x, __tf.Tensor):
89
96
shape = " x ".join([str(int(i)) for i in x.shape])
90
97
return "%s" % shape
98
+ if __torch and isinstance(x, __torch.Tensor):
99
+ shape = " x ".join([str(int(i)) for i in x.shape])
100
+ return "%s" % shape
91
101
if isinstance(x, list):
92
102
return "%s" % len(x)
93
103
if isinstance(x, dict):
@@ -129,6 +139,8 @@ def _jupyterlab_variableinspector_is_matrix(x):
129
139
return True
130
140
if __tf and isinstance(x, __tf.Tensor) and len(x.shape) <= 2:
131
141
return True
142
+ if __torch and isinstance(x, __torch.Tensor) and len(x.shape) <= 2:
143
+ return True
132
144
if isinstance(x, list):
133
145
return True
134
146
return False
@@ -153,7 +165,7 @@ def _jupyterlab_variableinspector_dict_list():
153
165
return True
154
166
if str(obj)[0] == "<":
155
167
return False
156
- if v in ['__np', '__pd', '__pyspark', '__tf', '__K', '__ipywidgets']:
168
+ if v in ['__np', '__pd', '__pyspark', '__tf', '__K', '__torch', ' __ipywidgets']:
157
169
return obj is not None
158
170
if str(obj).startswith("_Feature"):
159
171
# removes tf/keras objects
@@ -199,6 +211,9 @@ def _jupyterlab_variableinspector_getmatrixcontent(x, max_rows=10000):
199
211
elif __tf and (isinstance(x, __tf.Variable) or isinstance(x, __tf.Tensor)):
200
212
df = __K.get_value(x)
201
213
return _jupyterlab_variableinspector_getmatrixcontent(df)
214
+ elif __torch and __pd and isinstance(x, torch.Tensor):
215
+ df = x.cpu().numpy()
216
+ return _jupyterlab_variableinspector_getmatrixcontent(df)
202
217
elif isinstance(x, list):
203
218
s = __pd.Series(x)
204
219
return _jupyterlab_variableinspector_getmatrixcontent(s)
0 commit comments