Skip to content

Commit c95f7a6

Browse files
author
Timothy Spann
authored
Merge pull request #4 from simonellistonball/master
Added more output to attributes
2 parents 242ba02 + 33dece3 commit c95f7a6

File tree

9 files changed

+68
-51
lines changed

9 files changed

+68
-51
lines changed

build.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
mvn install -DskipTests
1+
#!/bin/sh
2+
3+
mvn clean package

nifi-tensorflow-nar/.gitignore

Lines changed: 0 additions & 1 deletion
This file was deleted.

nifi-tensorflow-nar/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
<parent>
2020
<groupId>com.dataflowdeveloper</groupId>
2121
<artifactId>tensorflow-processor</artifactId>
22-
<version>2.0</version>
22+
<version>2.1</version>
2323
</parent>
2424

2525
<artifactId>nifi-tensorflow-nar</artifactId>
@@ -33,7 +33,7 @@
3333
<dependency>
3434
<groupId>com.dataflowdeveloper</groupId>
3535
<artifactId>nifi-tensorflow-processors</artifactId>
36-
<version>2.0</version>
36+
<version>2.1</version>
3737
</dependency>
3838
</dependencies>
3939

nifi-tensorflow-processors/.gitignore

Lines changed: 0 additions & 1 deletion
This file was deleted.

nifi-tensorflow-processors/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
<parent>
2020
<groupId>com.dataflowdeveloper</groupId>
2121
<artifactId>tensorflow-processor</artifactId>
22-
<version>2.0</version>
22+
<version>2.1</version>
2323
</parent>
2424

2525
<artifactId>nifi-tensorflow-processors</artifactId>

nifi-tensorflow-processors/src/main/java/com/dataflowdeveloper/processors/process/TensorFlowProcessor.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,18 @@
2323

2424
import java.util.ArrayList;
2525
import java.util.Collections;
26+
import java.util.HashMap;
2627
import java.util.HashSet;
2728
import java.util.List;
29+
import java.util.Map;
30+
import java.util.Map.Entry;
2831
import java.util.Set;
32+
import java.util.SortedMap;
33+
import java.util.TreeMap;
34+
import java.util.concurrent.ConcurrentHashMap;
35+
import java.util.concurrent.ConcurrentMap;
36+
import java.util.stream.Collector;
37+
import java.util.stream.Collectors;
2938

