-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
34 lines (26 loc) · 1.51 KB
/
predict.py
File metadata and controls
34 lines (26 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from numpy import ndarray
from pandas import DataFrame
from xgboost import XGBClassifier
from .config import COLUMNS, OPERATIONS, WINDOW_SIZES, DATAPOINT_RESOLUTION
from .logger import LOGGER
from .windowing.windowing import window_stay
from .data_model.stay import Stay
def predict(model: XGBClassifier, stay: Stay) -> DataFrame:
"""With given model predict probabilities and label for every (windowable) datapoint in timeseries. Uses some parameters in config.
Args:
model (XGBClassifier): Model used for making predictions.
stay (Stay): Stay to predict.
Returns:
DataFrame: DataFrame with 'prediction' column, where '1' means positive sepsis prediction and '0' means negative sepsis prediction, and 'probabilities' where the model probabilities are shown.
The index of the dataframe is the timestamp of prediction.
"""
# First window the stay
windowed_stay: Stay|None = window_stay(stay, COLUMNS, OPERATIONS, WINDOW_SIZES, DATAPOINT_RESOLUTION)
if window_stay is None:
LOGGER.error("Invalid stay for windowing! Is stay long enough?")
# For every window predict probability and label
stay.time_series.to_numpy('float32')
stay_probas: ndarray = model.predict_proba(stay.time_series.to_numpy('float32'))[:, 1]
stay_predictions: ndarray = model.predict(stay.time_series.to_numpy('float32'))[:, 1]
df_predictions: DataFrame = DataFrame({'prediction': stay_predictions, 'probabilities': stay_probas}, index=stay.time_series.index)
return df_predictions