Skip to content

Commit b98a88a

Browse files
Merge branch 'main' into unit
2 parents 2447cae + 97b9921 commit b98a88a

File tree

6 files changed

+216
-12
lines changed

6 files changed

+216
-12
lines changed

RELEASENOTES.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
66

77
__Bug Fixes__:
88

9-
#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs
9+
#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs<br/>
10+
#1385 PackedSequence now participates in the DisposeScope system at the same level as Tensor objects.<br/>
1011

1112
# NuGet Version 0.103.0
1213

build/BranchInfo.props

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
<PropertyGroup>
33
<MajorVersion>0</MajorVersion>
44
<MinorVersion>103</MinorVersion>
5-
<PatchVersion>0</PatchVersion>
6-
<PreviousPackageVersion>0.102.8</PreviousPackageVersion>
5+
<PatchVersion>1</PatchVersion>
6+
<PreviousPackageVersion>0.103.0</PreviousPackageVersion>
77
</PropertyGroup>
8-
9-
</Project>
8+
</Project>

src/TorchSharp/DisposeScope.cs

+25
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ public void Detach(IEnumerable<IDisposable> disposables)
213213
if (disposable is torch.Tensor tensor) {
214214
tensor.OwningDisposeScope = null;
215215
}
216+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
217+
sequence.OwningDisposeScope = null;
218+
}
216219
}
217220
}
218221
}
@@ -239,9 +242,16 @@ public IReadOnlyList<IDisposable> Attach(IEnumerable<IDisposable> disposables)
239242
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
240243
}
241244
}
245+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
246+
if (sequence.OwningDisposeScope == null && !sequence.IsInvalid) {
247+
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
248+
}
249+
}
250+
242251
AddToOther(this, disposable);
243252
result.Add(disposable);
244253
}
254+
245255
return result;
246256
}
247257

@@ -274,6 +284,12 @@ public void DisposeEverythingBut(IEnumerable<IDisposable> inKeep)
274284
if (!tensor.IsInvalid) {
275285
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
276286
}
287+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
288+
// No need to have the disposable call back to the scope
289+
sequence.OwningDisposeScope = null;
290+
if (!sequence.IsInvalid) {
291+
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
292+
}
277293
} else {
278294
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
279295
}
@@ -358,6 +374,9 @@ public void MarkAsDisposed(IDisposable disposable)
358374
if (disposable is torch.Tensor tensor) {
359375
tensor.OwningDisposeScope = null;
360376
}
377+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
378+
sequence.OwningDisposeScope = null;
379+
}
361380
}
362381

363382
/// <summary>
@@ -380,6 +399,9 @@ private void AddToOther(DisposeScope? scope, IDisposable disposable)
380399
if (disposable is torch.Tensor tensor) {
381400
tensor.OwningDisposeScope = scope;
382401
}
402+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
403+
sequence.OwningDisposeScope = scope;
404+
}
383405
}
384406

385407
internal HashSet<IDisposable> DetachAllAndDispose()
@@ -390,6 +412,9 @@ internal HashSet<IDisposable> DetachAllAndDispose()
390412
if (disposable is torch.Tensor tensor) {
391413
tensor.OwningDisposeScope = null;
392414
}
415+
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
416+
sequence.OwningDisposeScope = null;
417+
}
393418
}
394419

395420
this.Disposables = new();

src/TorchSharp/DisposeScopeManager.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class DisposeScopeManager
1818
internal ThreadDisposeScopeStatistics StatisticsInstance { get; } = new ThreadDisposeScopeStatistics();
1919
internal DisposeScope? CurrentDisposeScope { get; private set; } = null;
2020

