Skip to content

Commit e4e62dd

Browse files
committed
Tensor: Rewrite Creation to fix heap corruption.
TensorTest: Added Unit tests for more coverage, reformatted file. c_api.tensor: Added overloads for TF_NewTensor that does not have deallocator parameters. BaseSession: Removed disposal immediately after TF_SessionRun call. c_api.DeallocatorArgs: Added DeallocatorArgs.Empty
1 parent 1d6f3de commit e4e62dd

File tree

8 files changed

+365
-329
lines changed

8 files changed

+365
-329
lines changed

Diff for: TensorFlow.NET.sln.DotSettings

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
2+
<s:Boolean x:Key="/Default/UserDictionary/Words/=Tensorflow/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

Diff for: src/TensorFlowNET.Core/APIs/c_api.cs

+9
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ public static string StringPiece(IntPtr handle)
5454

5555
public struct DeallocatorArgs
5656
{
57+
internal static unsafe c_api.DeallocatorArgs* EmptyPtr;
58+
internal static unsafe IntPtr Empty;
59+
60+
static unsafe DeallocatorArgs()
61+
{
62+
Empty = new IntPtr(EmptyPtr = (DeallocatorArgs*) Marshal.AllocHGlobal(Marshal.SizeOf<DeallocatorArgs>()));
63+
*EmptyPtr = new DeallocatorArgs() {gc_handle = IntPtr.Zero, deallocator_called = false};
64+
}
65+
5766
public bool deallocator_called;
5867
public IntPtr gc_handle;
5968
}

Diff for: src/TensorFlowNET.Core/Sessions/BaseSession.cs

+3-15
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,16 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
152152
{
153153

154154
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
155-
var ignoreDispose = new bool[feed_dict.Count];
156155
int i = 0;
157156
foreach (var x in feed_dict)
158157
{
159158
if (x.Key is Tensor tensor)
160159
{
161160
switch (x.Value)
162161
{
163-
case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
162+
case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
164163
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break;
164+
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
165165
#if _REGEN
166166
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
167167
%foreach types%
@@ -194,7 +194,6 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
194194
#endif
195195
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break;
196196
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
197-
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
198197
default:
199198
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
200199
}
@@ -203,18 +202,7 @@ private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list,
203202

204203
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
205204
//var targets = target_list;
206-
try
207-
{
208-
return _call_tf_sessionrun(feeds, fetches, target_list);
209-
} finally
210-
{
211-
for (var idx = 0; idx < feeds.Length; idx++)
212-
{
213-
if (ignoreDispose[idx])
214-
continue;
215-
feeds[idx].Value.Dispose();
216-
}
217-
}
205+
return _call_tf_sessionrun(feeds, fetches, target_list);
218206
}
219207

220208
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)

Diff for: src/TensorFlowNET.Core/Tensors/AllocationType.cs

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
namespace Tensorflow
2+
{
3+
/// <summary>
4+
/// Used internally to
5+
/// </summary>
6+
public enum AllocationType
7+
{
8+
None = 0,
9+
/// <summary>
10+
/// Allocation was done by passing in a pointer, might be also holding reference to a C# object.
11+
/// </summary>
12+
FromPointer = 1,
13+
/// <summary>
14+
/// Allocation was done by calling c_api.TF_AllocateTensor or TF decided it has to copy data during c_api.TF_NewTensor. <br></br>
15+
/// Deallocation is handled solely by Tensorflow.
16+
/// </summary>
17+
Tensorflow = 2,
18+
/// <summary>
19+
/// Allocation was done by Marshal.AllocateHGlobal
20+
/// </summary>
21+
Marshal = 3,
22+
/// <summary>
23+
/// Allocation was done by GCHandle.Alloc
24+
/// </summary>
25+
GCHandle = 4,
26+
}
27+
}

0 commit comments

Comments
 (0)