Skip to content

Commit 36f7ad0

Browse files
Merge branch 'main' into unit
2 parents b98a88a + 9856487 commit 36f7ad0

22 files changed

+989
-249
lines changed

RELEASENOTES.md

+6
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
44

55
# NuGet Version 0.103.1
66

7+
__Breaking Changes__:
8+
#1376 `torch.Tensor.backward`'s function signature has been updated to match PyTorch's implementation. Previously, passing `create_graph` or `retain_graph` by position would work like PyTorch's `torch.Tensor.backward`, but not if passing by name (`create_graph`'s value was swapped with `retain_graph`). This has been corrected; however, this means any code that passes `create_graph` or `retain_graph` by name needs to be updated to reflect the intended functionality.<br/>
9+
710
__Bug Fixes__:
811

912
#1383 `torch.linalg.vector_norm`: Make `ord`-argument optional, as specified in docs<br/>
1013
#1385 PackedSequence now participates in the DisposeScope system at the same level as Tensor objects.<br/>
14+
#1387 Attaching tensor to a DisposeScope no longer makes Statistics.DetachedFromScopeCount go negative.<br/>
15+
#1390 DisposeScopeManager.Statistics now includes DisposedOutsideScopeCount and AttachedToScopeCount. ThreadTotalLiveCount is now exact instead of approximate. ToString gives a useful debug string, and documentation is added for how to troubleshoot memory leaks. Also DisposeScopeManager.Statistics.TensorStatistics and DisposeScopeManager.Statistics.PackedSequenceStatistics provide separate metrics for these objects.<br/>
16+
#1392 ToTensor() extension method memory leaks fixed.<br/>
1117

1218
# NuGet Version 0.103.0
1319

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Memory Leak Troubleshooting
2+
3+
If suspect you are leaking memory this is your guide. First be sure to be familiar with the [Memory Management Techniques](memory.md).
4+
5+
## Verifying you have a leak
6+
7+
The `DisposeScopeManager.Statistics` property defines thread level statistics of objects captured
8+
in TorchSharp as objects are created and moved between DisposeScopes. Normally deal directly
9+
with only this property.
10+
11+
To see where code may be leaking objects, it is easiest to modify the training loop.
12+
Use a DisposeScope, reset the global statistics to have a known starting point, then take
13+
some action and look at the statistics to see what's still around.
14+
15+
```csharp
16+
//Training Loop, 10 epochs
17+
for (int i = 0; i < 10; i++) {
18+
//Clear the statistics
19+
DisposeScopeManager.Statistics.Reset();
20+
//Take action. In this case it is inside a DisposeScope, so
21+
//when this code block is done, there should be no new live objects.
22+
using (NewDisposeScope()) {
23+
var eval = model.call(x);
24+
// ... other model execution code
25+
optimizer.step();
26+
}
27+
//Examine what happened
28+
Console.WriteLine(DisposeScopeManager.Statistics);
29+
}
30+
```
31+
32+
If on every iteration the number of live objects is increasing, there is a leak. In the following
33+
example note that the number of live objects increases by 200 every iteration. It can also be
34+
seen these objects were created on a DisposeScope, but were eventually detached. In this specific
35+
case, look for where the code is detaching the tensors, and then determine how
36+
to correctly manage the lifetime of these objects.
37+
```csharp
38+
ThreadTotalLiveCount: 548; CreatedOutsideScopeCount: 0; DisposedOutsideScopeCount: 0; CreatedInScopeCount: 200; DisposedInScopeCount: 2; AttachedToScopeCount: 0; DetachedFromScopeCount: 200"
39+
ThreadTotalLiveCount: 748; CreatedOutsideScopeCount: 0; DisposedOutsideScopeCount: 0; CreatedInScopeCount: 200; DisposedInScopeCount: 2; AttachedToScopeCount: 0; DetachedFromScopeCount: 200"
40+
ThreadTotalLiveCount: 948; CreatedOutsideScopeCount: 0; DisposedOutsideScopeCount: 0; CreatedInScopeCount: 200; DisposedInScopeCount: 2; AttachedToScopeCount: 0; DetachedFromScopeCount: 200"
41+
```
42+
43+
It is not necessary to leave this code in place for production implementations after fixing a leak. It may be removed so the code looks more Pythonic if needed.
44+
45+
## Identifying the leak
46+
This is where the leg work is. Look at each line of code where Tensor or PackedSequence objects are created.
47+
Ensure they are eventually disposed either manually or by a DisposeScope. One can also print the statistics to
48+
the debugger while stepping code for an interactive approach.
49+
50+
Be aware that TorchSharp also creates tensors for itself and uses them in various
51+
ways. Just because one finds a tensor that is created by TorchSharp isn't being disposed, it likely isn't caused
52+
by TorchSharp. A good example is the Adam optimizer. It creates tensors internally to manage it's parameters,
53+
and detaches them from any DisposeScope that is in use. If it didn't, it would fail doing gradients and back
54+
propagation as it's tensors would have been disposed. These are eventually cleaned up when the optimizer is
55+
properly disposed after training. Faliure of the client code to dispose is the most likely cause of memory leaks.
56+
57+
## Working with RNNs
58+
One may want to drill down to `DisposeScopeManager.Statistics.TensorStatistics` or
59+
`DisposeScopeManager.Statistics.PackedSequenceStatistics`, these track Tensors and
60+
PackedSequence usages independently.
61+
62+
Additionally, a PackedSequence uses some tensors internally. These tensors show up in the creation statistics,
63+
and are immediately detached from any scope if there is one in context and will
64+
increment the DetachedFromScopeCount property. When a PackedSequence is disposed, it will also Dispose
65+
it's tensors. The differences in counts can be seen in the following, which represents output within an IDE
66+
debug window where all three levels of statistics were observed at the same execution time. Note the first two
67+
sum to the totals on the last line.
68+
69+
```
70+
DisposeScopeManager.Statistics.TensorStatistics.ToString()
71+
"ThreadTotalLiveCount: 4; CreatedOutsideScopeCount: 0; DisposedOutsideScopeCount: 0; CreatedInScopeCount: 6; DisposedInScopeCount: 2; AttachedToScopeCount: 0; DetachedFromScopeCount: 4"
72+
73+
DisposeScopeManager.Statistics.PackedSequenceStatistics.ToString()
74+
"ThreadTotalLiveCount: 1; CreatedOutsideScopeCount: 0; DisposedOutsideScopeCount: 0; CreatedInScopeCount: 1; DisposedInScopeCount: 0; AttachedToScopeCount: 0; DetachedFromScopeCount: 0"
75+
76+
DisposeScopeManager.Statistics.ToString()
77+
"ThreadTotalLiveCount: 5; CreatedOutsideScopeCount: 0; DisposedOutsideScopeCount: 0; CreatedInScopeCount: 7; DisposedInScopeCount: 2; AttachedToScopeCount: 0; DetachedFromScopeCount: 4"
78+
79+
```
80+
81+

