Skip to content

Commit 6364b71

Browse files
committed
initial commit
0 parents  commit 6364b71

11 files changed

+450
-0
lines changed

Eval.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score, accuracy_score, precision_recall_curve
2+
import numpy as np
3+
4+
class Eval:
5+
def __init__(self, pred, gold):
6+
self.pred = pred
7+
self.gold = gold
8+
9+
def Metrics(self, metics):
10+
if metics == 'all':
11+
auc = roc_auc_score(self.gold, self.pred)
12+
aupr = average_precision_score(self.gold, self.pred)
13+
recalls, precisions, thresholds_pr = precision_recall_curve(self.gold, self.pred)
14+
f1s = (2 * np.multiply(precisions, recalls)) / np.add(precisions, recalls)
15+
f1s = np.nan_to_num(f1s)
16+
max_idx = int(np.argmax(f1s))
17+
precision = precisions[max_idx]
18+
recall = recalls[max_idx]
19+
f1 = f1s[max_idx]
20+
threshold = thresholds_pr[max_idx]
21+
y_scores_label = np.copy(self.pred)
22+
y_scores_label = np.where(y_scores_label > threshold, 1, 0)
23+
y_scores_label = y_scores_label.astype(int)
24+
accuracy = accuracy_score(self.gold, y_scores_label)
25+
return np.array([auc, aupr, precision, recall, accuracy, f1])
26+
elif metics == 'specificity-sensitivity':
27+
# recall, sensitivity: true positive rate tp/(tp+fn)
28+
# specificity: true negative rate tn/(tn+fp)
29+
auc = roc_auc_score(self.gold, self.pred)
30+
31+
fixed_sensitivity, fixed_specificity = [], []
32+
fpr, tpr, _ = roc_curve(self.gold, self.pred)
33+
sensitivity, specificity = tpr, 1 - fpr
34+
35+
for i in range(1, 10):
36+
value = i * 0.1
37+
sensitivity_idx = np.argmin(np.abs(sensitivity-value))
38+
spec = specificity[sensitivity_idx]
39+
fixed_sensitivity.append(spec)
40+
41+
specificity_idx = np.argmin(np.abs(specificity-value))
42+
sen = sensitivity[specificity_idx]
43+
fixed_specificity.append(sen)
44+
45+
return auc, np.array(fixed_sensitivity), np.array(fixed_specificity)

FAERSdata.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import numpy as np
3+
from tqdm import tqdm
4+
5+
from mapping import sider_eval_pairs, drug2id, adr2id, drug_list, adr_list
6+
class FAERSdata:
7+
def __init__(self, directory, method, year):
8+
9+
Files = os.listdir('%s/%s' % (directory, method))
10+
11+
if year == 'all':
12+
Files = [Files[-1]]
13+
14+
X = {}
15+
Y = {}
16+
Index = {}
17+
for i in tqdm(range(len(Files))):
18+
f = Files[i]
19+
x = np.zeros(shape=(len(drug_list), len(adr_list)))
20+
with open('%s/%s/%s' % (directory, method, f), 'r') as ff:
21+
next(ff)
22+
for line in ff:
23+
line = line.strip('\n')
24+
line = line.split(',')
25+
drug, adr, score = line[0], line[1], round(float(line[2]),5)
26+
drug_id, adr_id = drug2id.get(drug), adr2id.get(adr)
27+
if drug in drug_list and adr in adr_list:
28+
x[drug_id, adr_id] = score
29+
30+
y = np.zeros(shape=(len(drug_list), len(adr_list)))
31+
for drug, adr in sider_eval_pairs:
32+
drug_id, adr_id = drug2id.get(drug), adr2id.get(adr)
33+
y[drug_id, adr_id] = 1
34+
35+
y = np.asarray(y)
36+
index = np.arange(x.shape[0])
37+
38+
X[i] = x
39+
Y[i] = y
40+
Index[i] = index.tolist()
41+
42+
self.X = X
43+
self.Y = Y
44+
self.Index = Index
45+
46+
47+
48+
49+
50+
51+
52+
53+
54+
55+
56+
57+

