diff --git a/src/TensorFlowNET.Core/Assembly/Properties.cs b/src/TensorFlowNET.Core/Assembly/Properties.cs new file mode 100644 index 000000000..28aee65e2 --- /dev/null +++ b/src/TensorFlowNET.Core/Assembly/Properties.cs @@ -0,0 +1,4 @@ +using System.Runtime.CompilerServices; +#if DEBUG +[assembly: InternalsVisibleTo("TensorFlowNET.UnitTest, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] +#endif diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index bfbfa4eca..d723283f7 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -178,13 +178,18 @@ public static float time() public static IEnumerable<(TKey, TValue)> enumerate(KeyValuePair[] values) { - foreach (var item in values) + var len = values.Length; + for (var i = 0; i < len; i++) + { + var item = values[i]; yield return (item.Key, item.Value); + } } public static IEnumerable<(int, T)> enumerate(IList values) { - for (int i = 0; i < values.Count; i++) + var len = values.Count; + for (int i = 0; i < len; i++) yield return (i, values[i]); } diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 396fb3111..c08d31753 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -15,58 +15,116 @@ limitations under the License. ******************************************************************************/ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using NumSharp.Backends.Unmanaged; +using static Tensorflow.c_api; namespace Tensorflow { + /// + /// Represents a TF_Buffer that can be passed to Tensorflow. + /// public class Buffer : DisposableObject { - private TF_Buffer buffer => Marshal.PtrToStructure(_handle); + private unsafe TF_Buffer buffer + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => *bufferptr; + } + + private unsafe TF_Buffer* bufferptr + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => (TF_Buffer*) _handle; + } - public byte[] Data + /// + /// The memory block representing this buffer. + /// + /// The deallocator is set to null. + public UnmanagedMemoryBlock MemoryBlock { - get + get { - var data = new byte[buffer.length]; - if (data.Length > 0) - Marshal.Copy(buffer.data, data, 0, data.Length); - return data; + unsafe + { + EnsureNotDisposed(); + var buff = (TF_Buffer*) _handle; + return new UnmanagedMemoryBlock((byte*) buff->data.ToPointer(), (long) buff->length); + } } } - public int Length => (int)buffer.length; - - public Buffer() + /// + /// The bytes length of this buffer. + /// + public ulong Length { - _handle = c_api.TF_NewBuffer(); + get + { + EnsureNotDisposed(); + return buffer.length; + } } - public Buffer(IntPtr handle) + public Buffer() => _handle = TF_NewBuffer(); + + internal Buffer(IntPtr handle) { + if (handle == IntPtr.Zero) + throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle)); + _handle = handle; } - public Buffer(byte[] data) - { - var dst = Marshal.AllocHGlobal(data.Length); - Marshal.Copy(data, 0, dst, data.Length); + public Buffer(byte[] data) : this(_toBuffer(data)) + { } - _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length); + private static IntPtr _toBuffer(byte[] data) + { + if (data == null) + throw new ArgumentNullException(nameof(data)); - Marshal.FreeHGlobal(dst); + unsafe + { + fixed (byte* src = data) + return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength); + } } public static implicit operator IntPtr(Buffer buffer) { + buffer.EnsureNotDisposed(); return buffer._handle; } - public static implicit operator byte[](Buffer buffer) + public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost. + + /// + /// Copies this buffer's contents onto a array. + /// + public byte[] ToArray() { - return buffer.Data; + EnsureNotDisposed(); + + unsafe + { + var len = buffer.length; + if (len == 0) + return Array.Empty(); + + byte[] data = new byte[len]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len); + + return data; + } } protected override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TF_DeleteBuffer(handle); + { + TF_DeleteBuffer(handle); + } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 688ac92c8..53a15abc2 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -16,6 +16,8 @@ limitations under the License. using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text; namespace Tensorflow @@ -26,27 +28,33 @@ namespace Tensorflow public abstract class DisposableObject : IDisposable { protected IntPtr _handle; + protected bool _disposed; - protected DisposableObject() { } + [SuppressMessage("ReSharper", "UnusedMember.Global")] + protected DisposableObject() + { } - protected DisposableObject(IntPtr handle) + protected DisposableObject(IntPtr handle) => _handle = handle; + [SuppressMessage("ReSharper", "InvertIf")] private void internal_dispose(bool disposing) { - if (disposing) - { - // free unmanaged resources (unmanaged objects) and override a finalizer below. - if (_handle != IntPtr.Zero) - { - // dispose managed state (managed objects). - DisposeManagedResources(); + if (_disposed) + return; + + _disposed = true; - // set large fields to null. - DisposeUnmanagedResources(_handle); + //first handle managed, they might use the unmanaged resources. + if (disposing) + // dispose managed state (managed objects). + DisposeManagedResources(); - _handle = IntPtr.Zero; - } + //free unmanaged memory + if (_handle != IntPtr.Zero) + { + DisposeUnmanagedResources(_handle); + _handle = IntPtr.Zero; } } @@ -55,28 +63,33 @@ private void internal_dispose(bool disposing) /// /// Equivalent to what you would perform inside protected virtual void DisposeManagedResources() - { - } + { } /// /// Dispose any unmanaged resources related to given . /// protected abstract void DisposeUnmanagedResources(IntPtr handle); - // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. ~DisposableObject() { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. internal_dispose(false); } - // This code added to correctly implement the disposable pattern. public void Dispose() { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. internal_dispose(true); - // uncomment the following line if the finalizer is overridden above. GC.SuppressFinalize(this); } + + /// + /// If is then throws + /// + /// When is + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected void EnsureNotDisposed() + { + if (_disposed) + throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}"); + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/Context.cs b/src/TensorFlowNET.Core/Eager/Context.cs index 4ee43d35f..700e12362 100644 --- a/src/TensorFlowNET.Core/Eager/Context.cs +++ b/src/TensorFlowNET.Core/Eager/Context.cs @@ -2,12 +2,10 @@ namespace Tensorflow.Eager { - public class Context : IDisposable + public class Context : DisposableObject { - private IntPtr _handle; - - public static int GRAPH_MODE = 0; - public static int EAGER_MODE = 1; + public const int GRAPH_MODE = 0; + public const int EAGER_MODE = 1; public int default_execution_mode; @@ -17,19 +15,16 @@ public Context(ContextOptions opts, Status status) status.Check(true); } - public void Dispose() - { - c_api.TFE_DeleteContext(_handle); - } + /// + /// Dispose any unmanaged resources related to given . + /// + protected sealed override void DisposeUnmanagedResources(IntPtr handle) + => c_api.TFE_DeleteContext(_handle); - public bool executing_eagerly() - { - return false; - } - public static implicit operator IntPtr(Context ctx) - { - return ctx._handle; - } + public bool executing_eagerly() => false; + + public static implicit operator IntPtr(Context ctx) + => ctx._handle; } } diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs index 4bdf04b35..12c4cdfc4 100644 --- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs +++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs @@ -3,23 +3,20 @@ namespace Tensorflow.Eager { - public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject? + public class ContextOptions : DisposableObject { - private IntPtr _handle; + public ContextOptions() : base(c_api.TFE_NewContextOptions()) + { } - public ContextOptions() - { - _handle = c_api.TFE_NewContextOptions(); - } + /// + /// Dispose any unmanaged resources related to given . + /// + protected sealed override void DisposeUnmanagedResources(IntPtr handle) + => c_api.TFE_DeleteContextOptions(_handle); - public void Dispose() - { - c_api.TFE_DeleteContextOptions(_handle); - } - public static implicit operator IntPtr(ContextOptions opts) - { - return opts._handle; - } + public static implicit operator IntPtr(ContextOptions opts) + => opts._handle; } + } diff --git a/src/TensorFlowNET.Core/Exceptions/KeyError.cs b/src/TensorFlowNET.Core/Exceptions/KeyError.cs index 8cecae76d..949fd3094 100644 --- a/src/TensorFlowNET.Core/Exceptions/KeyError.cs +++ b/src/TensorFlowNET.Core/Exceptions/KeyError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class KeyError : Exception + public class KeyError : TensorflowException { public KeyError() : base() { diff --git a/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs index 09a02a4a2..6f7e4f485 100644 --- a/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs +++ b/src/TensorFlowNET.Core/Exceptions/RuntimeError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class RuntimeError : Exception + public class RuntimeError : TensorflowException { public RuntimeError() : base() { diff --git a/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs b/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs new file mode 100644 index 000000000..ee9eca696 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/TensorflowException.cs @@ -0,0 +1,36 @@ +using System; +using System.Runtime.Serialization; + +namespace Tensorflow +{ + + /// + /// Serves as a base class to all exceptions of Tensorflow.NET. + /// + [Serializable] + public class TensorflowException : Exception + { + /// Initializes a new instance of the class. + public TensorflowException() + { } + + /// Initializes a new instance of the class with serialized data. + /// The that holds the serialized object data about the exception being thrown. + /// The that contains contextual information about the source or destination. + /// The info parameter is null. + /// The class name is null or is zero (0). + protected TensorflowException(SerializationInfo info, StreamingContext context) : base(info, context) + { } + + /// Initializes a new instance of the class with a specified error message. + /// The message that describes the error. + public TensorflowException(string message) : base(message) + { } + + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference (Nothing in Visual Basic) if no inner exception is specified. + public TensorflowException(string message, Exception innerException) : base(message, innerException) + { } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Exceptions/TypeError.cs b/src/TensorFlowNET.Core/Exceptions/TypeError.cs index a4c379881..42c8e3a02 100644 --- a/src/TensorFlowNET.Core/Exceptions/TypeError.cs +++ b/src/TensorFlowNET.Core/Exceptions/TypeError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class TypeError : Exception + public class TypeError : TensorflowException { public TypeError() : base() { diff --git a/src/TensorFlowNET.Core/Exceptions/ValueError.cs b/src/TensorFlowNET.Core/Exceptions/ValueError.cs index 825d27a16..0d6fb4e39 100644 --- a/src/TensorFlowNET.Core/Exceptions/ValueError.cs +++ b/src/TensorFlowNET.Core/Exceptions/ValueError.cs @@ -2,7 +2,7 @@ namespace Tensorflow { - public class ValueError : Exception + public class ValueError : TensorflowException { public ValueError() : base() { diff --git a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs index dc3955b28..145a30584 100644 --- a/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Framework/Models/ScopedTFImportGraphDefOptions.cs @@ -6,10 +6,5 @@ public ScopedTFImportGraphDefOptions() : base() { } - - ~ScopedTFImportGraphDefOptions() - { - base.Dispose(); - } } } diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs index 9f9b4ad71..8a2bc5c3c 100644 --- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -15,6 +15,8 @@ limitations under the License. ******************************************************************************/ using System.Collections.Generic; +using System.IO; +using Tensorflow.Util; namespace Tensorflow { @@ -27,12 +29,12 @@ public static Dictionary get_registered_ops() if(_registered_ops == null) { _registered_ops = new Dictionary(); - var handle = c_api.TF_GetAllOpList(); - var buffer = new Buffer(handle); - var op_list = OpList.Parser.ParseFrom(buffer); - - foreach (var op_def in op_list.Op) - _registered_ops[op_def.Name] = op_def; + using (var buffer = new Buffer(c_api.TF_GetAllOpList())) + { + var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + foreach (var op_def in op_list.Op) + _registered_ops[op_def.Name] = op_def; + } } return _registered_ops; diff --git a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs index 6c9f6b187..66419b3e8 100644 --- a/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs +++ b/src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs @@ -14,49 +14,62 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System; using System.Collections.Generic; using System.Linq; using static Tensorflow.Binding; namespace Tensorflow { - public class DefaultGraphStack + + /// + /// Serves as a stack for determining current default graph. + /// + public class DefaultGraphStack { - List stack = new List(); + private readonly List _stack = new List(); public void set_controller(Graph @default) { - if (!stack.Exists(x => x.Graph == @default)) - stack.Add(new StackModel { Graph = @default, IsDefault = true }); + if (!_stack.Exists(x => x.Graph == @default)) + _stack.Add(new StackModel {Graph = @default, IsDefault = true}); - foreach (var s in stack) + foreach (var s in _stack) s.IsDefault = s.Graph == @default; } public Graph get_controller() { - if (stack.Count(x => x.IsDefault) == 0) - stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); + if (_stack.Count(x => x.IsDefault) == 0) + _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); + for (var i = _stack.Count - 1; i >= 0; i--) + { + var x = _stack[i]; + if (x.IsDefault) + return x.Graph; + } - return stack.Last(x => x.IsDefault).Graph; + throw new TensorflowException("Unable to find a default graph"); } public bool remove(Graph g) { - var sm = stack.FirstOrDefault(x => x.Graph == g); - if (sm == null) return false; - return stack.Remove(sm); + if (_stack.Count == 0) + return false; + + var sm = _stack.Find(model => model.Graph == g); + return sm != null && _stack.Remove(sm); } public void reset() { - stack.Clear(); + _stack.Clear(); } - } - public class StackModel - { - public Graph Graph { get; set; } - public bool IsDefault { get; set; } + private class StackModel + { + public Graph Graph { get; set; } + public bool IsDefault { get; set; } + } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 4a3ac7937..c97e1b6ff 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -15,6 +15,7 @@ limitations under the License. ******************************************************************************/ using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using Tensorflow.Operations; @@ -66,8 +67,9 @@ public ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[ /// within the context should have control dependencies on /// `control_inputs`. /// + [SuppressMessage("ReSharper", "CoVariantArrayConversion")] public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) - => control_dependencies(control_inputs == null ? null : control_inputs.OfType().ToArray()); + => control_dependencies((object[])control_inputs); /// /// Returns a context manager that specifies control dependencies. diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 17828c730..4a7e0ed8c 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -14,6 +14,9 @@ You may obtain a copy of the License at limitations under the License. ******************************************************************************/ +using System.IO; +using Tensorflow.Util; + namespace Tensorflow { public partial class Graph @@ -23,21 +26,19 @@ public Buffer ToGraphDef(Status s) var buffer = new Buffer(); c_api.TF_GraphToGraphDef(_handle, buffer, s); s.Check(true); - // var def = GraphDef.Parser.ParseFrom(buffer); - // buffer.Dispose(); return buffer; } private GraphDef _as_graph_def(bool add_shapes = false) { - var status = new Status(); - var buffer = ToGraphDef(status); - status.Check(true); - status.Dispose(); - - var def = GraphDef.Parser.ParseFrom(buffer); - buffer.Dispose(); + GraphDef def; + using (var status = new Status()) + using (var buffer = ToGraphDef(status)) + { + status.Check(true); + def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } // Strip the experimental library field iff it's empty. // if(def.Library.Function.Count == 0) @@ -45,7 +46,7 @@ private GraphDef _as_graph_def(bool add_shapes = false) return def; } - public GraphDef as_graph_def(bool add_shapes = false) + public GraphDef as_graph_def(bool add_shapes = false) => _as_graph_def(add_shapes); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 82695527f..0b2dc0e69 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -30,11 +30,10 @@ public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, Impo var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); - for (int i = 0; i < num_return_outputs; i++) - { - var handle = return_output_handle + i * size; - return_outputs[i] = Marshal.PtrToStructure(handle); - } + + var tf_output_ptr = (TF_Output*) return_output_handle; + for (int i = 0; i < num_return_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); Marshal.FreeHGlobal(return_output_handle); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 436afcc90..0e28dd9ac 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -18,6 +18,7 @@ limitations under the License. using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow @@ -30,7 +31,7 @@ public OpDef GetOpDef(string type) using (var status = new Status()) { c_api.TF_GraphGetOpDef(_handle, type, buffer, status); - return OpDef.Parser.ParseFrom(buffer.Data); + return OpDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } @@ -39,16 +40,20 @@ public OperationDescription NewOperation(string opType, string opName) return c_api.TF_NewOperation(_handle, opType, opName); } - public unsafe Operation[] ReturnOperations(IntPtr results) + public Operation[] ReturnOperations(IntPtr results) { TF_Operation return_oper_handle = new TF_Operation(); int num_return_opers = 0; c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); Operation[] return_opers = new Operation[num_return_opers]; + var tf_op_size = Marshal.SizeOf(); for (int i = 0; i < num_return_opers; i++) { - var handle = return_oper_handle.node + Marshal.SizeOf() * i; - return_opers[i] = new Operation(*(IntPtr*)handle); + unsafe + { + var handle = return_oper_handle.node + tf_op_size * i; + return_opers[i] = new Operation(*(IntPtr*)handle); + } } return return_opers; @@ -67,7 +72,7 @@ public Operation OperationByName(string operName) public ITensorOrOperation[] get_operations() { - return _nodes_by_name.Values.Select(x => x).ToArray(); + return _nodes_by_name.Values.ToArray(); } /// @@ -81,7 +86,7 @@ public Operation get_operation_by_name(string name) public ITensorOrOperation _get_operation_by_name_unsafe(string name) { - return _nodes_by_name.ContainsKey(name) ? _nodes_by_name[name] : null; + return _nodes_by_name.TryGetValue(name, out var val) ? val : null; } public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 07dc117e8..0dfb68db8 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -369,7 +369,7 @@ public string unique_name(string name, bool mark_as_used = true) var name_key = name.ToLower(); int i = 0; if (_names_in_use.ContainsKey(name_key)) - i = _names_in_use[name_key]; + i = _names_in_use[name_key]; // Increment the number for "name_key". if (mark_as_used) _names_in_use[name_key] = i + 1; @@ -399,13 +399,13 @@ public TF_Output[] ReturnOutputs(IntPtr results) int num_return_outputs = 0; c_api.TF_ImportGraphDefResultsReturnOutputs(results, ref num_return_outputs, ref return_output_handle); TF_Output[] return_outputs = new TF_Output[num_return_outputs]; - for (int i = 0; i < num_return_outputs; i++) + unsafe { - var handle = return_output_handle + (Marshal.SizeOf() * i); - return_outputs[i] = Marshal.PtrToStructure(handle); + var tf_output_ptr = (TF_Output*) return_output_handle; + for (int i = 0; i < num_return_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; } - - return return_outputs; } public string[] get_all_collection_keys() @@ -497,11 +497,9 @@ private IEnumerable GetEnumerable() IEnumerator IEnumerable.GetEnumerator() => GetEnumerable().GetEnumerator(); - IEnumerator IEnumerable.GetEnumerator() - { - throw new NotImplementedException(); - } - + IEnumerator IEnumerable.GetEnumerator() + => throw new NotImplementedException(); + public static implicit operator IntPtr(Graph graph) { return graph._handle; diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs index bdcaf60ce..708025976 100644 --- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -20,7 +20,8 @@ namespace Tensorflow { public class ImportGraphDefOptions : DisposableObject { - public int NumReturnOutputs => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); + public int NumReturnOutputs + => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); public ImportGraphDefOptions() { diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 24348322a..62c8f3788 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -50,14 +50,12 @@ public int OutputListLength(string name) public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { - int size = Marshal.SizeOf(); - var handle = Marshal.AllocHGlobal(size); + var handle = Marshal.AllocHGlobal(Marshal.SizeOf()); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); var consumers = new TF_Input[num]; + var inputptr = (TF_Input*) handle; for (int i = 0; i < num; i++) - { - consumers[i] = Marshal.PtrToStructure(handle + i * size); - } + consumers[i] = *(inputptr + i); return consumers; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 059290f4d..5fff9ade5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -17,7 +17,9 @@ limitations under the License. using Google.Protobuf.Collections; using System; using System.Collections.Generic; +using System.IO; using System.Linq; +using Tensorflow.Util; namespace Tensorflow { @@ -226,9 +228,12 @@ public object get_attr(string name) using (var status = new Status()) using (var buf = new Buffer()) { - c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); - status.Check(true); - x = AttrValue.Parser.ParseFrom(buf); + unsafe + { + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); + } } string oneof_value = x.ValueCase.ToString(); @@ -259,7 +264,7 @@ private NodeDef GetNodeDef() { c_api.TF_OperationToNodeDef(_handle, buffer, s); s.Check(); - return NodeDef.Parser.ParseFrom(buffer); + return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } @@ -299,8 +304,7 @@ private void _assert_same_graph(Tensor tensor) /// public TF_Output _tf_output(int output_idx) { - var tf_output = new TF_Output(op, output_idx); - return tf_output; + return new TF_Output(op, output_idx); } /// @@ -308,8 +312,7 @@ public TF_Output _tf_output(int output_idx) /// public TF_Input _tf_input(int input_idx) { - var tf_input = new TF_Input(op, input_idx); - return tf_input; + return new TF_Input(op, input_idx); } } } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 4c5f2be3d..58177df28 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -1,413 +1,413 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Numerics; -using System.Text; - -namespace Tensorflow -{ - public class BaseSession : DisposableObject - { - protected Graph _graph; - protected bool _opened; - protected bool _closed; - protected int _current_version; - protected byte[] _target; - public Graph graph => _graph; - - public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) - { - _graph = g is null ? ops.get_default_graph() : g; - _graph.as_default(); - _target = UTF8Encoding.UTF8.GetBytes(target); - - SessionOptions newOpts = null; - if (opts == null) - newOpts = new SessionOptions(); - - var status = new Status(); - - _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); - - // dispose newOpts - if (opts == null) - newOpts.Dispose(); - - status.Check(true); - } - - public virtual void run(Operation op, params FeedItem[] feed_dict) - { - _run(op, feed_dict); - } - - public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) - { - return _run(fetche, feed_dict)[0]; - } - - public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) - { - return _run(fetche, feed_dict)[0]; - } - - public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) - { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); - return (results[0], results[1], results[2], results[3]); - } - - public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) - { - var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); - return (results[0], results[1], results[2]); - } - - public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) - { - var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); - return (results[0], results[1]); - } - - public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) - { - return _run(fetches, feed_dict); - } - - public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) - { - var feed_items = feed_dict == null ? new FeedItem[0] : - feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); - return _run(fetches, feed_items); - } - - private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) - { - var feed_dict_tensor = new Dictionary(); - var feed_map = new Dictionary(); - - Func> feed_fn = (item) => - { - return new (object, object)[] { (item.Key, item.Value) }; - }; - - // Validate and process feed_dict. - if (feed_dict != null) - { - foreach (var feed in feed_dict) - { - foreach (var (subfeed, subfeed_val) in feed_fn(feed)) - { - var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); - //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used - feed_dict_tensor[subfeed_t] = subfeed_val; - feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); - } - } - } - - // Create a fetch handler to take care of the structure of fetches. - var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); - - // Run request and get response. - // We need to keep the returned movers alive for the following _do_run(). - // These movers are no longer needed when _do_run() completes, and - // are deleted when `movers` goes out of scope when this _run() ends. - var _ = _update_with_movers(); - var final_fetches = fetch_handler.fetches(); - var final_targets = fetch_handler.targets(); - - // We only want to really perform the run if fetches or targets are provided, - // or if the call is a partial run that specifies feeds. - var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); - - return fetch_handler.build_results(this, results); - } - - /// - /// Runs a step based on the given fetches and feeds. - /// - /// - /// A list of operations to be run, but not fetched. - /// - /// - /// - /// A list of numpy ndarrays, corresponding to the elements of - /// `fetch_list`. If the ith element of `fetch_list` contains the - /// name of an operation, the first Tensor output of that operation - /// will be returned for that element. - /// - private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) - { - var feeds = feed_dict.Select(x => - { - if (x.Key is Tensor tensor) - { - switch (x.Value) - { -#if _REGEN - %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - case #1 v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case #1[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - % -#else - case sbyte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case sbyte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case byte[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case short[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ushort[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case int[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case uint[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case long[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case ulong[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case float[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case double[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Complex[] v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); -#endif - case bool v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); - case string v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case IntPtr v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); - case Tensor v: - return new KeyValuePair(tensor._as_tf_output(), v); - case NDArray v: - return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); - default: - throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); - } - } - throw new NotImplementedException("_do_run.feed_dict"); - }).ToArray(); - var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); - var targets = target_list; - - return _call_tf_sessionrun(feeds, fetches, target_list); - } - - private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) - { - // Ensure any changes to the graph are reflected in the runtime. - _extend_graph(); - - var status = new Status(); - - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); - - c_api.TF_SessionRun(_handle, - run_options: null, - inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), - ninputs: feed_dict.Length, - outputs: fetch_list, - output_values: output_values, - noutputs: fetch_list.Length, - target_opers: target_list.Select(f => (IntPtr)f).ToArray(), - ntargets: target_list.Count, - run_metadata: IntPtr.Zero, - status: status); - - status.Check(true); - - var result = new NDArray[fetch_list.Length]; - - for (int i = 0; i < fetch_list.Length; i++) - result[i] = fetchValue(output_values[i]); - - for (int i = 0; i < feed_dict.Length; i++) - feed_dict[i].Value.Dispose(); - - return result; - } - - private unsafe NDArray fetchValue(IntPtr output) - { - var tensor = new Tensor(output); - NDArray nd = null; - Type type = tensor.dtype.as_numpy_dtype(); - var ndims = tensor.shape; - var offset = c_api.TF_TensorData(output); - - if(ndims.Length == 0) - { - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - nd = NDArray.Scalar(*(bool*)offset); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = NDArray.FromString(str); - break; - case TF_DataType.TF_UINT8: - nd = NDArray.Scalar(*(byte*)offset); - break; - case TF_DataType.TF_INT16: - nd = NDArray.Scalar(*(short*)offset); - break; - case TF_DataType.TF_INT32: - nd = NDArray.Scalar(*(int*)offset); - break; - case TF_DataType.TF_INT64: - nd = NDArray.Scalar(*(long*)offset); - break; - case TF_DataType.TF_FLOAT: - nd = NDArray.Scalar(*(float*)offset); - break; - case TF_DataType.TF_DOUBLE: - nd = NDArray.Scalar(*(double*)offset); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - else - { - switch (tensor.dtype) - { - case TF_DataType.TF_BOOL: - var bools = new bool[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(bools).reshape(ndims); - break; - case TF_DataType.TF_STRING: - var bytes = tensor.BufferToArray(); - // wired, don't know why we have to start from offset 9. - // length in the begin - var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); - nd = np.array(str); - break; - case TF_DataType.TF_UINT8: - var _bytes = new byte[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(_bytes).reshape(ndims); - break; - case TF_DataType.TF_INT16: - var shorts = new short[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(shorts).reshape(ndims); - break; - case TF_DataType.TF_INT32: - var ints = new int[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(ints).reshape(ndims); - break; - case TF_DataType.TF_INT64: - var longs = new long[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(longs).reshape(ndims); - break; - case TF_DataType.TF_FLOAT: - var floats = new float[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(floats).reshape(ndims); - break; - case TF_DataType.TF_DOUBLE: - var doubles = new double[tensor.size]; - for (ulong i = 0; i < tensor.size; i++) - doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); - nd = np.array(doubles).reshape(ndims); - break; - default: - throw new NotImplementedException("can't fetch output"); - } - } - - tensor.Dispose(); - - return nd; - } - - /// - /// If a tensor handle that is fed to a device incompatible placeholder, - /// we move the tensor to the right device, generate a new tensor handle, - /// and update feed_dict to use the new handle. - /// - private List _update_with_movers() - { - return new List { }; - } - - private void _extend_graph() - { - - } - - public void close() - { - Dispose(); - } - - protected override void DisposeUnmanagedResources(IntPtr handle) - { - using (var status = new Status()) - { - c_api.TF_DeleteSession(handle, status); - status.Check(true); - } - } - } -} +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Text; + +namespace Tensorflow +{ + public class BaseSession : DisposableObject + { + protected Graph _graph; + protected bool _opened; + protected bool _closed; + protected int _current_version; + protected byte[] _target; + public Graph graph => _graph; + + public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) + { + _graph = g is null ? ops.get_default_graph() : g; + _graph.as_default(); + _target = UTF8Encoding.UTF8.GetBytes(target); + + SessionOptions newOpts = null; + if (opts == null) + newOpts = new SessionOptions(); + + var status = new Status(); + + _handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); + + // dispose newOpts + if (opts == null) + newOpts.Dispose(); + + status.Check(true); + } + + public virtual void run(Operation op, params FeedItem[] feed_dict) + { + _run(op, feed_dict); + } + + public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } + + public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } + + public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); + return (results[0], results[1], results[2], results[3]); + } + + public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); + return (results[0], results[1], results[2]); + } + + public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) + { + var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); + return (results[0], results[1]); + } + + public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) + { + return _run(fetches, feed_dict); + } + + public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) + { + var feed_items = feed_dict == null ? new FeedItem[0] : + feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); + return _run(fetches, feed_items); + } + + private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) + { + var feed_dict_tensor = new Dictionary(); + var feed_map = new Dictionary(); + + Func> feed_fn = (item) => + { + return new (object, object)[] { (item.Key, item.Value) }; + }; + + // Validate and process feed_dict. + if (feed_dict != null) + { + foreach (var feed in feed_dict) + { + foreach (var (subfeed, subfeed_val) in feed_fn(feed)) + { + var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false); + //var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype(); // subfeed_dtype was never used + feed_dict_tensor[subfeed_t] = subfeed_val; + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); + } + } + } + + // Create a fetch handler to take care of the structure of fetches. + var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); + + // Run request and get response. + // We need to keep the returned movers alive for the following _do_run(). + // These movers are no longer needed when _do_run() completes, and + // are deleted when `movers` goes out of scope when this _run() ends. + var _ = _update_with_movers(); + var final_fetches = fetch_handler.fetches(); + var final_targets = fetch_handler.targets(); + + // We only want to really perform the run if fetches or targets are provided, + // or if the call is a partial run that specifies feeds. + var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); + + return fetch_handler.build_results(this, results); + } + + /// + /// Runs a step based on the given fetches and feeds. + /// + /// + /// A list of operations to be run, but not fetched. + /// + /// + /// + /// A list of numpy ndarrays, corresponding to the elements of + /// `fetch_list`. If the ith element of `fetch_list` contains the + /// name of an operation, the first Tensor output of that operation + /// will be returned for that element. + /// + private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) + { + var feeds = feed_dict.Select(x => + { + if (x.Key is Tensor tensor) + { + switch (x.Value) + { +#if _REGEN + %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + case #1 v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case #1[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + % +#else + case sbyte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case sbyte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case byte[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case short[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ushort[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case int[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case uint[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case long[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case ulong[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case float[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case double[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Complex[] v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); +#endif + case bool v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor((byte)(v?1:0), TF_DataType.TF_BOOL)); + case string v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case IntPtr v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v)); + case Tensor v: + return new KeyValuePair(tensor._as_tf_output(), v); + case NDArray v: + return new KeyValuePair(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); + default: + throw new NotImplementedException($"feed_dict data type {(x.Value?.GetType().Name ?? "")}"); + } + } + throw new NotImplementedException("_do_run.feed_dict"); + }).ToArray(); + var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); + var targets = target_list; + + return _call_tf_sessionrun(feeds, fetches, target_list); + } + + private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair[] feed_dict, TF_Output[] fetch_list, List target_list) + { + // Ensure any changes to the graph are reflected in the runtime. + _extend_graph(); + + var status = new Status(); + + var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); + + c_api.TF_SessionRun(_handle, + run_options: null, + inputs: feed_dict.Select(f => f.Key).ToArray(), + input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), + ninputs: feed_dict.Length, + outputs: fetch_list, + output_values: output_values, + noutputs: fetch_list.Length, + target_opers: target_list.Select(f => (IntPtr)f).ToArray(), + ntargets: target_list.Count, + run_metadata: IntPtr.Zero, + status: status); + + status.Check(true); + + var result = new NDArray[fetch_list.Length]; + + for (int i = 0; i < fetch_list.Length; i++) + result[i] = fetchValue(output_values[i]); + + for (int i = 0; i < feed_dict.Length; i++) + feed_dict[i].Value.Dispose(); + + return result; + } + + private unsafe NDArray fetchValue(IntPtr output) + { + var tensor = new Tensor(output); + NDArray nd = null; + Type type = tensor.dtype.as_numpy_dtype(); + var ndims = tensor.shape; + var offset = c_api.TF_TensorData(output); + + if(ndims.Length == 0) + { + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + nd = NDArray.Scalar(*(bool*)offset); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.BufferToArray(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = NDArray.FromString(str); + break; + case TF_DataType.TF_UINT8: + nd = NDArray.Scalar(*(byte*)offset); + break; + case TF_DataType.TF_INT16: + nd = NDArray.Scalar(*(short*)offset); + break; + case TF_DataType.TF_INT32: + nd = NDArray.Scalar(*(int*)offset); + break; + case TF_DataType.TF_INT64: + nd = NDArray.Scalar(*(long*)offset); + break; + case TF_DataType.TF_FLOAT: + nd = NDArray.Scalar(*(float*)offset); + break; + case TF_DataType.TF_DOUBLE: + nd = NDArray.Scalar(*(double*)offset); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } + else + { + switch (tensor.dtype) + { + case TF_DataType.TF_BOOL: + var bools = new bool[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(bools).reshape(ndims); + break; + case TF_DataType.TF_STRING: + var bytes = tensor.BufferToArray(); + // wired, don't know why we have to start from offset 9. + // length in the begin + var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); + nd = np.array(str); + break; + case TF_DataType.TF_UINT8: + var _bytes = new byte[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(_bytes).reshape(ndims); + break; + case TF_DataType.TF_INT16: + var shorts = new short[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(shorts).reshape(ndims); + break; + case TF_DataType.TF_INT32: + var ints = new int[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + ints[i] = *(int*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(ints).reshape(ndims); + break; + case TF_DataType.TF_INT64: + var longs = new long[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + longs[i] = *(long*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(longs).reshape(ndims); + break; + case TF_DataType.TF_FLOAT: + var floats = new float[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + floats[i] = *(float*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(floats).reshape(ndims); + break; + case TF_DataType.TF_DOUBLE: + var doubles = new double[tensor.size]; + for (ulong i = 0; i < tensor.size; i++) + doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i)); + nd = np.array(doubles).reshape(ndims); + break; + default: + throw new NotImplementedException("can't fetch output"); + } + } + + tensor.Dispose(); + + return nd; + } + + /// + /// If a tensor handle that is fed to a device incompatible placeholder, + /// we move the tensor to the right device, generate a new tensor handle, + /// and update feed_dict to use the new handle. + /// + private List _update_with_movers() + { + return new List { }; + } + + private void _extend_graph() + { + + } + + public void close() + { + Dispose(); + } + + protected override void DisposeUnmanagedResources(IntPtr handle) + { + using (var status = new Status()) + { + c_api.TF_DeleteSession(handle, status); + status.Check(true); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Sessions/FeedItem.cs b/src/TensorFlowNET.Core/Sessions/FeedItem.cs index f87457e72..c3a3dc675 100644 --- a/src/TensorFlowNET.Core/Sessions/FeedItem.cs +++ b/src/TensorFlowNET.Core/Sessions/FeedItem.cs @@ -16,5 +16,11 @@ public FeedItem(object key, object val) public static implicit operator FeedItem((object, object) feed) => new FeedItem(feed.Item1, feed.Item2); + + public void Deconstruct(out object key, out object value) + { + key = Key; + value = Value; + } } } diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index ed99b7fe1..112543fe1 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -37,8 +37,8 @@ protected override void DisposeUnmanagedResources(IntPtr handle) public void SetConfig(ConfigProto config) { - var bytes = config.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); + var bytes = config.ToByteArray(); //TODO! we can use WriteTo + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak Marshal.Copy(bytes, 0, proto, bytes.Length); using (var status = new Status()) diff --git a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs index c40b2a00f..6cbf4eec5 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs @@ -27,13 +27,17 @@ public static string[] TF_OperationOutputConsumers_wrapper(TF_Output oper_out) var handle = Marshal.AllocHGlobal(size * num_consumers); int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); var consumers = new string[num_consumers]; - for (int i = 0; i < num; i++) + unsafe { - TF_Input input = Marshal.PtrToStructure(handle + i * size); - consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); + var inputptr = (TF_Input*) handle; + for (int i = 0; i < num; i++) + { + var oper = (inputptr + i)->oper; + consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); + } } return consumers; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 2bdd806ad..ce561f752 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -15,6 +15,8 @@ limitations under the License. ******************************************************************************/ using System; +using System.Runtime.CompilerServices; +using static Tensorflow.c_api; namespace Tensorflow { @@ -27,36 +29,36 @@ public class Status : DisposableObject /// /// Error message /// - public string Message => c_api.StringPiece(c_api.TF_Message(_handle)); + public string Message => c_api.StringPiece(TF_Message(_handle)); /// /// Error code /// - public TF_Code Code => c_api.TF_GetCode(_handle); + public TF_Code Code => TF_GetCode(_handle); public Status() { - _handle = c_api.TF_NewStatus(); + _handle = TF_NewStatus(); } public void SetStatus(TF_Code code, string msg) { - c_api.TF_SetStatus(_handle, code, msg); + TF_SetStatus(_handle, code, msg); } /// /// Check status /// Throw exception with error message if code != TF_OK /// + /// When the returned check is not TF_Code.TF_OK + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void Check(bool throwException = false) { if (Code != TF_Code.TF_OK) { Console.WriteLine(Message); if (throwException) - { - throw new Exception(Message); - } + throw new TensorflowException(Message); } } @@ -66,6 +68,6 @@ public static implicit operator IntPtr(Status status) } protected override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TF_DeleteStatus(handle); + => TF_DeleteStatus(handle); } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Status/c_api.status.cs b/src/TensorFlowNET.Core/Status/c_api.status.cs index cfac49d1b..ee17e4476 100644 --- a/src/TensorFlowNET.Core/Status/c_api.status.cs +++ b/src/TensorFlowNET.Core/Status/c_api.status.cs @@ -51,7 +51,7 @@ public partial class c_api /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_NewStatus(); + public static extern IntPtr TF_NewStatus(); /// /// Record in *s. Any previous information is lost. diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 73f116ec8..625b424a1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -52,9 +52,9 @@ public partial class Tensor private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; // note: they must be assigned to a static variable in order to work as unmanaged callbacks - static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; - static Deallocator _gcHandleDeallocator = FreeGCHandle; - private static Deallocator _nothingDeallocator = FreeNothing; + private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory; + private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle; + private static readonly Deallocator _nothingDeallocator = FreeNothing; /// /// Create a Tensor object from an existing TF handle @@ -528,7 +528,6 @@ public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) } _handle = CreateTensorFromNDArray(nd, tensorDType); - IsMemoryOwner = true; } private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) @@ -624,7 +623,7 @@ protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Marshal.WriteInt64(tensor, 0); var status = new Status(); - fixed (byte* src = &buffer[0]) + fixed (byte* src = buffer) c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); status.Check(true); @@ -667,8 +666,9 @@ internal static void FreeHGlobalMemory(IntPtr dataPtr, IntPtr len, ref Deallocat { if (args.deallocator_called) return; + // NumSharp will dispose - // Marshal.FreeHGlobal(dataPtr); + Marshal.FreeHGlobal(dataPtr); args.deallocator_called = true; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 8ac6c73e0..75cba69eb 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -221,15 +221,6 @@ public T[] Data() where T : unmanaged /// When is string public T[] ToArray() where T : unmanaged { - //when T is string - if (typeof(T) == typeof(string)) - { - if (dtype != TF_DataType.TF_STRING) - throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); - - return (T[]) (object) StringData(); - } - //Are the types matching? if (typeof(T).as_dtype() == dtype) { @@ -246,20 +237,12 @@ public T[] ToArray() where T : unmanaged unsafe { var len = (long) size; - fixed (T* dstRet = ret) + fixed (T* dst = ret) { - T* dst = dstRet; //local stack copy - if (typeof(T).IsPrimitive) - { - var src = (T*) buffer; - len *= ((long) itemsize); - System.Buffer.MemoryCopy(src, dst, len, len); - } else - { - var itemsize = (long) this.itemsize; - var buffer = this.buffer.ToInt64(); - Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure(new IntPtr(buffer + i * itemsize))); - } + //T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called. + var src = (T*) buffer; + len *= ((long) itemsize); + System.Buffer.MemoryCopy(src, dst, len, len); } } @@ -384,9 +367,15 @@ public byte[] BufferToArray() } } - /// Used internally in ToArray<T> - private unsafe string[] StringData() + /// + /// Extracts string array from current Tensor. + /// + /// When != TF_DataType.TF_STRING + public unsafe string[] StringData() { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + // // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] @@ -442,7 +431,7 @@ public NDArray eval(params FeedItem[] feed_dict) /// A dictionary that maps `Tensor` objects to feed values. /// The `Session` to be used to evaluate this tensor. /// A array corresponding to the value of this tensor. - public NDArray eval(Session session, FeedItem[] feed_dict = null) + public NDArray eval(Session session, params FeedItem[] feed_dict) { return ops._eval_using_default_session(this, feed_dict, graph, session); } @@ -568,23 +557,10 @@ public override string ToString() protected override void DisposeUnmanagedResources(IntPtr handle) { - if (handle != IntPtr.Zero) - { - c_api.TF_DeleteTensor(handle); - _handle = IntPtr.Zero; - } + c_api.TF_DeleteTensor(handle); } - public bool IsDisposed - { - get - { - lock (this) - { - return _handle == IntPtr.Zero; - } - } - } + public bool IsDisposed => _disposed; public int tensor_int_val { get; set; } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 43848da60..59c107fca 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -83,6 +83,12 @@ public static NDArray MakeNdarray(TensorProto tensor) throw new NotImplementedException("MakeNdarray"); } + private static readonly TF_DataType[] quantized_types = new TF_DataType[] + { + TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, + TF_DataType.TF_QINT32 + }; + /// /// Create a TensorProto. /// @@ -99,15 +105,6 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T if (values is TensorProto tp) return tp; - if (dtype != TF_DataType.DtInvalid) - ; - - bool is_quantized = new TF_DataType[] - { - TF_DataType.TF_QINT8, TF_DataType.TF_QUINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QUINT16, - TF_DataType.TF_QINT32 - }.Contains(dtype); - // We first convert value to a numpy array or scalar. NDArray nparray = null; var np_dt = dtype.as_numpy_dtype(); @@ -227,13 +224,13 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T } } - var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype); + var numpy_dtype = nparray.dtype.as_dtype(dtype: dtype); if (numpy_dtype == TF_DataType.DtInvalid) throw new TypeError($"Unrecognized data type: {nparray.dtype}"); // If dtype was specified and is a quantized type, we convert // numpy_dtype back into the quantized version. - if (is_quantized) + if (quantized_types.Contains(dtype)) numpy_dtype = dtype; bool is_same_size = false; diff --git a/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs b/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs new file mode 100644 index 000000000..02b8bb739 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/UnmanagedExtensions.cs @@ -0,0 +1,94 @@ +using System; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using NumSharp.Backends.Unmanaged; + +namespace Tensorflow.Util +{ + public static class UnmanagedExtensions + { + //internally UnmanagedMemoryStream can't construct with null address. + private static readonly unsafe byte* _empty = (byte*) Marshal.AllocHGlobal(1); + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be default/null. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock block) + { + unsafe + { + if (block.Address == null) + return new UnmanagedMemoryStream(_empty, 0); + return new UnmanagedMemoryStream(block.Address, block.BytesCount); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be default/null. + /// Offset from the start of the block. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this UnmanagedMemoryBlock block, long offset) + { + if (block.BytesCount - offset <= 0) + throw new ArgumentOutOfRangeException(nameof(offset)); + + unsafe + { + if (block.Address == null) + return new UnmanagedMemoryStream(_empty, 0); + return new UnmanagedMemoryStream(block.Address + offset, block.BytesCount - offset); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be IntPtr.Zero. + /// The length of the block in bytes. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this IntPtr address, long length) + { + if (length <= 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + unsafe + { + if (address == IntPtr.Zero) + return new UnmanagedMemoryStream(_empty, 0); + + // ReSharper disable once AssignNullToNotNullAttribute + return new UnmanagedMemoryStream((byte*) address, length); + } + } + + /// + /// Creates a memory stream based on given . + /// + /// The block to stream. Can be IntPtr.Zero. + /// Offset from the start of the block. + /// The length of the block in bytes. + /// There is no need to dispose the returned + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static UnmanagedMemoryStream Stream(this IntPtr address, long offset, long length) + { + if (length <= 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + unsafe + { + if (address == IntPtr.Zero) + return new UnmanagedMemoryStream(_empty, 0); + + return new UnmanagedMemoryStream((byte*) address + offset, length); + } + } + } +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/globals.regen b/src/TensorFlowNET.Core/globals.regen index 146155b33..86cbee675 100644 --- a/src/TensorFlowNET.Core/globals.regen +++ b/src/TensorFlowNET.Core/globals.regen @@ -8,7 +8,8 @@ %supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] %supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] %supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] -%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] +%supported_numericals_TF_DataType = ["TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] +%supported_numericals_TF_DataType_full = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] //this is the type we use in summerizing/reducting: %supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] @@ -25,7 +26,8 @@ %supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] %supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] -%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] +%supported_dtypes_TF_DataType = ["TF_BOOL","TF_UINT8","TF_INT16","TF_UINT16","TF_INT32","TF_UINT32","TF_INT64","TF_UINT64","TF_STRING","TF_DOUBLE","TF_FLOAT"] +%supported_dtypes_TF_DataType_full = ["TF_DataType.TF_BOOL","TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_STRING","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] %supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] %supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index e485ba6fd..1dc8eb568 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -230,8 +230,8 @@ public static (IntPtr, IntPtr) _create_c_op(Graph graph, NodeDef node_def, T[ // Add attrs foreach (var attr in node_def.Attr) { - var bytes = attr.Value.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak Marshal.Copy(bytes, 0, proto, bytes.Length); uint len = (uint)bytes.Length; c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index da8737228..ac6c38dcb 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -64,8 +64,7 @@ public void enable_eager_execution() public Session Session() { - defaultSession = new Session(); - return defaultSession; + return new Session(); } public Session Session(Graph graph) diff --git a/src/TensorFlowNet.Benchmarks/Program.cs b/src/TensorFlowNet.Benchmarks/Program.cs index e17a1d681..ea7c2bde5 100644 --- a/src/TensorFlowNet.Benchmarks/Program.cs +++ b/src/TensorFlowNet.Benchmarks/Program.cs @@ -9,24 +9,18 @@ class Program { static void Main(string[] args) { -#if DEBUG - IConfig config = new DebugInProcessConfig(); -#else - IConfig config = null; -#endif - if (args?.Length > 0) { for (int i = 0; i < args.Length; i++) { string name = $"TensorFlowBenchmark.{args[i]}"; var type = Type.GetType(name); - BenchmarkRunner.Run(type, config); + BenchmarkRunner.Run(type); } } else { - BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, config); + BenchmarkSwitcher.FromAssembly(Assembly.GetExecutingAssembly()).Run(args, ManualConfig.Create(DefaultConfig.Instance).With(ConfigOptions.DisableOptimizationsValidator)); } Console.ReadLine(); diff --git a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj index bc2a0ff39..4618f06ba 100644 --- a/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/TensorFlowBenchmark.csproj @@ -6,6 +6,7 @@ true TensorFlowBenchmark TensorFlowBenchmark + 7.3 diff --git a/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs b/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs new file mode 100644 index 000000000..5b3a0cd39 --- /dev/null +++ b/src/TensorFlowNet.Benchmarks/Unmanaged/StructCastBenchmark.cs @@ -0,0 +1,76 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using BenchmarkDotNet.Attributes; +using Google.Protobuf.WellKnownTypes; +using NumSharp; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowBenchmark.Unmanaged +{ + public struct UnmanagedStruct + { + public int a; + public long b; + public UnmanagedStruct(int _) + { + a = 2; + b = 3; + } + } + + [SimpleJob(launchCount: 1, warmupCount: 2, targetCount: 10)] + [MinColumn, MaxColumn, MeanColumn, MedianColumn] + public unsafe class StructCastBenchmark + { + private static void EnsureIsUnmanaged(T _) where T : unmanaged + { } + + static StructCastBenchmark() //if UnmanagedStruct is not unmanaged struct then this will fail to compile. + => EnsureIsUnmanaged(new UnmanagedStruct()); + + private IntPtr data; + private void* dataptr; + + [GlobalSetup] + public void Setup() + { + data = Marshal.AllocHGlobal(Marshal.SizeOf()); + dataptr = data.ToPointer(); + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void Marshal_PtrToStructure() + { + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = Marshal.PtrToStructure(data); + } + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void PointerCast() + { + var dptr = dataptr; + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = *(UnmanagedStruct*) dptr; + } + } + + [Benchmark, MethodImpl(MethodImplOptions.NoOptimization)] + public void Unsafe_Read() + { + var dptr = dataptr; + UnmanagedStruct _; + for (int i = 0; i < 10000; i++) + { + _ = Unsafe.Read(dptr); + } + } + + } +} \ No newline at end of file diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 73d40d28f..3116e6f43 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -102,7 +102,7 @@ public bool Run() // Display logs per epoch step if ((epoch + 1) % display_step == 0) - print($"Epoch: {(epoch + 1).ToString("D4")} Cost: {avg_cost.ToString("G9")} Elapse: {sw.ElapsedMilliseconds}ms"); + print($"Epoch: {(epoch + 1):D4} Cost: {avg_cost:G9} Elapse: {sw.ElapsedMilliseconds}ms"); sw.Reset(); } @@ -114,8 +114,8 @@ public bool Run() var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); // Calculate accuracy var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); - print($"Accuracy: {acc.ToString("F4")}"); + float acc = accuracy.eval(sess, (x, mnist.Test.Data), (y, mnist.Test.Labels)); + print($"Accuracy: {acc:F4}"); return acc > 0.9; } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs index 50093f3cb..d0c06704d 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs @@ -84,7 +84,7 @@ public void Predict(Session sess) public void PrepareData() { // get model file - string url = "http://download.tf.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; + string url = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tar.gz"; Web.Download(url, modelDir, "ssd_mobilenet_v1_coco.tar.gz"); Compress.ExtractTGZ(Path.Join(modelDir, "ssd_mobilenet_v1_coco.tar.gz"), "./"); diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 79cc548fb..7f2d81f4f 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -21,6 +21,7 @@ limitations under the License. using System.Diagnostics; using System.IO; using System.Linq; +using System.Threading.Tasks; using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Binding; @@ -381,10 +382,15 @@ private void cache_bottlenecks(Session sess, Dictionary { - foreach (var category in new string[] { "training", "testing", "validation" }) + var (label_name, label_lists) = kvs[i]; + + Parallel.For(0, categories.Length, j => { + var category = categories[j]; var category_list = label_lists[category]; foreach (var (index, unused_base_name) in enumerate(category_list)) { @@ -395,8 +401,8 @@ private void cache_bottlenecks(Session sess, Dictionary> image_lists, @@ -508,7 +514,7 @@ public void PrepareData() { // get a set of images to teach the network about the new classes string fileName = "flower_photos.tgz"; - string url = $"http://download.tf.org/example_images/{fileName}"; + string url = $"http://download.tensorflow.org/example_images/{fileName}"; Web.Download(url, data_dir, fileName); Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir); diff --git a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs index 15d9b8193..6c593929c 100644 --- a/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs +++ b/test/TensorFlowNET.UnitTest/Basics/AssignTests.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.Basics @@ -14,21 +15,22 @@ public void ShouldAssignVariable() var expected = new[] { false, true, false, false, true, false, true }; var spike = tf.Variable(false); - - spike.initializer.run(); - foreach (var i in range(1, 2)) + using (var sess = new Session()) { - if (raw_data[i] - raw_data[i - 1] > 5d) - { - var updater = tf.assign(spike, tf.constant(true)); - updater.eval(); - } - else + spike.initializer.run(session: sess); + foreach (var i in range(1, 2)) { - tf.assign(spike, tf.constant(true)).eval(); - } + if (raw_data[i] - raw_data[i - 1] > 5d) + { + var updater = tf.assign(spike, tf.constant(true)); + updater.eval(sess); + } else + { + tf.assign(spike, tf.constant(true)).eval(sess); + } - Assert.AreEqual((bool)spike.eval(), expected[i - 1]); + Assert.AreEqual((bool) spike.eval(), expected[i - 1]); + } } } } diff --git a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs index 33e38870b..58609c172 100644 --- a/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiGradientsTest.cs @@ -2,6 +2,7 @@ using NumSharp; using System; using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest @@ -45,15 +46,18 @@ private void TestGradientsSuccess(bool grad_inputs_provided) private bool GetGraphDef(Graph graph, out GraphDef graph_def) { graph_def = null; - var s = new Status(); - var buffer = new Buffer(); - c_api.TF_GraphToGraphDef(graph, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)); - if (ret) graph_def = GraphDef.Parser.ParseFrom(buffer.Data); - buffer.Dispose(); - s.Dispose(); - return ret; + using (var s = new Status()) + { + using (var buffer = new Buffer()) + { + c_api.TF_GraphToGraphDef(graph, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)); + if (ret) + graph_def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + return ret; + } + } } private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index 33e88286d..ae57b0753 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -40,10 +40,7 @@ public void SetInputs(Dictionary inputs) private void DeleteInputValues() { - for (var i = 0; i < input_values_.Count; ++i) - { - input_values_[i].Dispose(); - } + //clearing is enough as they will be disposed by the GC unless they are referenced else-where. input_values_.Clear(); } @@ -60,11 +57,7 @@ public void SetOutputs(TF_Output[] outputs) private void ResetOutputValues() { - for (var i = 0; i < output_values_.Count; ++i) - { - if (output_values_[i] != IntPtr.Zero) - output_values_[i].Dispose(); - } + //clearing is enough as they will be disposed by the GC unless they are referenced else-where. output_values_.Clear(); } diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index f5431e016..94da6d97e 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -322,7 +322,6 @@ public void ImportGraphDef() EXPECT_EQ(feed2, control_inputs[1]); // Export to a graph def so we can import a graph with control dependencies - graph_def.Dispose(); graph_def = new Buffer(); c_api.TF_GraphToGraphDef(graph, graph_def, s); EXPECT_EQ(TF_Code.TF_OK, s.Code); @@ -346,14 +345,10 @@ public void ImportGraphDef() EXPECT_EQ(feed4, control_inputs[1]); c_api.TF_DeleteImportGraphDefOptions(opts); - c_api.TF_DeleteBuffer(graph_def); // Can add nodes to the imported graph without trouble. c_test_util.Add(feed, scalar, graph, s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - - graph.Dispose(); - s.Dispose(); } /// diff --git a/test/TensorFlowNET.UnitTest/NameScopeTest.cs b/test/TensorFlowNET.UnitTest/NameScopeTest.cs index 4ff50deb0..3d763b383 100644 --- a/test/TensorFlowNET.UnitTest/NameScopeTest.cs +++ b/test/TensorFlowNET.UnitTest/NameScopeTest.cs @@ -1,4 +1,5 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow; using static Tensorflow.Binding; @@ -42,5 +43,39 @@ public void NestedNameScope() Assert.AreEqual("", g._name_stack); } + + [TestMethod] + public void NestedNameScope_Using() + { + Graph g = tf.Graph().as_default(); + + using (var name = new ops.NameScope("scope1")) + { + Assert.AreEqual("scope1", g._name_stack); + Assert.AreEqual("scope1/", name); + + var const1 = tf.constant(1.0); + Assert.AreEqual("scope1/Const:0", const1.name); + + using (var name2 = new ops.NameScope("scope2")) + { + Assert.AreEqual("scope1/scope2", g._name_stack); + Assert.AreEqual("scope1/scope2/", name); + + var const2 = tf.constant(2.0); + Assert.AreEqual("scope1/scope2/Const:0", const2.name); + } + + Assert.AreEqual("scope1", g._name_stack); + var const3 = tf.constant(2.0); + Assert.AreEqual("scope1/Const_1:0", const3.name); + } + + ; + + g.Dispose(); + + Assert.AreEqual("", g._name_stack); + } } } diff --git a/test/TensorFlowNET.UnitTest/Open.snk b/test/TensorFlowNET.UnitTest/Open.snk new file mode 100644 index 000000000..22a3cbd25 Binary files /dev/null and b/test/TensorFlowNET.UnitTest/Open.snk differ diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 0caa5259b..226a48396 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -4,6 +4,7 @@ using System.Linq; using NumSharp; using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; using static Tensorflow.Binding; @@ -21,7 +22,7 @@ public void GetAllOpList() { var handle = c_api.TF_GetAllOpList(); var buffer = new Buffer(handle); - var op_list = OpList.Parser.ParseFrom(buffer); + var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream()); var _registered_ops = new Dictionary(); foreach (var op_def in op_list.Op) diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 701b4b4b8..d2ae36d76 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -165,7 +165,7 @@ public T evaluate(Tensor tensor) { using (var sess = tf.Session()) { - var ndarray=tensor.eval(); + var ndarray=tensor.eval(sess); if (typeof(T) == typeof(double)) { double x = ndarray; diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 8fd4dc8a5..62d7c63d8 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -72,8 +72,6 @@ public void Session() // Clean up csession.CloseAndDelete(s); ASSERT_EQ(TF_Code.TF_OK, s.Code); - graph.Dispose(); - s.Dispose(); } [TestMethod] @@ -84,7 +82,7 @@ public void EvalTensor() var c = math_ops.matmul(a, b, name: "matmul"); using (var sess = tf.Session()) { - var result = c.eval(); + var result = c.eval(sess); Assert.AreEqual(6, result.Data()[0]); } } diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 848512f06..661d85eac 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -4,6 +4,12 @@ netcoreapp2.2 false + + true + + false + + Open.snk diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 211d7b65f..4d9d1059d 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -119,7 +119,7 @@ public void Assign2() { sess.run(init_op); // o some work with the model. - inc_v1.op.run(); + inc_v1.op.run(session: sess); } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index 1b6909e7b..627d7c2f6 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -1,4 +1,6 @@ -using Tensorflow; +using System.Diagnostics.CodeAnalysis; +using Tensorflow; +using Tensorflow.Util; using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest @@ -26,12 +28,15 @@ public static Operation Add(Operation l, Operation r, Graph graph, Status s, str return op; } + [SuppressMessage("ReSharper", "RedundantAssignment")] public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) { - var buffer = new Buffer(); - c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); - attr_value = AttrValue.Parser.ParseFrom(buffer); - buffer.Dispose(); + using (var buffer = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); + attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); + } + return s.Code == TF_Code.TF_OK; } @@ -42,7 +47,7 @@ public static GraphDef GetGraphDef(Graph graph) { c_api.TF_GraphToGraphDef(graph, buffer, s); s.Check(); - return GraphDef.Parser.ParseFrom(buffer); + return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); } } diff --git a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs index cdbd5f144..310ac6347 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs @@ -24,16 +24,17 @@ public class CreateOpFromTfOperationTest : PythonTest [TestMethod] public void TestShape() { - var g = tf.Graph().as_default(); - - var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); - var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); - var op = g._create_op_from_tf_operation(c_op); - - Assert.AreEqual("myop", op.name); - Assert.AreEqual("Identity", op.type); - Assert.AreEqual(1, len(op.outputs)); - assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); + using (var g = tf.Graph().as_default()) + { + var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); + var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); + var op = g._create_op_from_tf_operation(c_op); + + Assert.AreEqual("myop", op.name); + Assert.AreEqual("Identity", op.type); + Assert.AreEqual(1, len(op.outputs)); + assertItemsEqual(new[] {2, 3}, op.outputs[0].shape); + } } [TestMethod]