Skip to content

Commit a7e8bc9

Browse files
committed
Update calls to save/load featureset to newer cesium API
1 parent 1436f95 commit a7e8bc9

File tree

3 files changed

+29
-23
lines changed

3 files changed

+29
-23
lines changed

cesium_app/handlers/prediction.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,15 @@ def post(self):
115115
model_or_gridcv)
116116
preds = executor.submit(lambda fset, model: model.predict(fset),
117117
imputed_fset, model_data)
118-
pred_probs = executor.submit(lambda fset, model: model.predict_proba(fset)
118+
pred_probs = executor.submit(lambda fset, model:
119+
pd.DataFrame(model.predict_proba(fset),
120+
index=fset.index,
121+
columns=model.classes_)
119122
if hasattr(model, 'predict_proba') else [],
120123
imputed_fset, model_data)
121-
all_classes = executor.submit(lambda model: model.classes_
122-
if hasattr(model, 'classes_') else [],
123-
model_data)
124124
future = executor.submit(featurize.save_featureset, imputed_fset,
125125
pred_path, labels=all_labels, preds=preds,
126-
pred_probs=pred_probs, all_classes=all_classes)
126+
pred_probs=pred_probs)
127127

128128
prediction.task_id = future.key
129129
prediction.save()
@@ -182,9 +182,12 @@ def post(self):
182182
features_to_use=features_to_use,
183183
meta_features=meta_feats)
184184
fset = featurize.impute_featureset(fset, **impute_kwargs)
185-
data = {'preds': model_data.predict(fset),
186-
'all_classes': model_data.classes_}
185+
data = {'preds': model_data.predict(fset)}
187186
if hasattr(model_data, 'predict_proba'):
188-
data['pred_probs'] = model_data.predict_proba(fset)
187+
data['pred_probs'] = pd.DataFrame(model_data.predict_proba(fset),
188+
index=fset.index,
189+
columns=model_data.classes_)
190+
else:
191+
data['pred_probs'] = []
189192
pred_info = Prediction.format_pred_data(fset, data)
190193
return self.success(pred_info)

cesium_app/models.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55
import time
6-
import numpy as np
6+
import pandas as pd
77

88
import peewee as pw
99
from playhouse.postgres_ext import ArrayField, BinaryJSONField
@@ -73,6 +73,7 @@ class File(BaseModel):
7373
name = pw.CharField(null=True)
7474
created = pw.DateTimeField(default=datetime.datetime.now)
7575

76+
7677
@signals.post_delete(sender=File)
7778
def remove_file_after_delete(sender, instance):
7879
try:
@@ -135,6 +136,7 @@ class Meta:
135136
(('dataset', 'file'), True),
136137
)
137138

139+
138140
@signals.pre_delete(sender=Dataset)
139141
def remove_related_files(sender, instance):
140142
for f in instance.files:
@@ -148,7 +150,7 @@ class Featureset(BaseModel):
148150
name = pw.CharField()
149151
created = pw.DateTimeField(default=datetime.datetime.now)
150152
features_list = ArrayField(pw.CharField)
151-
custom_features_script = pw.CharField(null=True) # move to fset file?
153+
custom_features_script = pw.CharField(null=True) # move to fset file?
152154
file = pw.ForeignKeyField(File, on_delete='CASCADE')
153155
task_id = pw.CharField(null=True)
154156
finished = pw.DateTimeField(null=True)
@@ -194,16 +196,15 @@ def is_owned_by(self, username):
194196
def format_pred_data(fset, data):
195197
fset.columns = fset.columns.droplevel('channel')
196198
fset.index = fset.index.astype(str) # can't use ints as JSON keys
197-
result = {}
198-
for i, name in enumerate(fset.index):
199-
result[name] = {'features': fset.loc[name].to_dict()}
200-
if 'labels' in data:
201-
result[name]['label'] = data['labels'][i]
202-
if len(data['pred_probs']) > 0:
203-
result[name]['prediction'] = dict(zip(data['all_classes'],
204-
data['pred_probs'][i]))
205-
else:
206-
result[name]['prediction'] = data['preds'][i]
199+
labels = pd.Series(data.get('labels'), index=fset.index)
200+
if len(data.get('pred_probs', [])) > 0:
201+
preds = pd.DataFrame(data.get('pred_probs', []),
202+
index=fset.index).to_dict(orient='index')
203+
else:
204+
preds = pd.Series(data['preds'], index=fset.index).to_dict()
205+
result = {name: {'features': feats, 'label': labels.loc[name],
206+
'prediction': preds[name]}
207+
for name, feats in fset.to_dict(orient='index').items()}
207208
return result
208209

209210
def display_info(self):
@@ -238,6 +239,7 @@ def create_tables(retry=5):
238239
print('Could not connect to database...sleeping 5')
239240
time.sleep(5)
240241

242+
241243
def drop_tables():
242244
db.drop_tables(models, safe=True, cascade=True)
243245

cesium_app/tests/fixtures.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import peewee
1515
import datetime
1616
import joblib
17+
import pandas as pd
1718

1819

1920
@contextmanager
@@ -160,14 +161,14 @@ def create_test_prediction(dataset, model):
160161
if hasattr(model_data, 'best_estimator_'):
161162
model_data = model_data.best_estimator_
162163
preds = model_data.predict(fset)
163-
pred_probs = (model_data.predict_proba(fset)
164+
pred_probs = (pd.DataFrame(model_data.predict_proba(fset),
165+
index=fset.index, columns=model_data.classes_)
164166
if hasattr(model_data, 'predict_proba') else [])
165167
all_classes = model_data.classes_ if hasattr(model_data, 'classes_') else []
166168
pred_path = pjoin(cfg['paths']['predictions_folder'],
167169
'{}.npz'.format(str(uuid.uuid4())))
168170
featurize.save_featureset(fset, pred_path, labels=data['labels'],
169-
preds=preds, pred_probs=pred_probs,
170-
all_classes=all_classes)
171+
preds=preds, pred_probs=pred_probs)
171172
f, created = m.File.get_or_create(uri=pred_path)
172173
pred = m.Prediction.create(file=f, dataset=dataset, project=dataset.project,
173174
model=model, finished=datetime.datetime.now())

0 commit comments

Comments
 (0)