Skip to content

Commit 220c1a9

Browse files
committed
Merge branch 'master' of https://github.com/cesium-ml/cesium_web into add_project
2 parents 3a5ec67 + 993cb7e commit 220c1a9

File tree

13 files changed

+109
-121
lines changed

13 files changed

+109
-121
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ ghostdriver.log
99
*.swo
1010
__pycache__/
1111
node_modules/
12+
*.pyc

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ webpack = ./node_modules/.bin/webpack
88

99

1010
dependencies:
11-
@./tools/silent_monitor.py ./tools/install_deps.py requirements.txt
11+
@./tools/silent_monitor.py pip install -r requirements.txt
1212
@./tools/silent_monitor.py ./tools/check_js_deps.sh
1313

1414
db_init:

cesium_app/handlers/plot_features.py

+6-17
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,9 @@
44

55

66
class PlotFeaturesHandler(BaseHandler):
7-
def _get_featureset(self, featureset_id):
8-
try:
9-
f = Featureset.get(Featureset.id == featureset_id)
10-
except Featureset.DoesNotExist:
11-
raise AccessError('No such feature set')
12-
13-
if not f.is_owned_by(self.get_username()):
14-
raise AccessError('No such feature set')
15-
16-
return f
17-
18-
def get(self, featureset_id=None):
19-
fset = self._get_featureset(featureset_id)
20-
features_to_plot = sorted(fset.features_list)[0:4]
21-
data, layout = plot.feature_scatterplot(fset.file.uri, features_to_plot)
22-
23-
self.success({'data': data, 'layout': layout})
7+
def get(self, featureset_id):
8+
fset = Featureset.get_if_owned(featureset_id, self.get_username())
9+
features_to_plot = sorted(fset.features_list)[0:4] # TODO from form
10+
docs_json, render_items = plot.feature_scatterplot(fset.file.uri,
11+
features_to_plot)
12+
self.success({'docs_json': docs_json, 'render_items': render_items})

cesium_app/handlers/prediction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def get(self, prediction_id=None, action=None):
141141
'label': data['labels'],
142142
'prediction': data['preds']},
143143
columns=['ts_name', 'label', 'prediction'])
144-
if data.get('pred_probs'):
145-
result['probability'] = np.max(data['pred_probs'], axis=1)
144+
if len(data.get('pred_probs', [])) > 0:
145+
result['probability'] = data['pred_probs'].max(axis=1).values
146146
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
147147
self.set_header("Content-Disposition", "attachment; "
148148
"filename=cesium_prediction_results.csv")

cesium_app/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ class Featureset(BaseModel):
158158
def is_owned_by(self, username):
159159
return self.project.is_owned_by(username)
160160

161+
@staticmethod
162+
def get_if_owned(fset_id, username):
163+
try:
164+
f = Featureset.get(Featureset.id == fset_id)
165+
except Featureset.DoesNotExist:
166+
raise AccessError('No such feature set')
167+
168+
if not f.is_owned_by(username):
169+
raise AccessError('No such feature set')
170+
171+
return f
172+
161173

162174
class Model(BaseModel):
163175
"""ORM model of the Model table"""

cesium_app/plot.py

+40-43
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
from itertools import cycle
12
import numpy as np
2-
import pandas as pd
3-
from sklearn.metrics import confusion_matrix
4-
import plotly
5-
import plotly.offline as py
6-
from plotly.tools import FigureFactory as FF
7-
83
from cesium import featurize
9-
from .config import cfg
4+
from bokeh.plotting import figure
5+
from bokeh.layouts import gridplot
6+
from bokeh.palettes import Viridis as palette
7+
from bokeh.core.json_encoder import serialize_json
8+
from bokeh.document import Document
9+
from bokeh.util.serialization import make_id
1010

1111

