-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathdemo_model.py
More file actions
73 lines (56 loc) · 2.05 KB
/
demo_model.py
File metadata and controls
73 lines (56 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from models import PromptMR, count_parameters
if __name__ == "__main__":
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
n_adj_slc = 5
n_coil = 10
base_config = dict(
num_cascades=12,
num_adj_slices=n_adj_slc,
n_feat0=48,
feature_dim=[72, 96, 120],
prompt_dim=[24, 48, 72],
sens_n_feat0=24,
sens_feature_dim=[36, 48, 60],
sens_prompt_dim=[12, 24, 36],
len_prompt=[5, 5, 5],
prompt_size=[64, 32, 16],
n_enc_cab=[2, 3, 3],
n_dec_cab=[2, 2, 3],
n_skip_cab=[1, 1, 1],
n_bottleneck_cab=3,
no_use_ca=False,
)
model_promptmr = PromptMR(
**base_config,
adaptive_input=False,
n_buffer=0,
n_history=0,
use_sens_adj=False,
)
model_promptmr_plus = PromptMR(
**base_config,
adaptive_input=True, # adaptive input
n_buffer=4, # buffer size in adaptive input, fixed to 4
n_history=11, # historical feature aggregation
use_sens_adj=True, # adjacent sensitivity map estimation
)
model_promptmr.to(device)
model_promptmr_plus.to(device)
# Use random input for demo
rand_input = torch.randn(1, n_adj_slc*n_coil, 218, 170, 2)
rand_mask = torch.randn(1, 1, 218, 170, 1).bool()
num_low_frequencies = torch.tensor([18])
rand_input = rand_input.to(device)
rand_mask = rand_mask.to(device)
num_low_frequencies = num_low_frequencies.to(device)
with torch.no_grad():
output = model_promptmr(rand_input, rand_mask, num_low_frequencies)
output_plus = model_promptmr_plus(rand_input, rand_mask, num_low_frequencies)
print('model_promptmr param: ', count_parameters(model_promptmr))
print('model_promptmr+ param: ', count_parameters(model_promptmr_plus))
print('output_promptmr shape: ', output['img_pred'].shape)
print('output_promptmr+ shape: ', output_plus['img_pred'].shape)