Skip to content

Commit 6e8ff93

Browse files
authored
Inference file
Python file uses weights of trained models to run inference on other datasets.
1 parent b7d086e commit 6e8ff93

File tree

1 file changed

+301
-0
lines changed

1 file changed

+301
-0
lines changed

inference.py

+301
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
""" Authored by: Neel Kanwal ([email protected])"""
2+
3+
# This file provides inference code for baseline models (CNN+FC) and DKL models, mentioned in the paper.
4+
# Update paths to model weights for running this script
5+
6+
import warnings
7+
warnings.simplefilter(action='ignore', category=FutureWarning)
8+
warnings.simplefilter(action='ignore', category=DeprecationWarning)
9+
warnings.simplefilter(action='ignore', category=RuntimeWarning)
10+
warnings.simplefilter(action='ignore', category=UserWarning)
11+
12+
import matplotlib.pyplot as plt
13+
font = {'family': 'serif',
14+
'weight': 'normal',
15+
'size': 24}
16+
plt.rc('font', **font)
17+
18+
fig = plt.subplots(figsize=(12, 12))
19+
20+
import gpytorch
21+
from torch.autograd import Variable
22+
23+
import pandas as pd
24+
import numpy as np
25+
from datetime import datetime
26+
import seaborn as sns
27+
import os
28+
import time
29+
import json
30+
import torch
31+
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
32+
from torchvision import datasets, models
33+
import torchvision.transforms as transforms
34+
from torch.optim.lr_scheduler import MultiStepLR, LinearLR, ReduceLROnPlateau, ExponentialLR
35+
import torch.nn.functional as F
36+
37+
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, matthews_corrcoef, roc_auc_score
38+
39+
import scipy.stats as stats
40+
import statistics
41+
42+
from PIL import ImageFile
43+
ImageFile.LOAD_TRUNCATED_IMAGES = True
44+
from my_functions import get_class_distribution, infer_dkl_v2, infer_cnn_v2, DenseNetFeatureExtractor, DKLModel
45+
from my_functions import extract_features, custom_classifier, count_flops
46+
47+
# from mmcv.cnn.utils import flops_counter
48+
# from fvcore.nn import FlopCountAnalysis
49+
# from ptflops import get_model_complexity_info
50+
51+
torch.cuda.empty_cache()
52+
cuda_device = 1
53+
os.environ["CUDA_VISIBLE_DEVICES"] = str(cuda_device)
54+
55+
NUM_WORKER = 16 # Number of simultaneous compute tasks == number of physical cores
56+
BATCH_SIZE = 64
57+
dropout = 0.2
58+
torch.manual_seed(1700)
59+
60+
#### Selection parameters to define experiment
61+
#### choose dataset, artifact , and model
62+
architecture = "DKL" # "CNN", "DKL"
63+
dataset = "emc" # "focuspath", "emc", "tcgafocus", "suh"
64+
artifact = "fold" # Select Blur or Fold artifact
65+
val = False # runs on the validation set of dataset insted of test.
66+
67+
repitions_for_p = 5 # repitions to calculate mean and average across runns
68+
69+
# location where all experiments and models are present.
70+
model_weights = "path_to/DKLModels/weights/"
71+
72+
# blur_cnn_wts = "/DenseNetConfig/04_27_2022 19:47:45"
73+
# blur_dkl_wts = "/04_24_2022 10:26:30" #@ 128
74+
# # blur_dkl_wts = "/04_20_2022 10:09:11" # @256
75+
# # blur_dkl_wts = "/04_29_2022 09:27:26" #(6,6,6) @ 384
76+
# fold_cnn_wts = "/DenseNetConfig/04_27_2022 12:05:05" # 666
77+
# fold_dkl_wts = "/04_21_2022 03:33:31"
78+
79+
## Path to the datasets
80+
if dataset == "focuspath":
81+
path_to_dataset = "path_to/FocusPath/"
82+
elif dataset == "tcgafocus":
83+
path_to_dataset = "path_to/tcgafocus/"
84+
elif dataset == "suh":
85+
path_to_dataset = "path_to/Processed/"
86+
else:
87+
if artifact == "blur":
88+
path_to_dataset = "path_to/artifact_dataset/blur/test/"
89+
if val:
90+
print("Validation blur subset from EMC")
91+
path_to_dataset = "path_to/artifact_dataset/blur/validation/"
92+
elif artifact == "fold":
93+
path_to_dataset = "path_to/artifact_dataset/fold_20x/test/"
94+
if val:
95+
print("Validation fold subset from EMC")
96+
path_to_dataset = "path_to/artifact_dataset/fold_20x/validation/"
97+
else:
98+
print("Dataset does not exists")
99+
100+
# Transform data
101+
test_compose = transforms.Compose([
102+
transforms.CenterCrop((224, 224)),
103+
# transforms.Resize((224, 224)),
104+
transforms.ToTensor(),
105+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
106+
107+
# Data Loaders to load data.
108+
t = time.time()
109+
print(f"Loading {dataset} - {artifact} Dataset...........")
110+
test_images = datasets.ImageFolder(root=path_to_dataset, transform=test_compose)
111+
total_patches = len(test_images)
112+
idx2class = {v: k for k, v in test_images.class_to_idx.items()}
113+
num_classes = len(test_images.classes)
114+
test_loader = DataLoader(test_images, batch_size=BATCH_SIZE, shuffle=False, sampler=None, num_workers=NUM_WORKER, pin_memory=True)
115+
classes_list = test_loader.dataset.classes
116+
class_distribution = get_class_distribution(test_images)
117+
print("Class distribution in training: ", class_distribution)
118+
print(f"Length of {artifact} testset {len(test_images)} with {num_classes} classes")
119+
print(f"Total data loading time in minutes: {(time.time() - t)/60:.3f}")
120+
121+
122+
now = datetime.now()
123+
date_time = now.strftime("%m_%d_%Y %H:%M:%S")
124+
print(f"Its {date_time}.\n")
125+
126+
# Loading models based on defined experimental setting.
127+
# Best model for blur was DKL with (10,10,10)
128+
# Best model for fold was DKL with (6,6,6)
129+
130+
if architecture == "DKL":
131+
if artifact == "blur":
132+
print(f"\nInitializing DKL for {artifact}...............")
133+
feature_extractor = DenseNetFeatureExtractor(block_config=(10, 10, 10), num_classes=num_classes)
134+
num_features = feature_extractor.classifier.in_features
135+
print("Number of output features for patch is ", num_features)
136+
model = DKLModel(feature_extractor, num_dim=num_features, grid_size=128)
137+
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=model.num_dim, num_classes=num_classes)
138+
pytorch_total_params = sum(p.numel() for p in model.parameters())
139+
print("Total model parameters (M): ", pytorch_total_params/1e6)
140+
best_model_wts = model_weights + "/blur_dkl.dat"
141+
# print("Loading model weights ")
142+
model.load_state_dict(torch.load(best_model_wts, map_location=torch.device('cpu'))['model'])
143+
# print("Loading likelihood weights ")
144+
likelihood.load_state_dict(torch.load(best_model_wts, map_location=torch.device('cpu'))['likelihood'])
145+
146+
else:
147+
print(f"\nInitializing DKL for {artifact}...............")
148+
feature_extractor = DenseNetFeatureExtractor(block_config=(6, 6, 6), num_classes=num_classes)
149+
num_features = feature_extractor.classifier.in_features
150+
print("Number of output features for patch is ", num_features)
151+
model = DKLModel(feature_extractor, num_dim=num_features, grid_size=128)
152+
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=model.num_dim, num_classes=num_classes)
153+
pytorch_total_params = sum(p.numel() for p in model.parameters())
154+
print("Total model parameters (M): ", pytorch_total_params/1e6)
155+
best_model_wts = model_weights + "/fold_dkl.dat"
156+
# print("Loading model weights ")
157+
model.load_state_dict(torch.load(best_model_wts, map_location=torch.device('cpu'))['model'])
158+
# print("Loading likelihood weights ")
159+
likelihood.load_state_dict(torch.load(best_model_wts, map_location=torch.device('cpu'))['likelihood'])
160+
161+
input_size = (1, 3, 224, 224)
162+
flops = count_flops(model, input_size)
163+
print("GFLOPs:", flops/1e9)
164+
165+
166+
if torch.cuda.is_available():
167+
print("Cuda is available")
168+
model = model.cuda()
169+
likelihood = likelihood.cuda()
170+
171+
path = os.path.join('path_to/emc/', f"{dataset}")
172+
if not os.path.exists(path):
173+
os.mkdir(path)
174+
175+
print("\nTesting Starts....................")
176+
y_pred, y_true, probs, mean1, epistemic, pred_var, lower_1c, upper_1c, feature, entropy = infer_dkl_v2(test_loader,
177+
model, likelihood,total_patches=total_patches, n_samples=100)
178+
179+
file_names = [im[0].split("/")[-1] for im in test_loader.dataset.imgs]
180+
data = {"files": file_names, "ground_truth": y_true, "prediction": y_pred, "probabilities": probs, \
181+
"mean1": mean1,"epistemic":epistemic, "variance":pred_var, "lower_conf": lower_1c, "upper_conf": upper_1c, "entropy": entropy}
182+
dframe = pd.DataFrame(data)
183+
184+
with pd.ExcelWriter(f"{path}/dkl_predictions_on_{dataset}_for_{artifact}.xlsx") as wr:
185+
dframe.to_excel(wr, index=False)
186+
187+
accuracy = accuracy_score(y_true, y_pred)
188+
print("Accuracy: ", accuracy)
189+
f1 = f1_score(y_true, y_pred)
190+
print("F1 Score: ", f1)
191+
roc = roc_auc_score(y_true, y_pred)
192+
print("ROC AUC Score: ", roc)
193+
mathew_corr = matthews_corrcoef(y_true, y_pred)
194+
print("Mathew Correlation Coefficient: ", mathew_corr)
195+
196+
acc_list, f1_list, roc_list, mcc_list, pred_list = [],[],[],[],[]
197+
198+
for i in range(repitions_for_p):
199+
y_pred, y_true, _, _ , _, _, _, _, _, _ = infer_dkl_v2(test_loader,
200+
model, likelihood,total_patches=total_patches, n_samples=100)
201+
accuracy = accuracy_score(y_true, y_pred)
202+
acc_list.append(accuracy)
203+
f1 = f1_score(y_true, y_pred)
204+
f1_list.append(f1)
205+
roc = roc_auc_score(y_true, y_pred)
206+
roc_list.append(roc)
207+
mathew_corr = matthews_corrcoef(y_true, y_pred)
208+
mcc_list.append(mathew_corr)
209+
pred_list.append(y_pred)
210+
211+
print(acc_list)
212+
# p_value = stats.ttest_rel(pred_list[0], pred_list[1]).pvalue
213+
# print("\nP-value for DKL is", p_value)
214+
215+
print("\nAccuracy mean: ",statistics.mean(acc_list)," Accuracy std: ", statistics.stdev(acc_list))
216+
print("\nF1 mean: ",statistics.mean(f1_list)," F1 std: ",statistics.stdev(f1_list))
217+
print("\nROC mean: ",statistics.mean(roc_list)," ROC std: ",statistics.stdev(roc_list))
218+
print("\nMCC mean: ",statistics.mean(mcc_list)," MCC std: ",statistics.stdev(mcc_list))
219+
220+
else:
221+
if artifact == "blur":
222+
print(f"\nInitializing CNN baseline of DenseNet (10,10,10) for {artifact}...............")
223+
model = models.DenseNet(block_config = (10,10,10), growth_rate=12, num_init_features=24)
224+
num_features = model.classifier.in_features# 2208 --> less than 256
225+
model.classifier = custom_classifier(num_features, num_classes, dropout=dropout)
226+
print("Number of out features for patch is ", num_features)
227+
pytorch_total_params = sum(p.numel() for p in model.parameters())
228+
print("Total model parameters (M): ", pytorch_total_params/1e6)
229+
# best_model_wts = base_location + blur_cnn_wts
230+
model.load_state_dict(torch.load(model_weights + "/blur_cnn.dat",map_location=torch.device('cpu'))['model'])
231+
else:
232+
print(f"Initializing CNN DenseNet baseline of (6,6,6) Model for {artifact}...............")
233+
model = models.DenseNet(block_config = (6,6,6), growth_rate=12, num_init_features=24)
234+
num_features = model.classifier.in_features# 2208 --> less than 256
235+
model.classifier = custom_classifier(num_features, num_classes, dropout=dropout)
236+
print("Number of out features for patch is ", num_features)
237+
pytorch_total_params = sum(p.numel() for p in model.parameters())
238+
print("Total model parameters (M): ", pytorch_total_params/1e6)
239+
# best_model_wts = base_location + fold_cnn_wts
240+
model.load_state_dict(torch.load(model_weights + "/fold_cnn.dat",map_location=torch.device('cpu'))['model'])
241+
242+
input_size = (1, 3, 224, 224)
243+
flops = count_flops(model, input_size)
244+
print("GFLOPs:", flops/1e9)
245+
246+
247+
if torch.cuda.is_available():
248+
print("Cuda is available")# model should be on uda before selection of optimizer
249+
model = model.cuda()
250+
251+
path = os.path.join('path_to/', f"{dataset}")
252+
if not os.path.exists(path):
253+
os.mkdir(path)
254+
255+
print("\nTesting Starts....................")
256+
y_pred, y_true, probs, mean1, epistemic, pred_var, lower_1c, upper_1c = infer_cnn_v2(test_loader, model,total_patches=total_patches, n_samples=100)
257+
258+
259+
file_names = [im[0].split("/")[-1] for im in test_loader.dataset.imgs]
260+
data = {"files": file_names, "ground_truth": y_true, "prediction": y_pred, "probabilities": probs, \
261+
"mean1": mean1, "epistemic": epistemic, "variance":pred_var, "lower_conf": lower_1c, "upper_conf": upper_1c}
262+
dframe = pd.DataFrame(data)
263+
264+
with pd.ExcelWriter(f"{path}/cnn_predictions_on_{dataset}_for_{artifact}.xlsx") as wr:
265+
dframe.to_excel(wr, index=False)
266+
267+
268+
accuracy = accuracy_score(y_true, y_pred)
269+
print("Accuracy: ", accuracy)
270+
f1 = f1_score(y_true, y_pred)
271+
print("F1 Score: ", f1)
272+
roc = roc_auc_score(y_true, y_pred)
273+
print("ROC AUC Score: ", roc)
274+
mathew_corr = matthews_corrcoef(y_true, y_pred)
275+
print("Mathew Correlation Coefficient: ", mathew_corr)
276+
277+
acc_list, f1_list, roc_list, mcc_list, pred_list = [],[],[],[],[]
278+
279+
for i in range(repitions_for_p):
280+
y_pred, y_true, _, _ , _, _, _, _ = infer_cnn_v2(test_loader, model, total_patches=total_patches, n_samples=100)
281+
accuracy = accuracy_score(y_true, y_pred)
282+
acc_list.append(accuracy)
283+
f1 = f1_score(y_true, y_pred)
284+
f1_list.append(f1)
285+
roc = roc_auc_score(y_true, y_pred)
286+
roc_list.append(roc)
287+
mathew_corr = matthews_corrcoef(y_true, y_pred)
288+
mcc_list.append(mathew_corr)
289+
pred_list.append(y_pred)
290+
291+
print(acc_list)
292+
# p_value = stats.ttest_rel(pred_list[0], pred_list[1]).pvalue
293+
# print("\nP-value for CNN is", p_value)
294+
295+
print("\nAccuracy mean: ",statistics.mean(acc_list)," Accuracy std: ", statistics.stdev(acc_list))
296+
print("\nF1 mean: ",statistics.mean(f1_list)," F1 std: ",statistics.stdev(f1_list))
297+
print("\nROC mean: ",statistics.mean(roc_list)," ROC std: ",statistics.stdev(roc_list))
298+
print("\nMCC mean: ",statistics.mean(mcc_list)," MCC std: ",statistics.stdev(mcc_list))
299+
300+
print("#####################################################")
301+
print(f"Total time in minutes: {(time.time() - t)/60:.3f}")

0 commit comments

Comments
 (0)