1
+ from itertools import cycle
1
2
import numpy as np
2
- import pandas as pd
3
- from sklearn .metrics import confusion_matrix
4
- import plotly
5
- import plotly .offline as py
6
- from plotly .tools import FigureFactory as FF
7
-
8
3
from cesium import featurize
9
- from .config import cfg
4
+ from bokeh .plotting import figure
5
+ from bokeh .layouts import gridplot
6
+ from bokeh .palettes import Viridis as palette
7
+ from bokeh .core .json_encoder import serialize_json
8
+ from bokeh .document import Document
9
+ from bokeh .util .serialization import make_id
10
10
11
11
12
12
def feature_scatterplot (fset_path , features_to_plot ):
@@ -21,42 +21,39 @@ def feature_scatterplot(fset_path, features_to_plot):
21
21
22
22
Returns
23
23
-------
24
- (fig.data, fig.layout)
25
- Returns (fig.data, fig.layout) where `fig` is an instance of
26
- `plotly.tools.FigureFactory`.
24
+ (str, str)
25
+ Returns (docs_json, render_items) json for the desired plot.
27
26
"""
28
27
fset , data = featurize .load_featureset (fset_path )
29
28
fset = fset [features_to_plot ]
30
-
31
- if 'label' in data :
32
- fset ['label' ] = data ['label' ]
33
- index = 'label'
34
- else :
35
- index = None
36
-
37
- # TODO replace 'trace {i}' with class labels
38
- fig = FF .create_scatterplotmatrix (fset , diag = 'box' , index = index ,
39
- height = 800 , width = 800 )
40
-
41
- py .plot (fig , auto_open = False , output_type = 'div' )
42
-
43
- return fig .data , fig .layout
44
-
45
-
46
- #def prediction_heatmap(pred_path):
47
- # with xr.open_dataset(pred_path) as pset:
48
- # pred_df = pd.DataFrame(pset.prediction.values, index=pset.name,
49
- # columns=pset.class_label.values)
50
- # pred_labels = pred_df.idxmax(axis=1)
51
- # C = confusion_matrix(pset.label, pred_labels)
52
- # row_sums = C.sum(axis=1)
53
- # C = C / row_sums[:, np.newaxis]
54
- # fig = FF.create_annotated_heatmap(C, x=[str(el) for el in
55
- # pset.class_label.values],
56
- # y=[str(el) for el in
57
- # pset.class_label.values],
58
- # colorscale='Viridis')
59
- #
60
- # py.plot(fig, auto_open=False, output_type='div')
61
- #
62
- # return fig.data, fig.layout
29
+ colors = cycle (palette [5 ])
30
+ plots = np .array ([[figure (width = 300 , height = 200 )
31
+ for j in range (len (features_to_plot ))]
32
+ for i in range (len (features_to_plot ))])
33
+
34
+ for (j , i ), p in np .ndenumerate (plots ):
35
+ if (j == i == 0 ):
36
+ p .title .text = "Scatterplot matrix"
37
+ p .circle (fset .values [:,i ], fset .values [:,j ], color = next (colors ))
38
+ p .xaxis .minor_tick_line_color = None
39
+ p .yaxis .minor_tick_line_color = None
40
+ p .ygrid [0 ].ticker .desired_num_ticks = 2
41
+ p .xgrid [0 ].ticker .desired_num_ticks = 4
42
+ p .outline_line_color = None
43
+ p .axis .visible = None
44
+
45
+ plot = gridplot (plots .tolist (), ncol = len (features_to_plot ), mergetools = True , responsive = True , title = "Test" )
46
+
47
+ # Convert plot to json objects necessary for rendering with bokeh on the
48
+ # frontend
49
+ render_items = [{'docid' : plot ._id , 'elementid' : make_id ()}]
50
+
51
+ doc = Document ()
52
+ doc .add_root (plot )
53
+ docs_json_inner = doc .to_json ()
54
+ docs_json = {render_items [0 ]['docid' ]: docs_json_inner }
55
+
56
+ docs_json = serialize_json (docs_json )
57
+ render_items = serialize_json (render_items )
58
+
59
+ return docs_json , render_items
0 commit comments