3039
import org.apache.commons.io.IOUtils;
3140
import org.apache.nifi.annotation.behavior.EventDriven;
@@ -64,7 +73,7 @@
6473
*/
6574
public class TensorFlowProcessor extends AbstractProcessor {
6675

67-
public static final String ATTRIBUTE_OUTPUT_NAME = "probabilities";
76+
public static final String ATTRIBUTE_OUTPUT_NAME = "tf.probabilities";
6877
public static final String MODEL_DIR_NAME = "modeldir";
6978
public static final String PROPERTY_NAME_EXTRA = "Extra Resources";
7079

@@ -134,21 +143,28 @@ public void onTrigger(final ProcessContext context, final ProcessSession session
134143
service = new TensorFlowService();
135144
// read all bytes of the flowfile (tensor requires whole image)
136145
InputStream is = session.read(flowFile);
137-
String value;
146+
List<Entry<Float, String>> results;
138147
try {
139148
byte[] byteArray = IOUtils.toByteArray(is);
140-
value = service.getInception(byteArray, modelDir);
141-
149+
results = service.getInception(byteArray, modelDir).limit(10).collect(Collectors.toList());
142150
} catch(Exception e) {
143151
throw new ProcessException(e);
144152
} finally {
145153
is.close();
146154
}
147155

148-
if (value == null) {
156+
if (results == null) {
149157
session.transfer(flowFile, REL_UNMATCHED);
150158
} else {
151-
flowFile = session.putAttribute(flowFile, ATTRIBUTE_OUTPUT_NAME, value);
159+
HashMap<String,String> attributes = new HashMap<String,String>(results.size() * 2);
160+
for(int i = 0; i < results.size(); i++) {
161+
Object[] key = new Object[] { ATTRIBUTE_OUTPUT_NAME, i };
162+
163+
Entry<Float, String> entry = results.get(i);
164+
attributes.put(String.format("%s.%d.label", key), entry.getValue());
165+
attributes.put(String.format("%s.%d.probability", key), entry.getKey().toString());
166+
}
167+
flowFile = session.putAllAttributes(flowFile, attributes);
152168
session.transfer(flowFile, REL_SUCCESS);
153169
}
154170
session.commit();

nifi-tensorflow-processors/src/main/java/com/dataflowdeveloper/processors/process/TensorFlowService.java

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@
88
import java.nio.file.Path;
99
import java.nio.file.Paths;
1010
import java.util.Arrays;
11+
import java.util.Collections;
1112
import java.util.HashMap;
1213
import java.util.List;
1314
import java.util.Map;
15+
import java.util.Map.Entry;
16+
import java.util.stream.Collectors;
1417

18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
1520
import org.tensorflow.DataType;
1621
import org.tensorflow.Graph;
1722
import org.tensorflow.Output;
@@ -26,31 +31,44 @@
2631
*/
2732
public class TensorFlowService {
2833

34+
private final Logger logger = LoggerFactory.getLogger(TensorFlowService.class);
35+
2936
private Map<Path, Graph> modelCache = new HashMap<Path, Graph>();
3037
private Map<Path, List<String>> labelCache = new HashMap<Path, List<String>>();
3138

32-
public String getInception(byte[] imageBytes, String modelDir) {
39+
public List<Entry<Float, String>> getInception(byte[] imageBytes, String modelDir) {
40+
logger.info(String.format("getInception: %d bytes %s", new Object[] { imageBytes.length, Paths.get(modelDir, "graph.pb") }));
3341
Graph g = getOrCreate(Paths.get(modelDir, "graph.pb"));
3442
try (Session s = new Session(g)) {
3543
List<String> labels = getOrCreateLabels(Paths.get(modelDir, "label.txt"));
36-
try {
37-
Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
38-
Tensor result = s.runner().feed("input", image).fetch("output").run().get(0);
39-
final long[] rshape = result.shape();
40-
if (result.numDimensions() != 2 || rshape[0] != 1) {
41-
throw new RuntimeException(String.format(
42-
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
43-
Arrays.toString(rshape)));
44-
}
45-
int nlabels = (int) rshape[1];
46-
float[] labelProbabilities = result.copyTo(new float[1][nlabels])[0];
47-
int bestLabelIdx = maxIndex(labelProbabilities);
48-
return String.format("BEST MATCH: %s (%.2f%% likely)", labels.get(bestLabelIdx),
49-
labelProbabilities[bestLabelIdx] * 100f);
50-
} catch (Exception x) {
51-
x.printStackTrace();
44+
Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
45+
Tensor result = s.runner().feed("input", image).fetch("output").run().get(0);
46+
logger.debug("found results");
47+
48+
final long[] rshape = result.shape();
49+
if (result.numDimensions() != 2 || rshape[0] != 1) {
50+
throw new RuntimeException(String.format(
51+
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
52+
Arrays.toString(rshape)));
53+
}
54+
int nlabels = (int) rshape[1];
55+
56+
logger.debug(String.format("number of labels %d, %d", new Object[] { labels.size(), nlabels }));
57+
int mLabeled = Math.min(labels.size(), nlabels);
58+
59+
float[] labelProbabilities = result.copyTo(new float[1][nlabels])[0];
60+
61+
HashMap<Float, String> results = new HashMap<Float, String>();
62+
for (int i = 0; i < mLabeled; i++) {
63+
results.put(labelProbabilities[i], labels.get(i));
5264
}
53-
return "Unknown";
65+
66+
return Collections.synchronizedList(
67+
results.entrySet().stream().sorted(Collections.reverseOrder(Map.Entry.comparingByKey())).limit(10)
68+
.collect(Collectors.toList()));
69+
} catch (Exception e) {
70+
logger.error("Failed in tensorflow", e);
71+
throw(e);
5472
}
5573
}
5674

@@ -104,16 +122,6 @@ private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes
104122
}
105123
}
106124

107-
private static int maxIndex(float[] probabilities) {
108-
int best = 0;
109-
for (int i = 1; i < probabilities.length; ++i) {
110-
if (probabilities[i] > probabilities[best]) {
111-
best = i;
112-
}
113-
}
114-
return best;
115-
}
116-
117125
private static byte[] readAllBytesOrExit(Path path) {
118126
try {
119127
return Files.readAllBytes(path);

nifi-tensorflow-processors/src/test/java/com/dataflowdeveloper/processors/process/TensorFlowProcessorTest.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,14 @@
1616
*/
1717
package com.dataflowdeveloper.processors.process;
1818

19-
import static org.junit.Assert.assertNotNull;
19+
import static org.junit.Assert.assertEquals;
2020

21-
import java.io.File;
22-
import java.io.FileInputStream;
23-
import java.io.FileNotFoundException;
24-
import java.io.UnsupportedEncodingException;
2521
import java.net.URI;
2622
import java.net.URISyntaxException;
2723
import java.net.URL;
2824
import java.nio.file.Paths;
2925
import java.util.List;
3026

31-
import org.apache.nifi.components.PropertyDescriptor;
32-
import org.apache.nifi.processor.util.StandardValidators;
3327
import org.apache.nifi.util.MockFlowFile;
3428
import org.apache.nifi.util.TestRunner;
3529
import org.apache.nifi.util.TestRunners;
@@ -76,8 +70,7 @@ private void runAndAssertHappy() {
7670
List<MockFlowFile> successFiles = testRunner.getFlowFilesForRelationship(TensorFlowProcessor.REL_SUCCESS);
7771

7872
for (MockFlowFile mockFile : successFiles) {
79-
System.out.println("Attribute: " + mockFile.getAttribute(TensorFlowProcessor.ATTRIBUTE_OUTPUT_NAME));
80-
assertNotNull(mockFile.getAttribute(TensorFlowProcessor.ATTRIBUTE_OUTPUT_NAME));
73+
assertEquals("giant panda", mockFile.getAttribute("tf.probabilities.0.label"));
8174
}
8275
}
8376
}

pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
<parent>
2020
<groupId>org.apache.nifi</groupId>
2121
<artifactId>nifi-nar-bundles</artifactId>
22-
<version>1.0.0</version>
22+
<version>1.3.0</version>
2323
</parent>
2424

2525
<groupId>com.dataflowdeveloper</groupId>
2626
<artifactId>tensorflow-processor</artifactId>
27-
<version>2.0</version>
27+
<version>2.1</version>
2828
<packaging>pom</packaging>
2929

3030
<modules>

0 commit comments

Comments
 (0)