forked from PaddlePaddle/PaddleHub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
88 lines (73 loc) · 3.2 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#coding:utf-8
import argparse
import os
import ast
import paddle.fluid as fluid
import paddlehub as hub
import numpy as np
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for predict.")
parser.add_argument("--checkpoint_dir", type=str, default="paddlehub_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="resnet50", help="Module used as a feature extractor.")
parser.add_argument("--dataset", type=str, default="flowers", help="Dataset to finetune.")
# yapf: enable.
module_map = {
"resnet50": "resnet_v2_50_imagenet",
"resnet101": "resnet_v2_101_imagenet",
"resnet152": "resnet_v2_152_imagenet",
"mobilenet": "mobilenet_v2_imagenet",
"nasnet": "nasnet_imagenet",
"pnasnet": "pnasnet_imagenet"
}
def predict(args):
# Load Paddlehub pretrained model
module = hub.Module(name=args.module)
input_dict, output_dict, program = module.context(trainable=True)
# Download dataset
if args.dataset.lower() == "flowers":
dataset = hub.dataset.Flowers()
elif args.dataset.lower() == "dogcat":
dataset = hub.dataset.DogCat()
elif args.dataset.lower() == "indoor67":
dataset = hub.dataset.Indoor67()
elif args.dataset.lower() == "food101":
dataset = hub.dataset.Food101()
elif args.dataset.lower() == "stanforddogs":
dataset = hub.dataset.StanfordDogs()
else:
raise ValueError("%s dataset is not defined" % args.dataset)
# Use ImageClassificationReader to read dataset
data_reader = hub.reader.ImageClassificationReader(
image_width=module.get_expected_image_width(),
image_height=module.get_expected_image_height(),
images_mean=module.get_pretrained_images_mean(),
images_std=module.get_pretrained_images_std(),
dataset=dataset)
feature_map = output_dict["feature_map"]
# Setup feed list for data feeder
feed_list = [input_dict["image"].name]
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
use_data_parallel=False,
use_cuda=args.use_gpu,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
# Define a reading comprehension finetune task by PaddleHub's API
task = hub.ImageClassifierTask(
data_reader=data_reader,
feed_list=feed_list,
feature=feature_map,
num_classes=dataset.num_labels,
config=config)
data = ["./test/test_img_daisy.jpg", "./test/test_img_roses.jpg"]
print(task.predict(data=data, return_result=True))
if __name__ == "__main__":
args = parser.parse_args()
if not args.module in module_map:
hub.logger.error("module should in %s" % module_map.keys())
exit(1)
args.module = module_map[args.module]
predict(args)