Model.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import numpy as np
2+
import pickle
3+
from collections import defaultdict
4+
5+
from Eval import Eval
6+
from mapping import drugid2rxnorm, rxnorm2features, id2drug, id2adr
7+
from utils import split_data
8+
from similarity import get_Jaccard_Similarity
9+
10+
11+
class Model:
12+
def __init__(self, metrics):
13+
self.ALPHA = 0.1
14+
self.metrics = metrics
15+
16+
def get_similarity_matrix(self, X):
17+
features_matrix = []
18+
for idx in range(X.shape[0]):
19+
drug = id2drug.get(idx)
20+
rxnorm = drugid2rxnorm[drug]
21+
features = rxnorm2features[rxnorm]
22+
features_matrix.append(features)
23+
features_matrix = np.asarray(features_matrix)
24+
return get_Jaccard_Similarity(features_matrix)
25+
26+
def label_propogation(self, X, alpha):
27+
similarity_matrix = self.get_similarity_matrix(X)
28+
score_matrix_drug = (1 - alpha) * np.matmul(np.linalg.pinv(
29+
np.eye(np.shape(X)[0]) - alpha * similarity_matrix), X)
30+
return score_matrix_drug
31+
32+
def validate(self, X, Y, idx):
33+
AUC = []
34+
for i in range(1, 10):
35+
alpha = i * 0.1
36+
Y_pred = self.predict(X, alpha)
37+
metrics = self.eval(Y_pred, Y, idx)
38+
auc = metrics[0]
39+
AUC.append(auc)
40+
print(AUC)
41+
max_auc = max(AUC)
42+
max_idx = AUC.index(max_auc)
43+
max_alpha = (max_idx + 1) * 0.1
44+
self.ALPHA = max_alpha
45+
46+
def predict(self, X, alpha):
47+
Y_pred = self.label_propogation(X, alpha)
48+
return Y_pred
49+
50+
def eval(self, Y_pred, Y, idx):
51+
y_pred, y_gold = [], []
52+
for r, c in zip(idx[0], idx[1]):
53+
y_pred.append(Y_pred[r, c])
54+
y_gold.append(Y[r, c])
55+
ev = Eval(y_pred, y_gold)
56+
return ev.Metrics(self.metrics)
57+
58+
59+
def eval_DME(self, Y_pred, Y, idx, DME):
60+
y_pred, y_gold = defaultdict(list), defaultdict(list)
61+
for r, c in zip(idx[0], idx[1]):
62+
adrid = id2adr.get(c)
63+
if adrid in DME:
64+
y_pred[adrid].append(Y_pred[r, c])
65+
y_gold[adrid].append(Y[r, c])
66+
EV = {}
67+
for k in y_pred.keys():
68+
y_p, y_g = y_pred.get(k), y_gold.get(k)
69+
ev = Eval(y_p, y_g)
70+
EV[k] = ev.Metrics(self.metrics)
71+
return EV
72+
73+
74+
75+