1212
def feature_scatterplot(fset_path, features_to_plot):
@@ -21,42 +21,39 @@ def feature_scatterplot(fset_path, features_to_plot):
2121
2222
Returns
2323
-------
24-
(fig.data, fig.layout)
25-
Returns (fig.data, fig.layout) where `fig` is an instance of
26-
`plotly.tools.FigureFactory`.
24+
(str, str)
25+
Returns (docs_json, render_items) json for the desired plot.
2726
"""
2827
fset, data = featurize.load_featureset(fset_path)
2928
fset = fset[features_to_plot]
30-
31-
if 'label' in data:
32-
fset['label'] = data['label']
33-
index = 'label'
34-
else:
35-
index = None
36-
37-
# TODO replace 'trace {i}' with class labels
38-
fig = FF.create_scatterplotmatrix(fset, diag='box', index=index,
39-
height=800, width=800)
40-
41-
py.plot(fig, auto_open=False, output_type='div')
42-
43-
return fig.data, fig.layout
44-
45-
46-
#def prediction_heatmap(pred_path):
47-
# with xr.open_dataset(pred_path) as pset:
48-
# pred_df = pd.DataFrame(pset.prediction.values, index=pset.name,
49-
# columns=pset.class_label.values)
50-
# pred_labels = pred_df.idxmax(axis=1)
51-
# C = confusion_matrix(pset.label, pred_labels)
52-
# row_sums = C.sum(axis=1)
53-
# C = C / row_sums[:, np.newaxis]
54-
# fig = FF.create_annotated_heatmap(C, x=[str(el) for el in
55-
# pset.class_label.values],
56-
# y=[str(el) for el in
57-
# pset.class_label.values],
58-
# colorscale='Viridis')
59-
#
60-
# py.plot(fig, auto_open=False, output_type='div')
61-
#
62-
# return fig.data, fig.layout
29+
colors = cycle(palette[5])
30+
plots = np.array([[figure(width=300, height=200)
31+
for j in range(len(features_to_plot))]
32+
for i in range(len(features_to_plot))])
33+
34+
for (j, i), p in np.ndenumerate(plots):
35+
if (j == i == 0):
36+
p.title.text = "Scatterplot matrix"
37+
p.circle(fset.values[:,i], fset.values[:,j], color=next(colors))
38+
p.xaxis.minor_tick_line_color = None
39+
p.yaxis.minor_tick_line_color = None
40+
p.ygrid[0].ticker.desired_num_ticks = 2
41+
p.xgrid[0].ticker.desired_num_ticks = 4
42+
p.outline_line_color = None
43+
p.axis.visible = None
44+
45+
plot = gridplot(plots.tolist(), ncol=len(features_to_plot), mergetools=True, responsive=True, title="Test")
46+
47+
# Convert plot to json objects necessary for rendering with bokeh on the
48+
# frontend
49+
render_items = [{'docid': plot._id, 'elementid': make_id()}]
50+
51+
doc = Document()
52+
doc.add_root(plot)
53+
docs_json_inner = doc.to_json()
54+
docs_json = {render_items[0]['docid']: docs_json_inner}
55+
56+
docs_json = serialize_json(docs_json)
57+
render_items = serialize_json(render_items)
58+
59+
return docs_json, render_items

cesium_app/tests/frontend/test_features.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_plot_features(driver):
167167
driver.find_element_by_xpath("//b[contains(text(),'Please wait while we load your plotting data...')]")
168168

169169
driver.implicitly_wait(3)
170-
driver.find_element_by_css_selector("[class=svg-container]")
170+
driver.find_element_by_css_selector("[class=bk-plotdiv]")
171171

172172

173173
def test_delete_featureset(driver):

cesium_app/tests/frontend/test_predict.py

+21
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from os.path import join as pjoin
88
import numpy as np
99
import numpy.testing as npt
10+
import pandas as pd
1011
from cesium_app.config import cfg
1112
import json
1213
import requests
@@ -185,6 +186,26 @@ def test_download_prediction_csv_class(driver):
185186
os.remove('/tmp/cesium_prediction_results.csv')
186187

187188

189+
def test_download_prediction_csv_class_prob(driver):
190+
driver.get('/')
191+
with create_test_project() as p, create_test_dataset(p) as ds,\
192+
create_test_featureset(p) as fs,\
193+
create_test_model(fs, model_type='RandomForestClassifier') as m,\
194+
create_test_prediction(ds, m):
195+
_click_download(p.id, driver)
196+
assert os.path.exists('/tmp/cesium_prediction_results.csv')
197+
try:
198+
result = pd.read_csv('/tmp/cesium_prediction_results.csv')
199+
npt.assert_array_equal(result.ts_name, np.arange(5))
200+
npt.assert_array_equal(result.label, ['Mira', 'Classical_Cepheid',
201+
'Mira', 'Classical_Cepheid',
202+
'Mira'])
203+
npt.assert_array_equal(result.label, result.prediction)
204+
assert (result.probability >= 0.0).all()
205+
finally:
206+
os.remove('/tmp/cesium_prediction_results.csv')
207+
208+
188209
def test_download_prediction_csv_regr(driver):
189210
driver.get('/')
190211
with create_test_project() as p, create_test_dataset(p, label_type='regr') as ds,\

package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
"test": "eslint -c .eslintrc --ext .jsx,.js public/scripts/ && make test"
77
},
88
"dependencies": {
9+
"bokehjs": "^0.12.5",
910
"bootstrap": "^3.3.7",
1011
"bootstrap-css": "^3.0.0",
1112
"css-loader": "^0.26.2",
1213
"exports-loader": "^0.6.4",
1314
"imports-loader": "^0.7.1",
1415
"jquery": "^3.1.1",
15-
"plotly.js": "^1.23.1",
1616
"react": "^15.1.0",
1717
"react-dom": "^15.1.0",
1818
"react-redux": "^5.0.3",
@@ -23,6 +23,7 @@
2323
"redux-logger": "^2.8.1",
2424
"redux-thunk": "^2.2.0",
2525
"style-loader": "^0.13.2",
26+
"typescript": "^2.2.2",
2627
"webpack": "^2.2.1",
2728
"webpack-dev-server": "^2.4.1",
2829
"whatwg-fetch": "^2.0.2"

public/scripts/Plot.jsx

+22-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
11
import React, { Component } from 'react';
22
import { connect } from 'react-redux';
3-
import Plotly from './custom-plotly';
43
import { showNotification } from './Notifications';
4+
import "../../node_modules/bokehjs/build/js/bokeh.js";
5+
import "../../node_modules/bokehjs/build/css/bokeh.css";
56

7+
function bokeh_render_plot(node, docs_json, render_items) {
8+
// Create bokeh div element
9+
var bokeh_div = document.createElement("div");
10+
var inner_div = document.createElement("div");
11+
bokeh_div.setAttribute("class", "bk-root" );
12+
inner_div.setAttribute("class", "bk-plotdiv");
13+
inner_div.setAttribute("id", render_items[0].elementid);
14+
bokeh_div.appendChild(inner_div);
15+
node.appendChild(bokeh_div);
16+
17+
// Generate plot
18+
Bokeh.safely(function() {
19+
Bokeh.embed.embed_items(docs_json, render_items);
20+
});
21+
}
622

723
class Plot extends Component {
824
constructor(props) {
@@ -32,16 +48,17 @@ class Plot extends Component {
3248
if (!plotData) {
3349
return <b>Please wait while we load your plotting data...</b>;
3450
}
35-
36-
let { data, layout } = plotData;
51+
var docs_json = JSON.parse(plotData.docs_json);
52+
var render_items = JSON.parse(plotData.render_items);
3753

3854
return (
3955
plotData &&
4056
<div
4157
ref={
4258
(node) => {
43-
node && Plotly.plot(node, data, layout);
44-
}}
59+
node && bokeh_render_plot(node, docs_json, render_items)
60+
}
61+
}
4562
/>
4663
);
4764
}

public/scripts/custom-plotly.js

-13
This file was deleted.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ pyyaml
66
tornado
77
pyzmq
88
pyjwt
9-
plotly>=2.0.5
109
simplejson
1110
distributed>=1.14.3
1211
selenium
1312
pytest
1413
joblib>=0.11
14+
bokeh==0.12.5

tools/install_deps.py

-37
This file was deleted.

0 commit comments

Comments
 (0)