Skip to content

Commit be61da6

Browse files
Merge branch 'main' of https://github.com/dotnet/TorchSharp
2 parents 42c63a6 + cd11cfe commit be61da6

File tree

3 files changed

+188
-0
lines changed

3 files changed

+188
-0
lines changed

RELEASENOTES.md

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
44

5+
# NuGet Version 0.102.3
6+
7+
__API Changes__:
8+
9+
#1243 `fuse_conv_bn_weights` and `fuse_linear_bn_weights` are added.<br/>
10+
511
# NuGet Version 0.102.2
612

713
__Bug Fixes__:

src/TorchSharp/Torch.cs

+75
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Runtime.InteropServices;
1010
using System.Text;
1111
using System.Text.RegularExpressions;
12+
using TorchSharp.Modules;
1213
using TorchSharp.PInvoke;
1314
using static TorchSharp.PInvoke.NativeMethods;
1415

@@ -415,6 +416,80 @@ public static void vector_to_parameters(Tensor vec, IEnumerable<Modules.Paramete
415416
CheckForErrors();
416417
}
417418
}
419+
420+
/// <summary>
421+
/// Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
422+
/// </summary>
423+
/// <param name="conv_w">Convolutional weight.</param>
424+
/// <param name="conv_b">Convolutional bias.</param>
425+
/// <param name="bn_rm">BatchNorm running mean.</param>
426+
/// <param name="bn_rv">BatchNorm running variance.</param>
427+
/// <param name="bn_eps">BatchNorm epsilon.</param>
428+
/// <param name="bn_w">BatchNorm weight.</param>
429+
/// <param name="bn_b">BatchNorm bias.</param>
430+
/// <param name="transpose">If <c>true</c>, transpose the conv weight. Defaults to <c>false</c>.</param>
431+
/// <returns>Fused convolutional weight and bias.</returns>
432+
public static (Parameter weight, Parameter bias) fuse_conv_bn_weights(
433+
Tensor conv_w, Tensor? conv_b,
434+
Tensor bn_rm, Tensor bn_rv, double bn_eps,
435+
Tensor? bn_w, Tensor? bn_b,
436+
bool transpose = false)
437+
{
438+
using var scope = NewDisposeScope();
439+
440+
var conv_weight_dtype = conv_w.dtype;
441+
var conv_bias_dtype = conv_b?.dtype ?? conv_weight_dtype;
442+
conv_b ??= zeros_like(bn_rm);
443+
bn_w ??= ones_like(bn_rm);
444+
bn_b ??= zeros_like(bn_rm);
445+
var shape = conv_w.shape.Select(_ => 1L).ToArray();
446+
if (transpose)
447+
shape[1] = -1;
448+
else
449+
shape[0] = -1;
450+
451+
var bn_var_rsqrt = rsqrt(bn_rv + bn_eps);
452+
var fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape))
453+
.to(conv_weight_dtype);
454+
var fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b)
455+
.to(conv_bias_dtype);
456+
457+
var weight = new Parameter(fused_conv_w, conv_w.requires_grad);
458+
var bias = new Parameter(fused_conv_b, conv_b.requires_grad);
459+
460+
return scope.MoveToOuter(weight, bias);
461+
}
462+
463+
/// <summary>
464+
/// Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
465+
/// </summary>
466+
/// <param name="linear_w">Linear weight.</param>
467+
/// <param name="linear_b">Linear bias.</param>
468+
/// <param name="bn_rm">BatchNorm running mean.</param>
469+
/// <param name="bn_rv">BatchNorm running variance.</param>
470+
/// <param name="bn_eps">BatchNorm epsilon.</param>
471+
/// <param name="bn_w">BatchNorm weight.</param>
472+
/// <param name="bn_b">BatchNorm bias.</param>
473+
/// <returns>Fused linear weight and bias.</returns>
474+
public static (Parameter weight, Parameter bias) fuse_linear_bn_weights(
475+
Tensor linear_w, Tensor? linear_b,
476+
Tensor bn_rm, Tensor bn_rv, double bn_eps,
477+
Tensor bn_w, Tensor bn_b)
478+
{
479+
using var scope = NewDisposeScope();
480+
481+
linear_b ??= zeros_like(bn_rm);
482+
483+
var bn_scale = bn_w * rsqrt(bn_rv + bn_eps);
484+
485+
var fused_w = linear_w * bn_scale.unsqueeze(-1);
486+
var fused_b = (linear_b - bn_rm) * bn_scale + bn_b;
487+
488+
var weight = new Parameter(fused_w, linear_w.requires_grad);
489+
var bias = new Parameter(fused_b, linear_b.requires_grad);
490+
491+
return scope.MoveToOuter(weight, bias);
492+
}
418493
}
419494
}
420495

test/TorchSharpTest/TestTorchSharp.cs

+107
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22