21-
internal DisposeScope? RegisterOnCurrentDisposeScope(torch.Tensor tensor)
21+
internal DisposeScope? RegisterOnCurrentDisposeScope(IDisposable tensor)
2222
{
2323
if (this.CurrentDisposeScope is null) {
2424
StatisticsInstance.CreatedOutsideScopeCount++;

src/TorchSharp/NN/Utils/PackedSequence.cs

+64-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
23
using System;
4+
using System.Runtime.CompilerServices;
35
using System.Runtime.InteropServices;
46
using static TorchSharp.PInvoke.NativeMethods;
57

@@ -18,6 +20,17 @@ public static partial class rnn
1820
/// </summary>
1921
public sealed class PackedSequence : IDisposable
2022
{
23+
internal DisposeScope OwningDisposeScope {
24+
get => mOwningDisposeScope;
25+
set {
26+
mOwningDisposeScope = value;
27+
this.batch_sizes.OwningDisposeScope = value;
28+
this.data.OwningDisposeScope = value;
29+
this.sorted_indices.OwningDisposeScope = value;
30+
this.unsorted_indices.OwningDisposeScope = value;
31+
}
32+
}
33+
2134
/// <summary>
2235
/// Class wrapping PyTorch's packedsequence object reference.
2336
/// </summary>
@@ -39,6 +52,7 @@ internal HType() : base(IntPtr.Zero, true)
3952
protected override bool ReleaseHandle()
4053
{
4154
THSNN_PackedSequence_dispose(handle);
55+
handle = IntPtr.Zero;
4256
return true;
4357
}
4458
}
@@ -62,15 +76,21 @@ protected override bool ReleaseHandle()
6276
/// The original indices
6377
/// </summary>
6478
public readonly Tensor unsorted_indices;
79+
/// <summary>
80+
/// Is true if the PackedSequence has been disposed, false otherwise.
81+
/// </summary>
82+
internal bool IsInvalid => handle.IsInvalid;
6583
private HType handle;
84+
private DisposeScope mOwningDisposeScope;
6685

6786
internal PackedSequence(HType handle)
6887
{
6988
this.handle = handle;
70-
this.data = new Tensor(THSNN_PackedSequence_data(handle));
71-
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle));
72-
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle));
73-
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle));
89+
this.data = new Tensor(THSNN_PackedSequence_data(handle)).DetachFromDisposeScope();
90+
this.batch_sizes = new Tensor(THSNN_PackedSequence_batch_sizes(handle)).DetachFromDisposeScope();
91+
this.sorted_indices = new Tensor(THSNN_PackedSequence_sorted_indices(handle)).DetachFromDisposeScope();
92+
this.unsorted_indices = new Tensor(THSNN_PackedSequence_unsorted_indices(handle)).DetachFromDisposeScope();
93+
OwningDisposeScope = DisposeScopeManager.ThreadSingleton.RegisterOnCurrentDisposeScope(this);
7494
}
7595

7696
internal HType Handle => handle;
@@ -84,15 +104,53 @@ public void Dispose()
84104
this.batch_sizes.Dispose();
85105
this.sorted_indices.Dispose();
86106
this.unsorted_indices.Dispose();
107+
OwningDisposeScope?.MarkAsDisposed(this);
87108

