Skip to content

Commit ffcf498

Browse files
authored
Add files via upload
1 parent 6fe19a0 commit ffcf498

File tree

4 files changed

+674
-0
lines changed

4 files changed

+674
-0
lines changed

README.md

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# DSW
2+
3+
# Full list of covariates
4+
We show the full list of static demographics and time-varying covariates of sepsis patients obtained from [MIMIC-III](https://mimic.physionet.org/).
5+
| Category | Items | Type |
6+
|--------------|---------------------------------------------------------|--------|
7+
| Demographics | age | Cont. |
8+
| | gender | Binary |
9+
| | race (white, black, hispanic, other) | Binary |
10+
| | metastatic cancer | Binary |
11+
| | diabetes | Binary |
12+
| | height | Cont. |
13+
| | weight | Cont. |
14+
| | bmi | Cont. |
15+
| Vital signs | heart rate, systolic, mean and diastolic blood pressure | Cont. |
16+
| | Respiratory rate, SpO2 | Cont. |
17+
| | Temperatures | Cont. |
18+
| Lab tests | sodium, chloride, magnesium | Cont. |
19+
| | glucose, BUN, creatinine, urineoutput, GCS | Cont. |
20+
| | white blood cells count, bands, C-Reactive protein | Cont. |
21+
| | hemoglobin, hematocrit, aniongap | Cont. |
22+
| | platelets count, PTT, PT, INR | Cont. |
23+
| | bicarbonate, lactate | Cont. |
24+
25+
# Introduction
26+
This repository contains source code for paper ["Estimating Individual Treatment Effects with Time-Varying Confounders"]().
27+
28+
In this paper, we study the problem of Estimating individual treatment effects with time-varying confounders (as illustrated by a causal graph in the figure below)
29+
30+
<img src="src/Fig1.png" width=40%>
31+
32+
We propose Deep Sequential Weighting (DSW) for estimating ITE with time-varying confounders. DSW consists of three main components: representation learning module, balancing module and prediction module.
33+
34+
<img src="src/model4.png" width=80%>
35+
36+
To demonstrate the effectiveness of our framework, we conduct comprehensive experiments on synthetic, semi-synthetic and real-world EMR datasets ([MIMIC-III](https://mimic.physionet.org/)). DSW outperforms state-of-the-art baselines in terms of PEHE and ATE.
37+
38+
# Requirement
39+
Ubuntu16.04, python 3.6
40+
41+
Install [pytorch 1.4](https://pytorch.org/)
42+
43+
# Data preprocessing
44+
### Synthetic dataset
45+
Simulate the all covariates, treatments and outcomes
46+
```
47+
cd simulation
48+
python synthetic.py
49+
```
50+
51+
### Semi-synthetic dataset
52+
With a similar simulation process, we construct a semi-synthetic dataset based on a real-world dataset: [MIMIC-III](https://mimic.physionet.org/).
53+
```
54+
cd simulation
55+
python synthetic_mimic.py
56+
```
57+
58+
### MIMIC-III dataset
59+
Obtain the patients data of two treatment-outcome pairs: (1) vasopressor-Meanbp; (2) ventilator-SpO2.
60+
```
61+
cd simulation
62+
python pre_mimic.py
63+
```
64+
65+
# DSW
66+
#### Running example
67+
```
68+
python train_synthetic.py --observation_window 30 --epochs 64 --batch-size 128 --lr 1e-3
69+
```
70+
71+
#### Outputs
72+
- ITE estimation metrics: PEHE, ATE
73+
- Factual prediction metric: RMSE
74+

data_loader_syn.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import torch
3+
from torch.utils import data
4+
5+
gamma=0.1
6+
7+
data_dir = 'simulation/data_mymodel_new2_{}/'.format(gamma)
8+
9+
10+
# dataset meta data
11+
n_X_features = 100
12+
n_X_static_features = 5
13+
n_X_t_types = 1
14+
n_classes = 1
15+
16+
17+
def get_dim():
18+
return n_X_features, n_X_static_features, n_X_t_types, n_classes
19+
20+
21+
22+
class SyntheticDataset(data.Dataset):
23+
def __init__(self, list_IDs, obs_w, treatment):
24+
'''Initialization'''
25+
self.list_IDs = list_IDs
26+
self.obs_w = obs_w
27+
self.treatment = treatment
28+
29+
30+
def __len__(self):
31+
'''Denotes the total number of samples'''
32+
return len(self.list_IDs)
33+
34+
def __getitem__(self, index):
35+
'''Generates one sample of data'''
36+
# Select sample
37+
ID = self.list_IDs[index]
38+
39+
# Load labels
40+
label = np.load(data_dir + '{}.y.npy'.format(ID))
41+
42+
# Load data
43+
X_demographic = np.load(data_dir + '{}.static.npy'.format(ID))
44+
X_all = np.load(data_dir + '{}.x.npy'.format(ID))
45+
X_treatment_res = np.load(data_dir + '{}.a.npy'.format(ID))
46+
47+
X = torch.from_numpy(X_all.astype(np.float32))
48+
X_demo = torch.from_numpy(X_demographic.astype(np.float32))
49+
X_treatment = torch.from_numpy(X_treatment_res.astype(np.float32))
50+
y = torch.from_numpy(label.astype(np.float32))
51+
52+
return X, X_demo, X_treatment, y
53+

model_synthetic.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import torch.nn as nn
2+
import torch
3+
import torch.nn.functional as F
4+
5+
class Attn(nn.Module):
6+
def __init__(self, method, hidden_size):
7+
super(Attn, self).__init__()
8+
self.method = method
9+
if self.method not in ['dot', 'general', 'concat','concat2']:
10+
raise ValueError(self.method, "is not an appropriate attention method.")
11+
self.hidden_size = hidden_size
12+
if self.method == 'general':
13+
self.attn = nn.Linear(self.hidden_size, hidden_size)
14+
elif self.method == 'concat':
15+
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
16+
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
17+
18+
elif self.method == 'concat2':
19+
self.attn = nn.Linear(self.hidden_size * 3, hidden_size)
20+
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
21+
22+
def dot_score(self, hidden, encoder_output):
23+
return torch.sum(hidden * encoder_output, dim=2)
24+
25+
def general_score(self, hidden, encoder_output):
26+
energy = self.attn(encoder_output)
27+
return torch.sum(hidden * energy, dim=2)
28+
29+
def concat_score(self, hidden, encoder_output):
30+
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
31+
return torch.sum(self.v * energy, dim=2)
32+
33+
def concat_score2(self, hidden, encoder_output):
34+
h = torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)
35+
h = torch.cat((h, hidden*encoder_output),2)
36+
energy = self.attn(h).tanh()
37+
return torch.sum(self.v * energy, dim=2)
38+
39+
def forward(self, hidden, encoder_outputs):
40+
# Calculate the attention weights (energies) based on the given method
41+
if self.method == 'general':
42+
attn_energies = self.general_score(hidden, encoder_outputs)
43+
elif self.method == 'concat':
44+
attn_energies = self.concat_score(hidden, encoder_outputs)
45+
elif self.method == 'dot':
46+
attn_energies = self.dot_score(hidden, encoder_outputs)
47+
elif self.method == 'concat2':
48+
attn_energies = self.concat_score2(hidden, encoder_outputs)
49+
50+
# Transpose max_length and batch_size dimensions
51+
attn_energies = attn_energies.t()
52+
53+
# Return the softmax normalized probability scores (with added dimension)
54+
return F.softmax(attn_energies, dim=1).unsqueeze(1)
55+
56+
class LSTMModel(nn.Module):
57+
def __init__(self, n_X_features, n_X_static_features, n_X_fr_types, n_Z_confounders,
58+
attn_model, n_classes, obs_w,
59+
batch_size, hidden_size,
60+
num_layers=2, bidirectional=True, dropout = 0.2):
61+
super().__init__()
62+
63+
self.hidden_size = hidden_size
64+
self.batch_size = batch_size
65+
self.n_X_features = n_X_features
66+
self.n_X_static_features = n_X_static_features
67+
self.n_classes = n_classes
68+
self.obs_w = obs_w
69+
self.num_layers = num_layers
70+
self.x_emb_size = 32
71+
self.x_static_emb_size = 16
72+
self.z_dim = n_Z_confounders
73+
74+
if bidirectional:
75+
self.num_directions = 2
76+
else:
77+
self.num_directions = 1
78+
79+
self.n_t_classes = 1
80+
81+
self.rnn_f = nn.GRUCell(input_size=self.x_emb_size + 1 + n_Z_confounders, hidden_size=hidden_size)
82+
self.rnn_cf = nn.GRUCell(input_size=self.x_emb_size + 1 + n_Z_confounders, hidden_size=hidden_size)
83+
84+
self.attn_f = Attn(attn_model, hidden_size)
85+
self.concat_f = nn.Linear(hidden_size * 2, hidden_size)
86+
87+
self.attn_cf = Attn(attn_model, hidden_size)
88+
self.concat_cf = nn.Linear(hidden_size * 2, hidden_size)
89+
90+
91+
92+
self.x2emb = nn.Linear(n_X_features, self.x_emb_size)
93+
self.x_static2emb = nn.Linear(n_X_static_features, self.x_static_emb_size)
94+
95+
# IPW
96+
self.hidden2hidden_ipw = nn.Sequential(
97+
nn.Dropout(0.5),
98+
nn.Linear(self.x_emb_size + hidden_size + self.x_static_emb_size, hidden_size),
99+
nn.Dropout(0.3),
100+
nn.ReLU(),
101+
)
102+
self.hidden2out_ipw = nn.Linear(hidden_size, self.n_t_classes, bias=False)
103+
104+
# Outcome
105+
self.hidden2hidden_outcome_f = nn.Sequential(
106+
nn.Dropout(0.5),
107+
nn.Linear((self.x_emb_size + hidden_size) + self.x_static_emb_size + 1, hidden_size),
108+
nn.Dropout(0.3),
109+
nn.ReLU(),
110+
)
111+
self.hidden2out_outcome_f = nn.Linear(hidden_size, self.n_classes, bias=False)
112+
113+
self.hidden2hidden_outcome_cf = nn.Sequential(
114+
nn.Dropout(0.5),
115+
nn.Linear(self.x_emb_size + hidden_size + self.x_static_emb_size + 1, hidden_size),
116+
nn.Dropout(0.3),
117+
nn.ReLU(),
118+
)
119+
self.hidden2out_outcome_cf = nn.Linear(hidden_size, self.n_classes, bias=False)
120+
121+
122+
def feature_encode(self, x, x_fr):
123+
124+
f_hx = torch.randn(x.size(0), self.hidden_size)
125+
cf_hx = torch.randn(x.size(0), self.hidden_size)
126+
f_old = f_hx
127+
cf_old = cf_hx
128+
f_outputs = []
129+
f_zxs = []
130+
cf_outputs = []
131+
cf_zxs = []
132+
for i in range(x.size(1)):
133+
x_emb = self.x2emb(x[:, i, :])
134+
f_zx = torch.cat((x_emb, f_old), -1)
135+
f_zxs.append(f_zx)
136+
137+
cf_zx = torch.cat((x_emb, cf_old), -1)
138+
cf_zxs.append(cf_zx)
139+
140+
f_inputs = torch.cat((f_zx, x_fr[:,i].unsqueeze(1)), -1)
141+
142+
cf_treatment = torch.where(x_fr.sum(1)==0, torch.Tensor([1]), torch.Tensor([0])).unsqueeze(1)
143+
cf_inputs = torch.cat((cf_zx, cf_treatment), -1)
144+
145+
f_hx = self.rnn_f(f_inputs, f_hx)
146+
cf_hx = self.rnn_cf(cf_inputs, cf_hx)
147+
148+
if i == 0:
149+
f_concat_input = torch.cat((f_hx, f_hx), 1)
150+
cf_concat_input = torch.cat((cf_hx, cf_hx), 1)
151+
else:
152+
f_attn_weights = self.attn_f(f_hx, torch.stack(f_outputs))
153+
f_context = f_attn_weights.bmm(torch.stack(f_outputs).transpose(0, 1))
154+
f_context = f_context.squeeze(1)
155+
f_concat_input = torch.cat((f_hx, f_context), 1)
156+
157+
cf_attn_weights = self.attn_cf(cf_hx, torch.stack(cf_outputs))
158+
cf_context = cf_attn_weights.bmm(torch.stack(cf_outputs).transpose(0, 1))
159+
cf_context = cf_context.squeeze(1)
160+
cf_concat_input = torch.cat((cf_hx, cf_context), 1)
161+
162+
f_concat_output = torch.tanh(self.concat_f(f_concat_input))
163+
f_old = f_concat_output
164+
165+
cf_concat_output = torch.tanh(self.concat_cf(cf_concat_input))
166+
cf_old = cf_concat_output
167+
168+
f_outputs.append(f_hx)
169+
cf_outputs.append(cf_hx)
170+
171+
return f_zxs, cf_zxs
172+
173+
174+
def forward(self, x, x_demo, x_fr):
175+
176+
f_zxs, cf_zxs = self.feature_encode(x, x_fr)
177+
178+
# IPW
179+
ipw_outputs = []
180+
x_demo_emd = self.x_static2emb(x_demo)
181+
for i in range(len(f_zxs)):
182+
h = torch.cat((f_zxs[i], x_demo_emd), -1)
183+
h = self.hidden2hidden_ipw(h)
184+
ipw_out = self.hidden2out_ipw(h)
185+
ipw_outputs.append(ipw_out)
186+
187+
188+
# Outcome
189+
f_treatment = torch.where(x_fr.sum(1) > 0, torch.Tensor([1]), torch.Tensor([0])).unsqueeze(1)
190+
cf_treatment = torch.where(x_fr.sum(1) > 0, torch.Tensor([0]), torch.Tensor([1])).unsqueeze(1)
191+
192+
# factual prediction
193+
194+
f_zx_maxpool = torch.max(torch.stack(f_zxs), 0)
195+
196+
f_hidden = torch.cat((f_zx_maxpool[0], x_demo_emd, f_treatment), -1)
197+
f_h = self.hidden2hidden_outcome_f(f_hidden)
198+
199+
f_outcome_out = self.hidden2out_outcome_f(f_h)
200+
201+
# counterfactual prediction
202+
203+
cf_zx_maxpool = torch.max(torch.stack(cf_zxs), 0)
204+
205+
cf_hidden = torch.cat((cf_zx_maxpool[0], x_demo_emd, cf_treatment), -1)
206+
cf_h = self.hidden2hidden_outcome_cf(cf_hidden)
207+
208+
cf_outcome_out = self.hidden2out_outcome_cf(cf_h)
209+
210+
211+
return ipw_outputs, f_outcome_out, cf_outcome_out, f_h

0 commit comments

Comments
 (0)