forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInceptionArchGoogLeNet.cs
133 lines (108 loc) · 4.41 KB
/
InceptionArchGoogLeNet.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
using NumSharp;
using System;
using System.IO;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
namespace TensorFlowNET.Examples
{
/// <summary>
/// Inception Architecture for Computer Vision
/// Port from tensorflow\examples\label_image\label_image.py
/// </summary>
public class InceptionArchGoogLeNet : IExample
{
public bool Enabled { get; set; } = false;
public string Name => "Inception Arch GoogLeNet";
public bool IsImportingGraph { get; set; } = false;
string dir = "label_image_data";
string pbFile = "inception_v3_2016_08_28_frozen.pb";
string labelFile = "imagenet_slim_labels.txt";
string picFile = "grace_hopper.jpg";
int input_height = 299;
int input_width = 299;
int input_mean = 0;
int input_std = 255;
string input_name = "import/input";
string output_name = "import/InceptionV3/Predictions/Reshape_1";
public bool Run()
{
PrepareData();
var labels = File.ReadAllLines(Path.Join(dir, labelFile));
var nd = ReadTensorFromImageFile(Path.Join(dir, picFile),
input_height: input_height,
input_width: input_width,
input_mean: input_mean,
input_std: input_std);
var graph = new Graph();
graph.Import(Path.Join(dir, pbFile));
var input_operation = graph.get_operation_by_name(input_name);
var output_operation = graph.get_operation_by_name(output_name);
NDArray results;
using (var sess = tf.Session(graph))
{
results = sess.run(output_operation.outputs[0],
new FeedItem(input_operation.outputs[0], nd));
}
results = np.squeeze(results);
var argsort = results.argsort<float>();
var top_k = argsort.Data<float>()
.Skip(results.size - 5)
.Reverse()
.ToArray();
foreach (float idx in top_k)
Console.WriteLine($"{picFile}: {idx} {labels[(int)idx]}, {results[(int)idx]}");
return true;
}
private NDArray ReadTensorFromImageFile(string file_name,
int input_height = 299,
int input_width = 299,
int input_mean = 0,
int input_std = 255)
{
var graph = tf.Graph().as_default();
var file_reader = tf.read_file(file_name, "file_reader");
var image_reader = tf.image.decode_jpeg(file_reader, channels: 3, name: "jpeg_reader");
var caster = tf.cast(image_reader, tf.float32);
var dims_expander = tf.expand_dims(caster, 0);
var resize = tf.constant(new int[] { input_height, input_width });
var bilinear = tf.image.resize_bilinear(dims_expander, resize);
var sub = tf.subtract(bilinear, new float[] { input_mean });
var normalized = tf.divide(sub, new float[] { input_std });
using (var sess = tf.Session(graph))
return sess.run(normalized);
}
public void PrepareData()
{
Directory.CreateDirectory(dir);
// get model file
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz";
Utility.Web.Download(url, dir, $"{pbFile}.tar.gz");
Utility.Compress.ExtractTGZ(Path.Join(dir, $"{pbFile}.tar.gz"), dir);
// download sample picture
string pic = "grace_hopper.jpg";
url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}";
Utility.Web.Download(url, dir, pic);
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public void Train(Session sess)
{
throw new NotImplementedException();
}
public void Predict(Session sess)
{
throw new NotImplementedException();
}
public void Test(Session sess)
{
throw new NotImplementedException();
}
}
}