Skip to content

Commit d7ae36e

Browse files
committed
Adapt to changes in KNIME Deep Learning
1 parent fdad72a commit d7ae36e

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

org.knime.knip.dl/src/org/knime/ip/dl/DLKnipUtil.java

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package org.knime.ip.dl;
22

3-
import org.knime.dl.core.DLDefaultDimensionOrder;
43
import org.knime.dl.core.DLDimension;
54
import org.knime.dl.core.DLDimensionOrder;
65
import org.knime.dl.core.DLTensorSpec;
@@ -17,71 +16,80 @@ final class DLKnipUtil {
1716
private DLKnipUtil() {
1817
// utility class
1918
}
20-
21-
public static <T extends RealType<T>> RandomAccessibleInterval<T> mapImgToDL(ImgPlusValue<T> img, DLTensorSpec tensorSpec) {
19+
20+
public static <T extends RealType<T>> RandomAccessibleInterval<T> mapImgToDL(ImgPlusValue<T> img,
21+
DLTensorSpec tensorSpec) {
2222
int[] mapping = calculateMapping(img, tensorSpec);
2323
return DimSwapper.swap(img.getImgPlus(), mapping);
2424
}
25-
25+
2626
private static DLDimensionOrder extractDimensionOrder(DLTensorSpec tensorSpec) {
27-
if (tensorSpec.getDimensionOrder() == DLDefaultDimensionOrder.Unknown) {
27+
if (tensorSpec.getDimensionOrder() == DLDimensionOrder.Unknown) {
2828
throw new IllegalArgumentException(
2929
"Can't infer shape from image if the dimension order of the input tensor is unknown");
3030
}
3131
return tensorSpec.getDimensionOrder();
3232
}
33-
34-
public static <T extends RealType<T>> long[] getShapeFromImg(final ImgPlusValue<T> img, final DLTensorSpec tensorSpec) {
33+
34+
public static <T extends RealType<T>> long[] getShapeFromImg(final ImgPlusValue<T> img,
35+
final DLTensorSpec tensorSpec) {
3536
int[] mapping = calculateMapping(img, tensorSpec);
3637
return mapShape(img.getDimensions(), mapping);
3738
}
38-
39+
3940
private static <T extends RealType<T>> int[] calculateMapping(final ImgPlusValue<T> img, DLTensorSpec tensorSpec) {
4041
DLDimensionOrder tensorDimensionOrder = extractDimensionOrder(tensorSpec);
4142
DLDimension[] imgDimensionOrder = getDimensionOrder(getAxes(img.getMetadata()));
4243
return tensorDimensionOrder.inferMappingFor(imgDimensionOrder);
4344
}
44-
45+
4546
private static long[] mapShape(final long[] imgShape, final int[] mapping) {
4647
assert imgShape.length == mapping.length;
4748
long[] mappedShape = new long[imgShape.length];
4849
for (int i = 0; i < mappedShape.length; i++) {
4950
// in KNIP the last dimension changes the slowest (e.g. C in XYC) while
50-
// in deep learning (especially TensorFlow) the last dimension changes the fastest.
51+
// in deep learning (especially TensorFlow) the last dimension changes the
52+
// fastest.
5153
mappedShape[i] = imgShape[imgShape.length - mapping[i] - 1];
5254
}
5355
return mappedShape;
5456
}
55-
57+
5658
private static DLDimension[] getDimensionOrder(CalibratedAxis[] axes) {
5759
DLDimension[] dimOrder = new DLDimension[axes.length];
5860
for (int i = 0; i < axes.length; i++) {
5961
// in KNIP the last dimension changes the slowest (e.g. C in XYC) while
60-
// in deep learning (especially TensorFlow) the last dimension changes the fastest.
62+
// in deep learning (especially TensorFlow) the last dimension changes the
63+
// fastest.
6164
dimOrder[i] = axisToDimension(axes[axes.length - i - 1]);
6265
}
6366
return dimOrder;
6467
}
65-
68+
6669
private static CalibratedAxis[] getAxes(final ImgPlusMetadata metaData) {
6770
CalibratedAxis[] axes = new CalibratedAxis[metaData.numDimensions()];
6871
metaData.axes(axes);
6972
return axes;
7073
}
71-
74+
7275
private static DLDimension axisToDimension(CalibratedAxis axis) {
7376
switch (axis.type().getLabel()) {
74-
case "X": return DLDimension.Width;
75-
case "Y": return DLDimension.Height;
76-
case "Z": return DLDimension.Depth;
77-
case "Channel": return DLDimension.Channel;
78-
case "Time": return DLDimension.Time;
77+
case "X":
78+
return DLDimension.Width;
79+
case "Y":
80+
return DLDimension.Height;
81+
case "Z":
82+
return DLDimension.Depth;
83+
case "Channel":
84+
return DLDimension.Channel;
85+
case "Time":
86+
return DLDimension.Time;
7987

8088
default:
8189
throw new IllegalArgumentException("Unknown axis '" + axis.type().getLabel() + "' encountered.");
8290
}
8391
}
84-
92+
8593
static long[] reverseShape(long[] shape) {
8694
long[] reversedShape = new long[shape.length];
8795
for (int i = 0; i < shape.length; i++) {

0 commit comments

Comments
 (0)