Skip to content

Commit b927009

Browse files
authored
Feature request: PyTorch Tensors (#186)
1 parent dbdc076 commit b927009

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/inspectorscripts.ts

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ __pd = None
2828
__pyspark = None
2929
__tf = None
3030
__K = None
31+
__torch = None
3132
__ipywidgets = None
3233
3334
3435
def _check_imported():
35-
global __np, __pd, __pyspark, __tf, __K, __ipywidgets
36+
global __np, __pd, __pyspark, __tf, __K, __torch, __ipywidgets
3637
37-
if 'numpy' in sys.modules:
38+
if '
39+
' in sys.modules:
3840
# don't really need the try
3941
import numpy as __np
4042
@@ -55,6 +57,9 @@ def _check_imported():
5557
except ImportError:
5658
__K = None
5759
60+
if 'torch' in sys.modules:
61+
import torch as __torch
62+
5863
if 'ipywidgets' in sys.modules:
5964
import ipywidgets as __ipywidgets
6065
@@ -66,6 +71,8 @@ def _jupyterlab_variableinspector_getsizeof(x):
6671
return "?"
6772
elif __tf and isinstance(x, __tf.Variable):
6873
return "?"
74+
elif __torch and isinstance(x, __torch.Tensor):
75+
return x.element_size() * x.nelement()
6976
elif __pd and type(x).__name__ == 'DataFrame':
7077
return x.memory_usage().sum()
7178
else:
@@ -88,6 +95,9 @@ def _jupyterlab_variableinspector_getshapeof(x):
8895
if __tf and isinstance(x, __tf.Tensor):
8996
shape = " x ".join([str(int(i)) for i in x.shape])
9097
return "%s" % shape
98+
if __torch and isinstance(x, __torch.Tensor):
99+
shape = " x ".join([str(int(i)) for i in x.shape])
100+
return "%s" % shape
91101
if isinstance(x, list):
92102
return "%s" % len(x)
93103
if isinstance(x, dict):
@@ -129,6 +139,8 @@ def _jupyterlab_variableinspector_is_matrix(x):
129139
return True
130140
if __tf and isinstance(x, __tf.Tensor) and len(x.shape) <= 2:
131141
return True
142+
if __torch and isinstance(x, __torch.Tensor) and len(x.shape) <= 2:
143+
return True
132144
if isinstance(x, list):
133145
return True
134146
return False
@@ -153,7 +165,7 @@ def _jupyterlab_variableinspector_dict_list():
153165
return True
154166
if str(obj)[0] == "<":
155167
return False
156-
if v in ['__np', '__pd', '__pyspark', '__tf', '__K', '__ipywidgets']:
168+
if v in ['__np', '__pd', '__pyspark', '__tf', '__K', '__torch', '__ipywidgets']:
157169
return obj is not None
158170
if str(obj).startswith("_Feature"):
159171
# removes tf/keras objects
@@ -199,6 +211,9 @@ def _jupyterlab_variableinspector_getmatrixcontent(x, max_rows=10000):
199211
elif __tf and (isinstance(x, __tf.Variable) or isinstance(x, __tf.Tensor)):
200212
df = __K.get_value(x)
201213
return _jupyterlab_variableinspector_getmatrixcontent(df)
214+
elif __torch and __pd and isinstance(x, torch.Tensor):
215+
df = x.cpu().numpy()
216+
return _jupyterlab_variableinspector_getmatrixcontent(df)
202217
elif isinstance(x, list):
203218
s = __pd.Series(x)
204219
return _jupyterlab_variableinspector_getmatrixcontent(s)

0 commit comments

Comments
 (0)