docfx/articles/memory.md

+11-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ In both cases, you may want to experiment with using a smaller batch size -- tem
1010

1111
Note DiffSharp (which uses TorchSharp) relies on techniques 1.
1212

13+
Also refer to [Memory Leak Troubleshooting](memory leak troubleshooting.md) for help on fixing any leaks.
14+
1315
> Most of the examples included will use technique #1, doing frequent explicit calls to GC.Collect() in the training code -- if not after each batch in the training loop, at least after each epoch.
1416
1517
## Technique 1. Automatic disposal via Garbage Collection
@@ -44,7 +46,7 @@ __Note__: Even with this approach, it is a good idea to place a call to `GC.Coll
4446

4547
It is important to understand that all TorchSharp "tensors" (type Tensor) are actually "tensor aliases", referring to a C++ tensor. When a C++ tensor is created and returned to .NET as a tensor alias, and the reference count on the C++ tensor is incremented. When you call `Dispose()` on the TorchSharp tensor alias (that is, type Tensor), it is decremented. If the tensor alias is finalized instead, the decrement happens implicitly.
4648

47-
To enable this technique, all operations that return one or more TorchSharp `Tensor`s should return "fresh" Tensor aliases (though that doesn't always mean freshly copied C++ tensors). This is true even for in-place, destructive operations like `add_()`, which overwrites the underlying native tensor with data, but still returns a fresh tensor alias to that same tensor.
49+
To enable this technique, all operations that return one or more TorchSharp `Tensor`s should return "fresh" Tensor aliases (though that doesn't always mean freshly copied C++ tensors). This is true even for in-place, destructive operations like `add_()`, which overwrites the underlying native tensor with data, but still returns a fresh tensor alias to that same tensor.
4850

4951
Thus, when you write methods and functions that take and produce type Tensor, for example in the `forward()` method of a model, you should always make sure to return a fresh alias. Most of the time, this happens automatically, because the last action of your code will normally be to call another tensor function, which itself will be returning a fresh alias, but there are cases when it's not, especially when returning input tensors or tensors stored in some lookaside table.
5052

@@ -55,7 +57,7 @@ Tensor flatten(Tensor input) {
5557
if (input.shape.Length == 1)
5658
return input.alias();
5759
else
58-
return input.reshape(input.numel());
60+
return input.reshape(input.numel());
5961
}
6062
```
6163

