Skip to content

Commit f7ef39c

Browse files
committed
tf.flayers: Added flatten
1 parent 78b81a2 commit f7ef39c

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

src/TensorFlowNET.Core/APIs/tf.layers.cs

+34
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Collections.Generic;
1718
using Tensorflow.Keras.Layers;
1819
using Tensorflow.Operations.Activation;
1920
using static Tensorflow.Binding;
@@ -163,6 +164,39 @@ public Tensor dense(Tensor inputs,
163164

164165
return layer.apply(inputs);
165166
}
167+
168+
/// <summary>
169+
/// Flattens an input tensor while preserving the batch axis (axis 0).
170+
/// </summary>
171+
/// <param name="inputs">Tensor input.</param>
172+
/// <param name="name">The name of the layer.</param>
173+
/// <param name="data_format">
174+
/// A string, one of `channels_last` (default) or `channels_first`. <br></br>
175+
/// The ordering of the dimensions in the inputs. <br></br>
176+
/// `channels_last` corresponds to inputs with shape <br></br>
177+
/// `(batch, height, width, channels)` while `channels_first` corresponds to <br></br>
178+
/// inputs with shape `(batch, channels, height, width)`.
179+
/// </param>
180+
/// <returns></returns>
181+
public Tensor flatten(Tensor inputs,
182+
string name = null,
183+
string data_format = "channels_last")
184+
{
185+
if (inputs.shape.Length == 0)
186+
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");
187+
188+
var premutation = new List<int>() {0};
189+
if (data_format == "channels_first" && inputs.NDims > 1)
190+
{
191+
premutation.AddRange(Binding.range(2, inputs.NDims));
192+
premutation.Add(1);
193+
inputs = array_ops.transpose(inputs, premutation.ToArray());
194+
}
195+
196+
var ret = array_ops.reshape(inputs, new int[] {inputs.shape[0], -1});
197+
ret.set_shape(new int[] {inputs.shape[0], -1});
198+
return ret;
199+
}
166200
}
167201
}
168202
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using FluentAssertions;
3+
using Microsoft.VisualStudio.TestTools.UnitTesting;
4+
using NumSharp;
5+
using Tensorflow;
6+
using static Tensorflow.Binding;
7+
8+
namespace TensorFlowNET.UnitTest.layers_test
9+
{
10+
[TestClass]
11+
public class flatten
12+
{
13+
[TestMethod]
14+
public void Case1()
15+
{
16+
var sess = tf.Session().as_default();
17+
18+
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(3, 4, 3, 1, 2));
19+
sess.run(tf.layers.flatten(input), (input, np.arange(3 * 4 * 3 * 1 * 2).reshape(3, 4, 3, 1, 2))).Should().BeShaped(3, 24);
20+
}
21+
22+
[TestMethod]
23+
public void Case2()
24+
{
25+
var sess = tf.Session().as_default();
26+
27+
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
28+
sess.run(tf.layers.flatten(input), (input, np.arange(6))).Should().BeShaped(6, 1);
29+
}
30+
31+
[TestMethod]
32+
public void Case3()
33+
{
34+
var sess = tf.Session().as_default();
35+
36+
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape());
37+
new Action(() => sess.run(tf.layers.flatten(input), (input, NDArray.Scalar(6)))).Should().Throw<ValueError>();
38+
}
39+
}
40+
}

0 commit comments

Comments
 (0)