Skip to content

Commit 14c3d7c

Browse files
committed
Merge remote-tracking branch 'cesium-ml/master'
2 parents d5e0d9b + 1436f95 commit 14c3d7c

File tree

3 files changed

+45
-4
lines changed

3 files changed

+45
-4
lines changed

cesium_app/handlers/prediction.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def _get_prediction(self, prediction_id):
2525
try:
2626
d = Prediction.get(Prediction.id == prediction_id)
2727
except Prediction.DoesNotExist:
28-
raise AccessError('No such dataset')
28+
raise AccessError('No such prediction')
2929

3030
if not d.is_owned_by(self.get_username()):
31-
raise AccessError('No such dataset')
31+
raise AccessError('No such prediction')
3232

3333
return d
3434

@@ -67,6 +67,9 @@ def post(self):
6767

6868
dataset_id = data['datasetID']
6969
model_id = data['modelID']
70+
# If only a subset of specified dataset is to be used, a list of the
71+
# corresponding time series file names can be provided
72+
ts_names = data.get('ts_names')
7073

7174
dataset = Dataset.get(Dataset.id == data["datasetID"])
7275
model = Model.get(Model.id == data["modelID"])
@@ -88,7 +91,15 @@ def post(self):
8891

8992
executor = yield self._get_executor()
9093

91-
all_time_series = executor.map(time_series.load, dataset.uris)
94+
# If only a subset of the dataset is to be used, get specified files
95+
if ts_names:
96+
ts_uris = [f.uri for f in dataset.files if os.path.basename(f.name)
97+
in ts_names or os.path.basename(f.name).split('.npz')[0]
98+
in ts_names]
99+
else:
100+
ts_uris = dataset.uris
101+
102+
all_time_series = executor.map(time_series.load, ts_uris)
92103
all_labels = executor.map(lambda ts: ts.label, all_time_series)
93104
all_features = executor.map(featurize.featurize_single_ts,
94105
all_time_series,

cesium_app/tests/frontend/test_predict.py

+31
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from os.path import join as pjoin
88
import numpy as np
99
import numpy.testing as npt
10+
from cesium_app.config import cfg
11+
import json
12+
import requests
1013
from cesium_app.tests.fixtures import (create_test_project, create_test_dataset,
1114
create_test_featureset, create_test_model,
1215
create_test_prediction)
@@ -204,3 +207,31 @@ def test_download_prediction_csv_regr(driver):
204207
[4, 3.1, 3.1]])
205208
finally:
206209
os.remove('/tmp/cesium_prediction_results.csv')
210+
211+
212+
def test_predict_specific_ts_name():
213+
with create_test_project() as p, create_test_dataset(p) as ds,\
214+
create_test_featureset(p) as fs, create_test_model(fs) as m:
215+
ts_data = [[1, 2, 3, 4], [32.2, 53.3, 32.3, 32.52], [0.2, 0.3, 0.6, 0.3]]
216+
impute_kwargs = {'strategy': 'constant', 'value': None}
217+
data = {'datasetID': ds.id,
218+
'ts_names': ['217801'],
219+
'modelID': m.id}
220+
response = requests.post('{}/predictions'.format(cfg['server']['url']),
221+
data=json.dumps(data)).json()
222+
assert response['status'] == 'success'
223+
224+
n_secs = 0
225+
while n_secs < 5:
226+
pred_info = requests.get('{}/predictions/{}'.format(
227+
cfg['server']['url'], response['data']['id'])).json()
228+
if pred_info['status'] == 'success' and pred_info['data']['finished']:
229+
assert isinstance(pred_info['data']['results']['217801']
230+
['features']['total_time'],
231+
float)
232+
assert 'Mira' in pred_info['data']['results']['217801']['prediction']
233+
break
234+
n_secs += 1
235+
time.sleep(1)
236+
else:
237+
raise Exception('test_predict_specific_ts_name timed out')

tools/watch_logs.py

-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def logs_from_config(supervisor_conf):
9696
with nostdout():
9797
from cesium_app.config import cfg
9898

99-
watched.append(cfg['paths']['err_log_path'])
10099
watched.append('log/error.log')
101100
watched.append('log/nginx-error.log')
102101

0 commit comments

Comments
 (0)