@@ -100,10 +102,10 @@ let myTensorFunction0(input: Tensor) =
100102
input.alias()
101103
102104
let myTensorFunction1() =
103-
if today then
104-
table[4].alias()
105+
if today then
106+
table[4].alias()
105107
else
106-
table[5].alias()
108+
table[5].alias()
107109
108110
let myTensorFunction2(input: Tensor) =
109111
input.add(tensor(1))
@@ -124,9 +126,9 @@ let myTensorFunction5(go: bool, input: Tensor) =
124126
tmp2.add(tensor(1))
125127
else
126128
input.alias()
127-
129+
128130
let myTensorFunction5(go: bool, input: Tensor) =
129-
if go then
131+
if go then
130132
use tmp1 = input.add_(tensor(1)) // NOTE: even for in-place mutations
131133
use tmp2 = input.add_(tensor(1)) // NOTE: even for in-place mutations
132134
tmp2.add(tensor(1))
@@ -173,13 +175,13 @@ use d = torch.NewDisposeScope()
173175
total_acc <- total_acc + (predicted_labels.argmax(1) == labels).sum().cpu().item<long>()
174176
```
175177

176-
If you need to dispose some tensors before the scope is disposed, you can use `DisposeEverything()`, or `DisposeEverythingBut(...)` if you want to exclude a few tensors from disposal. These can be useful when tensor lifetimes aren't cleanly nested in dynamic scopes.
178+
If you need to dispose some tensors before the scope is disposed, you can use `DisposeEverything()`, or `DisposeEverythingBut(...)` if you want to exclude a few tensors from disposal. These can be useful when tensor lifetimes aren't cleanly nested in dynamic scopes.
177179

178180
__NOTE: It is absolutely essential for the proper functioning of dynamic dispose scopes that the scope is created with a 'using' statemen (C#) or 'use' expression (F#).__
179181

180182
It's important to note that these scopes are dynamic -- if any functions are called, the tensors inside them are also registered and disposed, unless there's a nested scope within those functions.
181183

182-
It is advisable to place a dispose scope around your training and test code, and in any library code that can be called from contexts that do not have dispose scopes.
184+
It is advisable to place a dispose scope around your training and test code, and in any library code that can be called from contexts that do not have dispose scopes.
183185

184186
That said, you should use dispose scope very carefully: having _too few_ scope raises the pressure on native memory, which is particularly bad for GPUs. Having too _many_ scopes, managing too few temporaries, will add runtime overhead to computations. For example, it may be better to put a scope outside an inner loop that contains multiple computations than to place it inside the loop. There is no single best answer.
185187

src/TorchSharp/DisposeScope.cs

+37-42
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,13 @@ public void MoveToOther(DisposeScope? scope, IEnumerable<IDisposable> disposable
154154
{
155155
if (this._disposeScopeManager is null)
156156
throw new ObjectDisposedException(this.GetType().FullName);
157-
foreach (var disposable in disposables) {
158-
if (Disposables.Remove(disposable)) {
159-
AddToOther(scope, disposable);
157+
if (scope == null) {
158+
Detach(disposables);
159+
} else {
160+
foreach (var disposable in disposables) {
161+
if (Disposables.Remove(disposable)) {
162+
AddToOther(scope, disposable);
163+
}
160164
}
161165
}
162166
}
@@ -209,11 +213,11 @@ public void Detach(IEnumerable<IDisposable> disposables)
209213
throw new ObjectDisposedException(this.GetType().FullName);
210214
foreach (var disposable in disposables) {
211215
if (Disposables.Remove(disposable)) {
212-
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount++;
213216
if (disposable is torch.Tensor tensor) {
217+
_disposeScopeManager.StatisticsInstance.TensorStatistics.DetachedFromScopeCount++;
214218
tensor.OwningDisposeScope = null;
215-
}
216-
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
219+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
220+
_disposeScopeManager.StatisticsInstance.PackedSequenceStatistics.DetachedFromScopeCount++;
217221
sequence.OwningDisposeScope = null;
218222
}
219223
}
@@ -237,18 +241,13 @@ public IReadOnlyList<IDisposable> Attach(IEnumerable<IDisposable> disposables)
237241

238242
var result = new List<IDisposable>();
239243
foreach (var disposable in disposables) {
240-
if (disposable is torch.Tensor tensor) {
241-
if (tensor.OwningDisposeScope == null && !tensor.IsInvalid) {
242-
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
243-
}
244-
}
245-
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
246-
if (sequence.OwningDisposeScope == null && !sequence.IsInvalid) {
247-
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount--;
244+
if (AddToOther(this, disposable)) {
245+
if (disposable is torch.Tensor tensor) {
246+
_disposeScopeManager.StatisticsInstance.TensorStatistics.AttachedToScopeCount++;
247+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
248+
_disposeScopeManager.StatisticsInstance.PackedSequenceStatistics.AttachedToScopeCount++;
248249
}
249250
}
250-
251-
AddToOther(this, disposable);
252251
result.Add(disposable);
253252
}
254253

@@ -278,22 +277,6 @@ public void DisposeEverythingBut(IEnumerable<IDisposable> inKeep)
278277
continue;
279278
}
280279

281-
if (disposable is torch.Tensor tensor) {
282-
// No need to have the disposable call back to the scope
283-
tensor.OwningDisposeScope = null;
284-
if (!tensor.IsInvalid) {
285-
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
286-
}
287-
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
288-
// No need to have the disposable call back to the scope
289-
sequence.OwningDisposeScope = null;
290-
if (!sequence.IsInvalid) {
291-
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
292-
}
293-
} else {
294-
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
295-
}
296-
297280
disposable.Dispose();
298281
}
299282
}
@@ -369,7 +352,7 @@ public void MarkAsDisposed(IDisposable disposable)
369352
{
370353
if (this._disposeScopeManager is null)
371354
throw new ObjectDisposedException(this.GetType().FullName);
372-
_disposeScopeManager.StatisticsInstance.DisposedInScopeCount++;
355+
373356
Disposables.Remove(disposable);
374357
if (disposable is torch.Tensor tensor) {
375358
tensor.OwningDisposeScope = null;
@@ -386,33 +369,45 @@ public void MarkAsDisposed(IDisposable disposable)
386369
/// <returns></returns>
387370
public bool Contains(IDisposable disposable) => Disposables.Contains(disposable);
388371

389-
private void AddToOther(DisposeScope? scope, IDisposable disposable)
372+
private bool AddToOther(DisposeScope scope, IDisposable disposable)
390373
{
391374
if (this._disposeScopeManager is null)
392375
throw new ObjectDisposedException(this.GetType().FullName);
393-
if (scope != null) {
394-
scope.Disposables.Add(disposable);
376+
377+
DisposeScope? oldScope;
378+
if (disposable is torch.Tensor t) {
379+
oldScope = t.OwningDisposeScope;
380+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
381+
oldScope = sequence.OwningDisposeScope;
395382
} else {
396-
_disposeScopeManager.StatisticsInstance.DetachedFromScopeCount++;
383+
throw new InvalidOperationException("DisposeScope can only manage Tensor or PackedSequence");
384+
}
385+
386+
if (scope == oldScope) return false;
387+
388+
scope.Disposables.Add(disposable);
389+
if (oldScope != null) {
390+
oldScope.Disposables.Remove(disposable);
397391
}
398392

399393
if (disposable is torch.Tensor tensor) {
400394
tensor.OwningDisposeScope = scope;
401-
}
402-
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
395+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
403396
sequence.OwningDisposeScope = scope;
404397
}
398+
399+
return true;
405400
}
406401

407402
internal HashSet<IDisposable> DetachAllAndDispose()
408403
{
409404
var disposables = this.Disposables;
410405
foreach (var disposable in this.Disposables) {
411-
this._disposeScopeManager!.StatisticsInstance.DetachedFromScopeCount++;
412406
if (disposable is torch.Tensor tensor) {
407+
this._disposeScopeManager!.StatisticsInstance.TensorStatistics.DetachedFromScopeCount++;
413408
tensor.OwningDisposeScope = null;
414-
}
415-
else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
409+
} else if (disposable is torch.nn.utils.rnn.PackedSequence sequence) {
410+
this._disposeScopeManager!.StatisticsInstance.PackedSequenceStatistics.DetachedFromScopeCount++;
416411
sequence.OwningDisposeScope = null;
417412
}
418413
}

0 commit comments

Comments
 (0)