1
+ using System . Reflection ;
2
+ using TorchSharp ;
3
+ using Xunit ;
4
+
5
+ namespace TorchSharpTest ;
6
+
7
+ public class TestDisposeScopesPackedSequence
8
+ {
9
+ [ Fact ]
10
+ public void MoveDisposeScope ( )
11
+ {
12
+ var sequences = CreateTestSequences ( ) ;
13
+ torch . nn . utils . rnn . PackedSequence packed_sequence ;
14
+ var otherScope = torch . NewDisposeScope ( ) ;
15
+ using ( torch . NewDisposeScope ( ) )
16
+ {
17
+ using ( torch . NewDisposeScope ( ) )
18
+ {
19
+ packed_sequence = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : false ) ;
20
+ AssertPackedSequenceValid ( packed_sequence ) ;
21
+
22
+ packed_sequence . MoveToOuterDisposeScope ( ) ;
23
+ }
24
+ AssertPackedSequenceValid ( packed_sequence ) ;
25
+
26
+ packed_sequence . MoveToOtherDisposeScope ( otherScope ) ;
27
+ }
28
+
29
+ AssertPackedSequenceValid ( packed_sequence ) ;
30
+ otherScope . Dispose ( ) ;
31
+
32
+ Assert . True ( GetPackedSequenceIsInvalid ( packed_sequence ) ) ;
33
+ Assert . True ( packed_sequence . data . IsInvalid ) ;
34
+ }
35
+
36
+ [ Fact ]
37
+ public void DisposablesValidityWhenNotSorted ( )
38
+ {
39
+ var sequences = CreateTestSequences ( ) ;
40
+ using var scope = torch . NewDisposeScope ( ) ;
41
+ var packed = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : false ) ;
42
+ Assert . Equal ( 1 , scope . DisposablesCount ) ;
43
+ AssertPackedSequenceValid ( packed ) ;
44
+ }
45
+
46
+ [ Fact ]
47
+ public void DisposablesValidityWhenSorted ( )
48
+ {
49
+ var sequences = CreateTestSequences ( ) ;
50
+ using var scope = torch . NewDisposeScope ( ) ;
51
+ var packed = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : true ) ;
52
+ Assert . Equal ( 1 , scope . DisposablesCount ) ;
53
+ Assert . False ( GetPackedSequenceIsInvalid ( packed ) ) ;
54
+ Assert . False ( packed . batch_sizes . IsInvalid ) ;
55
+ Assert . False ( packed . data . IsInvalid ) ;
56
+ Assert . True ( packed . sorted_indices . IsInvalid ) ;
57
+ Assert . True ( packed . unsorted_indices . IsInvalid ) ;
58
+ }
59
+
60
+ [ Fact ]
61
+ public void DisposeScopeStatistics ( )
62
+ {
63
+ DisposeScopeManager . Statistics . Reset ( ) ;
64
+ AssertStatCounts ( 0 , 0 , 0 , 0 , 0 ) ;
65
+ var sequences = CreateTestSequences ( ) ;
66
+ AssertStatCounts ( 0 , 2 , 0 , 0 , 0 ) ;
67
+ var outOfScope = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : true ) ;
68
+ AssertStatCounts ( 0 , 7 , 0 , 0 , 0 ) ;
69
+ using var scope = torch . NewDisposeScope ( ) ;
70
+ AssertStatCounts ( 0 , 7 , 0 , 0 , 0 ) ;
71
+
72
+ var inScope = torch . nn . utils . rnn . pack_sequence ( sequences , enforce_sorted : true ) ;
73
+ AssertStatCounts ( 5 , 7 , 4 , 0 , 1 ) ;
74
+
75
+ scope . Attach ( outOfScope ) ;
76
+ //Possible subtle bug. When attaching an object that isn't owned by any scope, the count subtracts.
77
+ AssertStatCounts ( 5 , 7 , 3 , 0 , 2 ) ;
78
+
79
+ scope . Detach ( inScope ) ;
80
+ AssertStatCounts ( 5 , 7 , 4 , 0 , 1 ) ;
81
+
82
+ outOfScope . Dispose ( ) ;
83
+ AssertStatCounts ( 5 , 7 , 4 , 5 , - 4 ) ;
84
+
85
+ }
86
+
87
+ private static void AssertStatCounts ( long createdInScope , long createdOutsideScope , long detachedFrom , long disposedIn , long threadTotalLive )
88
+ {
89
+ var stats = DisposeScopeManager . Statistics ;
90
+ Assert . Equal ( createdInScope , stats . CreatedInScopeCount ) ;
91
+ Assert . Equal ( createdOutsideScope , stats . CreatedOutsideScopeCount ) ;
92
+ Assert . Equal ( detachedFrom , stats . DetachedFromScopeCount ) ;
93
+ Assert . Equal ( disposedIn , stats . DisposedInScopeCount ) ;
94
+ Assert . Equal ( threadTotalLive , stats . ThreadTotalLiveCount ) ;
95
+ }
96
+
97
+ private static torch . Tensor [ ] CreateTestSequences ( )
98
+ {
99
+ return new [ ]
100
+ {
101
+ torch . tensor ( new long [ ] { 1 , 2 , 3 , 4 } ) ,
102
+ torch . tensor ( new long [ ] { 5 , 6 } ) ,
103
+ } ;
104
+ }
105
+
106
+ private static void AssertPackedSequenceValid ( torch . nn . utils . rnn . PackedSequence packed_sequence )
107
+ {
108
+ Assert . False ( GetPackedSequenceIsInvalid ( packed_sequence ) ) ;
109
+ Assert . False ( packed_sequence . batch_sizes . IsInvalid ) ;
110
+ Assert . False ( packed_sequence . data . IsInvalid ) ;
111
+ Assert . False ( packed_sequence . sorted_indices . IsInvalid ) ;
112
+ Assert . False ( packed_sequence . unsorted_indices . IsInvalid ) ;
113
+ }
114
+
115
+ private static bool GetPackedSequenceIsInvalid ( torch . nn . utils . rnn . PackedSequence packed_sequence )
116
+ {
117
+ //HACK: reflection to avoid exposing internal method IsInvalid in API.
118
+ var getter = typeof ( torch . nn . utils . rnn . PackedSequence ) . GetProperty ( "IsInvalid" , BindingFlags . Instance | BindingFlags . NonPublic ) ! ;
119
+ return ( bool ) getter . GetValue ( packed_sequence ) ! ;
120
+ }
121
+ }
0 commit comments