README.md

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# LP-SDA
2+
3+
## 1. Introduction
4+
This repository contains source code for paper ["Towards early detection of adverse drugreactions: combining pre-clinical drug structuresand post-market safety reports"]() (accepted by **_BMC Medical Informatics andDecision Making_**).
5+
In this paper, we propose a label propagation framework to enhance drug safety signals by combining pre-clinical drug chemical structures with post-marketing safety reports from [FDA Adverse Event Reporting System (FAERS)](https://open.fda.gov/data/faers/).
6+
7+
We apply the label propagation framework to four popular signal detection algorithms (PRR, ROR,MGPS, BCPNN) and find that our proposed framework generates more accurate drug safety signals than the corresponding baselines.
8+
9+
## 2. Pipeline
10+
![alt text](img/pipeline.jpg "Pipeline")
11+
12+
Fig. 1: The overall framework for label propagation based signal detection algorithms. It consists of three main steps: computing original drug safety signals from FAERS reports, constructing a drug-drug similarity network from pre-clinical drug structures, and generating enhanced drug safety signals through a label propagation process.
13+
14+
## 3. Dataset
15+
Datasets used in the paper:
16+
- [FAERS](https://open.fda.gov/data/faers/): a database that contains information on adverse event and medication error reports submitted to FDA. We use a curated and standardized version of FAERS data from 2004 to 2014 (Banda, Juan M. et al., 2017) [[paper&data]](https://datadryad.org/stash/dataset/doi:10.5061/dryad.8q0s4).
17+
- [PubChem](https://www.ncbi.nlm.nih.gov/pubmed/26400175): a public repository for information on chemical substances and their biological activities. The PubChem Compound database provides unique chemical structure information of drugs.
18+
- [SIDER](http://sideeffects.embl.de/): a database that contains information on marketed medicines and their recorded adverse drug reactions.
19+
20+
## 4. Code
21+
#### Running example
22+
```
23+
python run.py --input SignalScoresSource --method PRR05 --year all --eval_metrics all --split True
24+
```
25+
26+
#### Parameters
27+
- --input, input original signal scores files.
28+
- --method, signal detection algorithm (i.e., PRR, ROR, MGPS, BCPNN).
29+
- --year, years of data used for model (i.e., all years data from 2004 to 2014 or data arranged by ending years).
30+
- --eval_metrics, evaluation metrics (i.e., AUC, AUPR, Precision, Recall, etc.)
31+
- --split, whether to split entire dataset into validation set and testing set.
32+
- --output, output file.
33+
34+
## 5. Citation
35+
Please kindly cite the paper if you use the code, datasets or any results in this repo or in the paper:

img/pipeline.jpg

585 KB
Loading

mapping.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import pickle
2+
3+
sider_eval_pairs = pickle.load(open('pickles/sider_eval_pairs_final.pkl', 'rb'))
4+
drugid2rxnorm = pickle.load(open('pickles/drugid2rxnorm_mapping.pkl', 'rb'))
5+
rxnorm2features = pickle.load(open('pickles/rxnorm2features_mapping.pkl', 'rb'))
6+
7+
drug_list = list(set(drug for (drug, adr) in sider_eval_pairs))
8+
adr_list = list(set(adr for (drug, adr) in sider_eval_pairs))
9+
10+
id2drug = {i: drug for i, drug in enumerate(drug_list)}
11+
drug2id = {drug: i for i, drug in enumerate(drug_list)}
12+
13+
id2adr = {i: adr for i, adr in enumerate(adr_list)}
14+
adr2id = {adr: i for i, adr in enumerate(adr_list)}

pickles/drugid2rxnorm_mapping.pkl

95 KB
Binary file not shown.

pickles/rxnorm2features_mapping.pkl

4.61 MB
Binary file not shown.

pickles/sider_eval_pairs_final.pkl

960 KB
Binary file not shown.

run.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
2+
import numpy as np
3+
4+
from FAERSdata import FAERSdata
5+
from Model import Model
6+
from utils import split_data, sample_zeros
7+
8+
9+
def parse_args():
10+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter, conflict_handler='resolve')
11+
parser.add_argument('--input', required=True, help='Input original signal scores file.')
12+
parser.add_argument('--method', required=True, choices=['PRR05', 'ROR05', 'GPS', 'BCPNN'], help='Signal detection algorithm')
13+
parser.add_argument('--year', default='all', choices=['all', 'each'], help='Years of data used for model')
14+
parser.add_argument('--eval_metrics', required=True, choices=['all', 'specificity-sensitivity'],
15+
help='Evaluation metrics')
16+
parser.add_argument('--split', type=bool, default=False)
17+
parser.add_argument('--output')
18+
19+
args = parser.parse_args()
20+
return args
21+
22+
23+
def pretty_print_eval(res, metrics):
24+
if metrics == 'all':
25+
print('All metrics: ' + ','.join(np.round(res,3).astype(str)))
26+
else:
27+
print('fixed_sensitivity: ' + ','.join(np.round(res[1],3).astype(str)))
28+
print('fixed_specificity: ' + ','.join(np.round(res[2],3).astype(str)))
29+
30+
31+
def main(args):
32+
print('#' * 50)
33+
print('Signal Detection Algorithm: {}, Year: {}'.format(args.method, args.year))
34+
print('#' * 50)
35+
36+
37+
data = FAERSdata(args.input, args.method, args.year)
38+
39+
for i in range(len(data.X.keys())):
40+
X, Y, _ = data.X.get(i), data.Y.get(i), data.Index.get(i)
41+
# all_idx = np.where(Y > -1)
42+
all_idx = sample_zeros(Y)
43+
if args.split:
44+
valid, test = split_data(Y)
45+
model = Model(args.eval_metrics)
46+
model.validate(X, Y, valid)
47+
Y_pred = model.predict(X, model.ALPHA)
48+
valid_res = model.eval(Y_pred, Y, valid)
49+
test_res = model.eval(Y_pred, Y, test)
50+
print('LP-{}:'.format(args.method))
51+
print('alpha: {}'.format(model.ALPHA))
52+
print('valid:')
53+
pretty_print_eval(valid_res, args.eval_metrics)
54+
print('test:')
55+
pretty_print_eval(test_res, args.eval_metrics)
56+
57+
valid_res = model.eval(X, Y, valid)
58+
test_res = model.eval(X, Y, test)
59+
print('baseline-{}:'.format(args.method))
60+
print('valid:')
61+
pretty_print_eval(valid_res, args.eval_metrics)
62+
print('test:')
63+
pretty_print_eval(test_res, args.eval_metrics)
64+
else:
65+
model = Model(args.eval_metrics)
66+
model.validate(X, Y, all_idx)
67+
Y_pred = model.predict(X, model.ALPHA)
68+
res = model.eval(Y_pred, Y, all_idx)
69+
print('LP-{}:'.format(args.method))
70+
pretty_print_eval(res, args.eval_metrics)
71+
72+
print('baseline-{}:'.format(args.method))
73+
res = model.eval(X, Y, all_idx)
74+
pretty_print_eval(res, args.eval_metrics)
75+
76+
def main_DME(args):
77+
print('#' * 50)
78+
print('Signal Detection Algorithm: {}, Year: {}'.format(args.method, args.year))
79+
print('#' * 50)
80+
81+
data = FAERSdata(args.input, args.method, args.year)
82+
DME = np.loadtxt('DME.txt', dtype=str, delimiter=',')
83+
adr_id, adr_name = DME[:,0], DME[:,1]
84+
85+
out = open(args.output, 'w')
86+
# out.write('ID,Name,AUC,AUC,AUPR,AUPR,Precision,Precision,Recall,Recall,Accuracy,Accuracy,F1,F1\n')
87+
for i in range(len(data.X.keys())):
88+
X, Y, _ = data.X.get(i), data.Y.get(i), data.Index.get(i)
89+
# all_idx = np.where(Y > -1)
90+
eval_idx = sample_zeros(Y)
91+
model = Model(args.eval_metrics)
92+
Y_pred = model.predict(X, model.ALPHA)
93+
LP_res = model.eval_DME(Y_pred, Y, eval_idx, adr_id)
94+
baseline_res = model.eval_DME(X, Y, eval_idx, adr_id)
95+
for i, adr in enumerate(list(adr_id)):
96+
print('LP-{}:'.format(args.method))
97+
LP_metric = LP_res.get(adr)
98+
print('ADR:{} '.format(adr))
99+
pretty_print_eval(LP_metric, args.eval_metrics)
100+
101+
print('baseline-{}:'.format(args.method))
102+
baseline_metric = baseline_res.get(adr)
103+
pretty_print_eval(baseline_metric, args.eval_metrics)
104+
105+
out.write('{},{},{},{}\n'.format(adr, adr_name[i], ','.join(np.round(LP_metric,3).astype(str)), ','.join(np.round(baseline_metric,3).astype(str))))
106+
107+
out.close()
108+
109+
110+
def more_main():
111+
args = parse_args()
112+
main(args)
113+
114+
if __name__ == '__main__':
115+
more_main()

0 commit comments

Comments
 (0)