3+
using System;
34
using System.Collections.Generic;
5+
using System.Linq.Expressions;
6+
using System.Reflection;
47
using Xunit;
58

69
using static TorchSharp.torch;
@@ -233,6 +236,110 @@ public void UtilsVtoP()
233236
Assert.Equal(data, data1);
234237
}
235238

239+
[Fact]
240+
public void UtilsFusion()
241+
{
242+
static void SetRandomParameter<T>(
243+
T module,
244+
Expression<Func<T, Modules.Parameter>> parameterProperty)
245+
{
246+
var propertyExpression = (MemberExpression)parameterProperty.Body;
247+
var property = (PropertyInfo)propertyExpression.Member;
248+
var parameter = (Modules.Parameter)property.GetValue(module)!;
249+
var randomTensor = rand_like(
250+
parameter,
251+
parameter.dtype,
252+
parameter.device) * 100;
253+
var newParameter = new Modules.Parameter(randomTensor, parameter.requires_grad);
254+
property.SetValue(module, newParameter);
255+
}
256+
257+
static void SetRandomTensor<T>(
258+
T module,
259+
Expression<Func<T, Tensor>> tensorProperty)
260+
{
261+
var propertyExpression = (MemberExpression)tensorProperty.Body;
262+
var property = (PropertyInfo)propertyExpression.Member;
263+
var tensor = (Tensor)property.GetValue(module)!;
264+
var newTensor = rand_like(
265+
tensor,
266+
tensor.dtype,
267+
tensor.device,
268+
tensor.requires_grad) * 100;
269+
property.SetValue(module, newTensor);
270+
}
271+
272+
static void AssertRelativelyEqual(
273+
Tensor expected, Tensor actual, double tolerance = 1e-5)
274+
{
275+
Assert.Equal(expected.size(), actual.size());
276+
var difference = (expected - actual) / expected;
277+
var maxDifference = (double)difference.abs().max();
278+
Assert.InRange(maxDifference, -tolerance, tolerance);
279+
}
280+
281+
{
282+
// linear
283+
var x = rand(new long[] { 20, 20 }) * 100;
284+
285+
var linear = nn.Linear(20, 5);
286+
linear.eval();
287+
SetRandomParameter(linear, x => x.weight!);
288+
SetRandomParameter(linear, x => x.bias!);
289+
290+
var batchNorm1d = nn.BatchNorm1d(5, eps: 1);
291+
batchNorm1d.eval();
292+
SetRandomParameter(batchNorm1d, x => x.weight!);
293+
SetRandomParameter(batchNorm1d, x => x.bias!);
294+
SetRandomTensor(batchNorm1d, x => x.running_mean!);
295+
SetRandomTensor(batchNorm1d, x => x.running_var!);
296+
297+
(var weight, var bias) = nn.utils.fuse_linear_bn_weights(
298+
linear.weight!, linear.bias,
299+
batchNorm1d.running_mean!, batchNorm1d.running_var!,
300+
bn_eps: 1, batchNorm1d.weight!, batchNorm1d.bias!);
301+
302+
var newLinear = nn.Linear(20, 5);
303+
newLinear.eval();
304+
newLinear.weight = weight;
305+
newLinear.bias = bias;
306+
307+
AssertRelativelyEqual(
308+
batchNorm1d.call(linear.call(x)),
309+
newLinear.call(x));
310+
}
311+
312+
{
313+
// conv
314+
var x = rand(new long[] { 20, 20, 20, 20 }) * 100;
315+
var conv = nn.Conv2d(20, 5, 3);
316+
conv.eval();
317+
SetRandomParameter(conv, x => x.weight!);
318+
SetRandomParameter(conv, x => x.bias!);
319+
320+
var batchNorm2d = nn.BatchNorm2d(5, eps: 13);
321+
batchNorm2d.eval();
322+
SetRandomParameter(batchNorm2d, x => x.weight!);
323+
SetRandomParameter(batchNorm2d, x => x.bias!);
324+
SetRandomTensor(batchNorm2d, x => x.running_mean!);
325+
SetRandomTensor(batchNorm2d, x => x.running_var!);
326+
327+
(var weight, var bias) = nn.utils.fuse_conv_bn_weights(
328+
conv.weight!, conv.bias,
329+
batchNorm2d.running_mean!, batchNorm2d.running_var!,
330+
bn_eps: 13, batchNorm2d.weight!, batchNorm2d.bias!);
331+
332+
var newConv = nn.Conv2d(20, 5, 3);
333+
newConv.eval();
334+
newConv.weight = weight;
335+
newConv.bias = bias;
336+
337+
AssertRelativelyEqual(
338+
batchNorm2d.call(conv.call(x)),
339+
newConv.call(x));
340+
}
341+
}
342+
236343
[Fact(Skip = "Intermittently fails")]
237344
public void AllowTF32()
238345
{

0 commit comments

Comments
 (0)