1
+ """ Authored by: Neel Kanwal ([email protected] )"""
2
+
3
+ # This file provides inference code for baseline models (CNN+FC) and DKL models, mentioned in the paper.
4
+ # Update paths to model weights for running this script
5
+
6
+ import warnings
7
+ warnings .simplefilter (action = 'ignore' , category = FutureWarning )
8
+ warnings .simplefilter (action = 'ignore' , category = DeprecationWarning )
9
+ warnings .simplefilter (action = 'ignore' , category = RuntimeWarning )
10
+ warnings .simplefilter (action = 'ignore' , category = UserWarning )
11
+
12
+ import matplotlib .pyplot as plt
13
+ font = {'family' : 'serif' ,
14
+ 'weight' : 'normal' ,
15
+ 'size' : 24 }
16
+ plt .rc ('font' , ** font )
17
+
18
+ fig = plt .subplots (figsize = (12 , 12 ))
19
+
20
+ import gpytorch
21
+ from torch .autograd import Variable
22
+
23
+ import pandas as pd
24
+ import numpy as np
25
+ from datetime import datetime
26
+ import seaborn as sns
27
+ import os
28
+ import time
29
+ import json
30
+ import torch
31
+ from torch .utils .data import DataLoader , WeightedRandomSampler , random_split
32
+ from torchvision import datasets , models
33
+ import torchvision .transforms as transforms
34
+ from torch .optim .lr_scheduler import MultiStepLR , LinearLR , ReduceLROnPlateau , ExponentialLR
35
+ import torch .nn .functional as F
36
+
37
+ from sklearn .metrics import confusion_matrix , accuracy_score , f1_score , matthews_corrcoef , roc_auc_score
38
+
39
+ import scipy .stats as stats
40
+ import statistics
41
+
42
+ from PIL import ImageFile
43
+ ImageFile .LOAD_TRUNCATED_IMAGES = True
44
+ from my_functions import get_class_distribution , infer_dkl_v2 , infer_cnn_v2 , DenseNetFeatureExtractor , DKLModel
45
+ from my_functions import extract_features , custom_classifier , count_flops
46
+
47
+ # from mmcv.cnn.utils import flops_counter
48
+ # from fvcore.nn import FlopCountAnalysis
49
+ # from ptflops import get_model_complexity_info
50
+
51
+ torch .cuda .empty_cache ()
52
+ cuda_device = 1
53
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = str (cuda_device )
54
+
55
+ NUM_WORKER = 16 # Number of simultaneous compute tasks == number of physical cores
56
+ BATCH_SIZE = 64
57
+ dropout = 0.2
58
+ torch .manual_seed (1700 )
59
+
60
+ #### Selection parameters to define experiment
61
+ #### choose dataset, artifact , and model
62
+ architecture = "DKL" # "CNN", "DKL"
63
+ dataset = "emc" # "focuspath", "emc", "tcgafocus", "suh"
64
+ artifact = "fold" # Select Blur or Fold artifact
65
+ val = False # runs on the validation set of dataset insted of test.
66
+
67
+ repitions_for_p = 5 # repitions to calculate mean and average across runns
68
+
69
+ # location where all experiments and models are present.
70
+ model_weights = "path_to/DKLModels/weights/"
71
+
72
+ # blur_cnn_wts = "/DenseNetConfig/04_27_2022 19:47:45"
73
+ # blur_dkl_wts = "/04_24_2022 10:26:30" #@ 128
74
+ # # blur_dkl_wts = "/04_20_2022 10:09:11" # @256
75
+ # # blur_dkl_wts = "/04_29_2022 09:27:26" #(6,6,6) @ 384
76
+ # fold_cnn_wts = "/DenseNetConfig/04_27_2022 12:05:05" # 666
77
+ # fold_dkl_wts = "/04_21_2022 03:33:31"
78
+
79
+ ## Path to the datasets
80
+ if dataset == "focuspath" :
81
+ path_to_dataset = "path_to/FocusPath/"
82
+ elif dataset == "tcgafocus" :
83
+ path_to_dataset = "path_to/tcgafocus/"
84
+ elif dataset == "suh" :
85
+ path_to_dataset = "path_to/Processed/"
86
+ else :
87
+ if artifact == "blur" :
88
+ path_to_dataset = "path_to/artifact_dataset/blur/test/"
89
+ if val :
90
+ print ("Validation blur subset from EMC" )
91
+ path_to_dataset = "path_to/artifact_dataset/blur/validation/"
92
+ elif artifact == "fold" :
93
+ path_to_dataset = "path_to/artifact_dataset/fold_20x/test/"
94
+ if val :
95
+ print ("Validation fold subset from EMC" )
96
+ path_to_dataset = "path_to/artifact_dataset/fold_20x/validation/"
97
+ else :
98
+ print ("Dataset does not exists" )
99
+
100
+ # Transform data
101
+ test_compose = transforms .Compose ([
102
+ transforms .CenterCrop ((224 , 224 )),
103
+ # transforms.Resize((224, 224)),
104
+ transforms .ToTensor (),
105
+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])])
106
+
107
+ # Data Loaders to load data.
108
+ t = time .time ()
109
+ print (f"Loading { dataset } - { artifact } Dataset..........." )
110
+ test_images = datasets .ImageFolder (root = path_to_dataset , transform = test_compose )
111
+ total_patches = len (test_images )
112
+ idx2class = {v : k for k , v in test_images .class_to_idx .items ()}
113
+ num_classes = len (test_images .classes )
114
+ test_loader = DataLoader (test_images , batch_size = BATCH_SIZE , shuffle = False , sampler = None , num_workers = NUM_WORKER , pin_memory = True )
115
+ classes_list = test_loader .dataset .classes
116
+ class_distribution = get_class_distribution (test_images )
117
+ print ("Class distribution in training: " , class_distribution )
118
+ print (f"Length of { artifact } testset { len (test_images )} with { num_classes } classes" )
119
+ print (f"Total data loading time in minutes: { (time .time () - t )/ 60 :.3f} " )
120
+
121
+
122
+ now = datetime .now ()
123
+ date_time = now .strftime ("%m_%d_%Y %H:%M:%S" )
124
+ print (f"Its { date_time } .\n " )
125
+
126
+ # Loading models based on defined experimental setting.
127
+ # Best model for blur was DKL with (10,10,10)
128
+ # Best model for fold was DKL with (6,6,6)
129
+
130
+ if architecture == "DKL" :
131
+ if artifact == "blur" :
132
+ print (f"\n Initializing DKL for { artifact } ..............." )
133
+ feature_extractor = DenseNetFeatureExtractor (block_config = (10 , 10 , 10 ), num_classes = num_classes )
134
+ num_features = feature_extractor .classifier .in_features
135
+ print ("Number of output features for patch is " , num_features )
136
+ model = DKLModel (feature_extractor , num_dim = num_features , grid_size = 128 )
137
+ likelihood = gpytorch .likelihoods .SoftmaxLikelihood (num_features = model .num_dim , num_classes = num_classes )
138
+ pytorch_total_params = sum (p .numel () for p in model .parameters ())
139
+ print ("Total model parameters (M): " , pytorch_total_params / 1e6 )
140
+ best_model_wts = model_weights + "/blur_dkl.dat"
141
+ # print("Loading model weights ")
142
+ model .load_state_dict (torch .load (best_model_wts , map_location = torch .device ('cpu' ))['model' ])
143
+ # print("Loading likelihood weights ")
144
+ likelihood .load_state_dict (torch .load (best_model_wts , map_location = torch .device ('cpu' ))['likelihood' ])
145
+
146
+ else :
147
+ print (f"\n Initializing DKL for { artifact } ..............." )
148
+ feature_extractor = DenseNetFeatureExtractor (block_config = (6 , 6 , 6 ), num_classes = num_classes )
149
+ num_features = feature_extractor .classifier .in_features
150
+ print ("Number of output features for patch is " , num_features )
151
+ model = DKLModel (feature_extractor , num_dim = num_features , grid_size = 128 )
152
+ likelihood = gpytorch .likelihoods .SoftmaxLikelihood (num_features = model .num_dim , num_classes = num_classes )
153
+ pytorch_total_params = sum (p .numel () for p in model .parameters ())
154
+ print ("Total model parameters (M): " , pytorch_total_params / 1e6 )
155
+ best_model_wts = model_weights + "/fold_dkl.dat"
156
+ # print("Loading model weights ")
157
+ model .load_state_dict (torch .load (best_model_wts , map_location = torch .device ('cpu' ))['model' ])
158
+ # print("Loading likelihood weights ")
159
+ likelihood .load_state_dict (torch .load (best_model_wts , map_location = torch .device ('cpu' ))['likelihood' ])
160
+
161
+ input_size = (1 , 3 , 224 , 224 )
162
+ flops = count_flops (model , input_size )
163
+ print ("GFLOPs:" , flops / 1e9 )
164
+
165
+
166
+ if torch .cuda .is_available ():
167
+ print ("Cuda is available" )
168
+ model = model .cuda ()
169
+ likelihood = likelihood .cuda ()
170
+
171
+ path = os .path .join ('path_to/emc/' , f"{ dataset } " )
172
+ if not os .path .exists (path ):
173
+ os .mkdir (path )
174
+
175
+ print ("\n Testing Starts...................." )
176
+ y_pred , y_true , probs , mean1 , epistemic , pred_var , lower_1c , upper_1c , feature , entropy = infer_dkl_v2 (test_loader ,
177
+ model , likelihood ,total_patches = total_patches , n_samples = 100 )
178
+
179
+ file_names = [im [0 ].split ("/" )[- 1 ] for im in test_loader .dataset .imgs ]
180
+ data = {"files" : file_names , "ground_truth" : y_true , "prediction" : y_pred , "probabilities" : probs , \
181
+ "mean1" : mean1 ,"epistemic" :epistemic , "variance" :pred_var , "lower_conf" : lower_1c , "upper_conf" : upper_1c , "entropy" : entropy }
182
+ dframe = pd .DataFrame (data )
183
+
184
+ with pd .ExcelWriter (f"{ path } /dkl_predictions_on_{ dataset } _for_{ artifact } .xlsx" ) as wr :
185
+ dframe .to_excel (wr , index = False )
186
+
187
+ accuracy = accuracy_score (y_true , y_pred )
188
+ print ("Accuracy: " , accuracy )
189
+ f1 = f1_score (y_true , y_pred )
190
+ print ("F1 Score: " , f1 )
191
+ roc = roc_auc_score (y_true , y_pred )
192
+ print ("ROC AUC Score: " , roc )
193
+ mathew_corr = matthews_corrcoef (y_true , y_pred )
194
+ print ("Mathew Correlation Coefficient: " , mathew_corr )
195
+
196
+ acc_list , f1_list , roc_list , mcc_list , pred_list = [],[],[],[],[]
197
+
198
+ for i in range (repitions_for_p ):
199
+ y_pred , y_true , _ , _ , _ , _ , _ , _ , _ , _ = infer_dkl_v2 (test_loader ,
200
+ model , likelihood ,total_patches = total_patches , n_samples = 100 )
201
+ accuracy = accuracy_score (y_true , y_pred )
202
+ acc_list .append (accuracy )
203
+ f1 = f1_score (y_true , y_pred )
204
+ f1_list .append (f1 )
205
+ roc = roc_auc_score (y_true , y_pred )
206
+ roc_list .append (roc )
207
+ mathew_corr = matthews_corrcoef (y_true , y_pred )
208
+ mcc_list .append (mathew_corr )
209
+ pred_list .append (y_pred )
210
+
211
+ print (acc_list )
212
+ # p_value = stats.ttest_rel(pred_list[0], pred_list[1]).pvalue
213
+ # print("\nP-value for DKL is", p_value)
214
+
215
+ print ("\n Accuracy mean: " ,statistics .mean (acc_list )," Accuracy std: " , statistics .stdev (acc_list ))
216
+ print ("\n F1 mean: " ,statistics .mean (f1_list )," F1 std: " ,statistics .stdev (f1_list ))
217
+ print ("\n ROC mean: " ,statistics .mean (roc_list )," ROC std: " ,statistics .stdev (roc_list ))
218
+ print ("\n MCC mean: " ,statistics .mean (mcc_list )," MCC std: " ,statistics .stdev (mcc_list ))
219
+
220
+ else :
221
+ if artifact == "blur" :
222
+ print (f"\n Initializing CNN baseline of DenseNet (10,10,10) for { artifact } ..............." )
223
+ model = models .DenseNet (block_config = (10 ,10 ,10 ), growth_rate = 12 , num_init_features = 24 )
224
+ num_features = model .classifier .in_features # 2208 --> less than 256
225
+ model .classifier = custom_classifier (num_features , num_classes , dropout = dropout )
226
+ print ("Number of out features for patch is " , num_features )
227
+ pytorch_total_params = sum (p .numel () for p in model .parameters ())
228
+ print ("Total model parameters (M): " , pytorch_total_params / 1e6 )
229
+ # best_model_wts = base_location + blur_cnn_wts
230
+ model .load_state_dict (torch .load (model_weights + "/blur_cnn.dat" ,map_location = torch .device ('cpu' ))['model' ])
231
+ else :
232
+ print (f"Initializing CNN DenseNet baseline of (6,6,6) Model for { artifact } ..............." )
233
+ model = models .DenseNet (block_config = (6 ,6 ,6 ), growth_rate = 12 , num_init_features = 24 )
234
+ num_features = model .classifier .in_features # 2208 --> less than 256
235
+ model .classifier = custom_classifier (num_features , num_classes , dropout = dropout )
236
+ print ("Number of out features for patch is " , num_features )
237
+ pytorch_total_params = sum (p .numel () for p in model .parameters ())
238
+ print ("Total model parameters (M): " , pytorch_total_params / 1e6 )
239
+ # best_model_wts = base_location + fold_cnn_wts
240
+ model .load_state_dict (torch .load (model_weights + "/fold_cnn.dat" ,map_location = torch .device ('cpu' ))['model' ])
241
+
242
+ input_size = (1 , 3 , 224 , 224 )
243
+ flops = count_flops (model , input_size )
244
+ print ("GFLOPs:" , flops / 1e9 )
245
+
246
+
247
+ if torch .cuda .is_available ():
248
+ print ("Cuda is available" )# model should be on uda before selection of optimizer
249
+ model = model .cuda ()
250
+
251
+ path = os .path .join ('path_to/' , f"{ dataset } " )
252
+ if not os .path .exists (path ):
253
+ os .mkdir (path )
254
+
255
+ print ("\n Testing Starts...................." )
256
+ y_pred , y_true , probs , mean1 , epistemic , pred_var , lower_1c , upper_1c = infer_cnn_v2 (test_loader , model ,total_patches = total_patches , n_samples = 100 )
257
+
258
+
259
+ file_names = [im [0 ].split ("/" )[- 1 ] for im in test_loader .dataset .imgs ]
260
+ data = {"files" : file_names , "ground_truth" : y_true , "prediction" : y_pred , "probabilities" : probs , \
261
+ "mean1" : mean1 , "epistemic" : epistemic , "variance" :pred_var , "lower_conf" : lower_1c , "upper_conf" : upper_1c }
262
+ dframe = pd .DataFrame (data )
263
+
264
+ with pd .ExcelWriter (f"{ path } /cnn_predictions_on_{ dataset } _for_{ artifact } .xlsx" ) as wr :
265
+ dframe .to_excel (wr , index = False )
266
+
267
+
268
+ accuracy = accuracy_score (y_true , y_pred )
269
+ print ("Accuracy: " , accuracy )
270
+ f1 = f1_score (y_true , y_pred )
271
+ print ("F1 Score: " , f1 )
272
+ roc = roc_auc_score (y_true , y_pred )
273
+ print ("ROC AUC Score: " , roc )
274
+ mathew_corr = matthews_corrcoef (y_true , y_pred )
275
+ print ("Mathew Correlation Coefficient: " , mathew_corr )
276
+
277
+ acc_list , f1_list , roc_list , mcc_list , pred_list = [],[],[],[],[]
278
+
279
+ for i in range (repitions_for_p ):
280
+ y_pred , y_true , _ , _ , _ , _ , _ , _ = infer_cnn_v2 (test_loader , model , total_patches = total_patches , n_samples = 100 )
281
+ accuracy = accuracy_score (y_true , y_pred )
282
+ acc_list .append (accuracy )
283
+ f1 = f1_score (y_true , y_pred )
284
+ f1_list .append (f1 )
285
+ roc = roc_auc_score (y_true , y_pred )
286
+ roc_list .append (roc )
287
+ mathew_corr = matthews_corrcoef (y_true , y_pred )
288
+ mcc_list .append (mathew_corr )
289
+ pred_list .append (y_pred )
290
+
291
+ print (acc_list )
292
+ # p_value = stats.ttest_rel(pred_list[0], pred_list[1]).pvalue
293
+ # print("\nP-value for CNN is", p_value)
294
+
295
+ print ("\n Accuracy mean: " ,statistics .mean (acc_list )," Accuracy std: " , statistics .stdev (acc_list ))
296
+ print ("\n F1 mean: " ,statistics .mean (f1_list )," F1 std: " ,statistics .stdev (f1_list ))
297
+ print ("\n ROC mean: " ,statistics .mean (roc_list )," ROC std: " ,statistics .stdev (roc_list ))
298
+ print ("\n MCC mean: " ,statistics .mean (mcc_list )," MCC std: " ,statistics .stdev (mcc_list ))
299
+
300
+ print ("#####################################################" )
301
+ print (f"Total time in minutes: { (time .time () - t )/ 60 :.3f} " )
0 commit comments