88109
if (handle != null && !handle.IsInvalid) {
89110
handle.Dispose();
90111
handle.SetHandleAsInvalid();
112+
113+
}
114+
}
115+
/// <summary>
116+
/// Moves PackedSequence to the outer DisposeScope. If there is no outer DisposeScope, it's detached from the
117+
/// DisposeScope system.
118+
/// </summary>
119+
/// <returns>The same PackedSequence that the method was called on</returns>
120+
public PackedSequence MoveToOuterDisposeScope()
121+
{
122+
OwningDisposeScope?.MoveToOuter(this);
123+
return this;
124+
}
125+
126+
/// <summary>
127+
/// Detaches the PackedSequence completely from the DisposeScope system.
128+
/// </summary>
129+
/// <returns>The same PackedSequence that the method was called on</returns>
130+
public PackedSequence DetachFromDisposeScope()
131+
{
132+
OwningDisposeScope?.Detach(this);
133+
return this;
134+
}
135+
136+
public PackedSequence MoveToOtherDisposeScope(PackedSequence other)
137+
{
138+
return MoveToOtherDisposeScope(other.OwningDisposeScope);
139+
}
140+
141+
public PackedSequence MoveToOtherDisposeScope(DisposeScope other)
142+
{
143+
if (OwningDisposeScope == null && other != null) {
144+
other.Attach(this);
145+
}
146+
else {
147+
OwningDisposeScope?.MoveToOther(other, this);
91148
}
149+
return this;
92150
}
93-
}
151+
}
94152
}
95153
}
96154
}
97155
}
98-
}
156+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using System.Reflection;
2+
using TorchSharp;
3+
using Xunit;
4+
5+
namespace TorchSharpTest;
6+
7+
public class TestDisposeScopesPackedSequence
8+
{
9+
[Fact]
10+
public void MoveDisposeScope()
11+
{
12+
var sequences = CreateTestSequences();
13+
torch.nn.utils.rnn.PackedSequence packed_sequence;
14+
var otherScope = torch.NewDisposeScope();
15+
using (torch.NewDisposeScope())
16+
{
17+
using (torch.NewDisposeScope())
18+
{
19+
packed_sequence = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
20+
AssertPackedSequenceValid(packed_sequence);
21+
22+
packed_sequence.MoveToOuterDisposeScope();
23+
}
24+
AssertPackedSequenceValid(packed_sequence);
25+
26+
packed_sequence.MoveToOtherDisposeScope(otherScope);
27+
}
28+
29+
AssertPackedSequenceValid(packed_sequence);
30+
otherScope.Dispose();
31+
32+
Assert.True(GetPackedSequenceIsInvalid(packed_sequence));
33+
Assert.True(packed_sequence.data.IsInvalid);
34+
}
35+
36+
[Fact]
37+
public void DisposablesValidityWhenNotSorted()
38+
{
39+
var sequences = CreateTestSequences();
40+
using var scope = torch.NewDisposeScope();
41+
var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: false);
42+
Assert.Equal(1, scope.DisposablesCount);
43+
AssertPackedSequenceValid(packed);
44+
}
45+
46+
[Fact]
47+
public void DisposablesValidityWhenSorted()
48+
{
49+
var sequences = CreateTestSequences();
50+
using var scope = torch.NewDisposeScope();
51+
var packed = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
52+
Assert.Equal(1, scope.DisposablesCount);
53+
Assert.False(GetPackedSequenceIsInvalid(packed));
54+
Assert.False(packed.batch_sizes.IsInvalid);
55+
Assert.False(packed.data.IsInvalid);
56+
Assert.True(packed.sorted_indices.IsInvalid);
57+
Assert.True(packed.unsorted_indices.IsInvalid);
58+
}
59+
60+
[Fact]
61+
public void DisposeScopeStatistics()
62+
{
63+
DisposeScopeManager.Statistics.Reset();
64+
AssertStatCounts(0, 0, 0, 0, 0);
65+
var sequences = CreateTestSequences();
66+
AssertStatCounts(0, 2, 0, 0, 0);
67+
var outOfScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
68+
AssertStatCounts(0, 7, 0, 0, 0);
69+
using var scope = torch.NewDisposeScope();
70+
AssertStatCounts(0, 7, 0, 0, 0);
71+
72+
var inScope = torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted: true);
73+
AssertStatCounts(5, 7, 4, 0, 1);
74+
75+
scope.Attach(outOfScope);
76+
//Possible subtle bug. When attaching an object that isn't owned by any scope, the count subtracts.
77+
AssertStatCounts( 5, 7, 3, 0, 2);
78+
79+
scope.Detach(inScope);
80+
AssertStatCounts( 5, 7, 4, 0, 1);
81+
82+
outOfScope.Dispose();
83+
AssertStatCounts( 5, 7, 4, 5, -4);
84+
85+
}
86+
87+
private static void AssertStatCounts(long createdInScope, long createdOutsideScope, long detachedFrom, long disposedIn, long threadTotalLive)
88+
{
89+
var stats = DisposeScopeManager.Statistics;
90+
Assert.Equal(createdInScope, stats.CreatedInScopeCount);
91+
Assert.Equal(createdOutsideScope, stats.CreatedOutsideScopeCount);
92+
Assert.Equal(detachedFrom, stats.DetachedFromScopeCount);
93+
Assert.Equal(disposedIn, stats.DisposedInScopeCount);
94+
Assert.Equal(threadTotalLive, stats.ThreadTotalLiveCount);
95+
}
96+
97+
private static torch.Tensor[] CreateTestSequences()
98+
{
99+
return new[]
100+
{
101+
torch.tensor(new long[] { 1, 2, 3, 4 }),
102+
torch.tensor(new long[] { 5, 6 }),
103+
};
104+
}
105+
106+
private static void AssertPackedSequenceValid(torch.nn.utils.rnn.PackedSequence packed_sequence)
107+
{
108+
Assert.False(GetPackedSequenceIsInvalid(packed_sequence));
109+
Assert.False(packed_sequence.batch_sizes.IsInvalid);
110+
Assert.False(packed_sequence.data.IsInvalid);
111+
Assert.False(packed_sequence.sorted_indices.IsInvalid);
112+
Assert.False(packed_sequence.unsorted_indices.IsInvalid);
113+
}
114+
115+
private static bool GetPackedSequenceIsInvalid(torch.nn.utils.rnn.PackedSequence packed_sequence)
116+
{
117+
//HACK: reflection to avoid exposing internal method IsInvalid in API.
118+
var getter = typeof(torch.nn.utils.rnn.PackedSequence).GetProperty("IsInvalid", BindingFlags.Instance | BindingFlags.NonPublic)!;
119+
return (bool)getter.GetValue(packed_sequence)!;
120+
}
121+
}

0 commit comments

Comments
 (0)