Skip to content

Commit 6bb7703

Browse files
Merge branch 'main' into unit
2 parents 2d2fcfc + b2bb7e8 commit 6bb7703

File tree

7 files changed

+72
-7
lines changed

7 files changed

+72
-7
lines changed

RELEASENOTES.md

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ __Issues fixed__:
2525
#1400 There may be an error in torchvision.transforms.GaussianBlur<br/>
2626
#1402 diagonal() has incorrect default<br/>
2727

28+
__API Changes__:
29+
30+
#1382: Add support for torch.nn.functional.normalize<br/>
31+
2832
# NuGet Version 0.103.1
2933

3034
__Breaking Changes__:

TorchSharp.sln

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "TorchSharp", "TorchSharp",
3636
EndProject
3737
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Debug\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{CAD9DB7F-3223-3324-884D-FA2381593DA7}"
3838
EndProject
39-
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Release\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{E4C0DBEE-0815-311B-9065-137BB50BD793}"
39+
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LibTorchSharp", "bin\obj\x64.Release\Native\LibTorchSharp\LibTorchSharp.vcxproj", "{BB811429-0DF1-3D22-B664-09C2F5A9E0AB}"
4040
EndProject
4141
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Native-Debug", "Native-Debug", "{CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}"
4242
ProjectSection(SolutionItems) = preProject
@@ -181,7 +181,7 @@ Global
181181
{42B45168-476D-4BFA-87B8-81A34E6295CD} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
182182
{567456AD-B026-4CB6-B98D-4FC930C90223} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
183183
{CAD9DB7F-3223-3324-884D-FA2381593DA7} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}
184-
{E4C0DBEE-0815-311B-9065-137BB50BD793} = {4DB9E84D-324C-408F-87A6-246E86205540}
184+
{BB811429-0DF1-3D22-B664-09C2F5A9E0AB} = {4DB9E84D-324C-408F-87A6-246E86205540}
185185
{CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
186186
{D8C60CD8-8429-45F2-A755-47B6CD10FDF8} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
187187
{4DB9E84D-324C-408F-87A6-246E86205540} = {CF2C1A9E-3A8A-4329-8A6E-7880C15AAC3D}

src/Native/LibTorchSharp/THSNN.h

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ EXPORT_API(NNModule) THSNN_custom_module(const char* name, Tensor(*forward)(Tens
3939

4040
// Normalization
4141

42+
EXPORT_API(Tensor) THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps);
4243
EXPORT_API(Tensor) THSNN_batch_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool training, const double momentum, const double eps);
4344
EXPORT_API(Tensor) THSNN_group_norm(const Tensor input, int64_t num_groups, const Tensor weight, const Tensor bias, const double eps);
4445
EXPORT_API(Tensor) THSNN_instance_norm(const Tensor input, const Tensor running_mean, const Tensor running_var, const Tensor weight, const Tensor bias, const bool use_input_stats, const double momentum, const double eps);

src/Native/LibTorchSharp/THSNormalization.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ Tensor THSNN_batch_norm(const Tensor input, Tensor running_mean, const Tensor ru
1313
CATCH_TENSOR(torch::batch_norm(*input, w, b, rm, rv, training, momentum, eps, false));
1414
}
1515

16+
Tensor THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps)
17+
{
18+
auto opts = torch::nn::functional::NormalizeFuncOptions()
19+
.p(p)
20+
.dim(dim)
21+
.eps(eps);
22+
CATCH_TENSOR(torch::nn::functional::normalize(*input, opts));
23+
}
24+
1625
Tensor THSNN_group_norm(const Tensor input, const int64_t num_groups, const Tensor weight, const Tensor bias, const double eps)
1726
{
1827
auto opts = torch::nn::functional::GroupNormFuncOptions(num_groups)

src/TorchSharp/NN/Normalization/Functional.cs

+22-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using static TorchSharp.PInvoke.NativeMethods;
44

5+
#nullable enable
56
namespace TorchSharp
67
{
78
public static partial class torch
@@ -10,10 +11,27 @@ public static partial class nn
1011
{
1112
public static partial class functional
1213
{
14+
/// <summary>
15+
/// Perform normalization of inputs over specified dimension.
16+
/// </summary>
17+
/// <param name="input">Input tensor of any shape.</param>
18+
/// <param name="p">the exponent value in the norm formulation</param>
19+
/// <param name="dim">the dimension to reduce</param>
20+
/// <param name="eps">small value to avoid division by zero</param>
21+
public static Tensor normalize(Tensor input, double p = 2.0, long dim = 1L, double eps = 1e-12)
22+
{
23+
var res = THSNN_normalize(
24+
input.Handle,
25+
p, dim, eps);
26+
if (res == IntPtr.Zero)
27+
torch.CheckForErrors();
28+
return new Tensor(res);
29+
}
30+
1331
/// <summary>
1432
/// Applies Batch Normalization for each channel across a batch of data.
1533
/// </summary>
16-
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor weight = null, Tensor bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
34+
public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor running_var, Tensor? weight = null, Tensor? bias = null, bool training = false, double momentum = 0.1, double eps = 1e-5)
1735
{
1836
var res = THSNN_batch_norm(
1937
input.Handle,
@@ -31,7 +49,7 @@ public static Tensor batch_norm(Tensor input, Tensor running_mean, Tensor runnin
3149
/// <summary>
3250
/// Applies Group Normalization for last certain number of dimensions.
3351
/// </summary>
34-
public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
52+
public static Tensor group_norm(Tensor input, long num_groups, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
3553
{
3654
var res = THSNN_group_norm(
3755
input.Handle,
@@ -47,7 +65,7 @@ public static Tensor group_norm(Tensor input, long num_groups, Tensor weight = n
4765
/// <summary>
4866
/// Applies Instance Normalization for each channel in each data sample in a batch.
4967
/// </summary>
50-
public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Tensor running_var = null, Tensor weight = null, Tensor bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
68+
public static Tensor instance_norm(Tensor input, Tensor? running_mean = null, Tensor? running_var = null, Tensor? weight = null, Tensor? bias = null, bool use_input_stats = true, double momentum = 0.1, double eps = 1e-5)
5169
{
5270
var res = THSNN_instance_norm(
5371
input.Handle,
@@ -65,7 +83,7 @@ public static Tensor instance_norm(Tensor input, Tensor running_mean = null, Ten
6583
/// <summary>
6684
/// Applies Layer Normalization for last certain number of dimensions.
6785
/// </summary>
68-
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor weight = null, Tensor bias = null, double eps = 1e-5)
86+
public static Tensor layer_norm(Tensor input, long[] normalized_shape, Tensor? weight = null, Tensor? bias = null, double eps = 1e-5)
6987
{
7088
IntPtr res;
7189
unsafe {

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

+3
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,9 @@ internal static extern IntPtr THSNN_custom_module(
512512
[DllImport("LibTorchSharp")]
513513
internal static extern IntPtr THSNN_Unflatten_ctor(long dim, IntPtr shape, long shape_len, out IntPtr pBoxedModule);
514514

515+
[DllImport("LibTorchSharp")]
516+
internal static extern IntPtr THSNN_normalize(IntPtr input, double p, long dim, double eps);
517+
515518
[DllImport("LibTorchSharp")]
516519
internal static extern IntPtr THSNN_batch_norm(IntPtr input, IntPtr running_mean, IntPtr running_var, IntPtr weight, IntPtr bias, [MarshalAs(UnmanagedType.U1)] bool training, double momentum, double eps);
517520

test/TorchSharpTest/NN.cs

+31-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
22
using System;
33
using System.Linq;
44
using System.Runtime.InteropServices;
@@ -5024,6 +5024,36 @@ private Tensor NormalizeTensor(Tensor x, Tensor x_mean, Tensor x_var, double eps
50245024
return (x - x_mean) / torch.sqrt(eps + x_var);
50255025
}
50265026

5027+
[Fact]
5028+
public void TestNormalizeFunc()
5029+
{
5030+
foreach (var device in TestUtils.AvailableDevices()) {
5031+
var x = torch.from_array(new double[]
5032+
{ -1.0786, 0.3455, 1.2929, 0.5030,
5033+
-0.2930, 1.0420, -0.1082, -0.2943,
5034+
-0.3989, -0.8311, 0.7103, -1.5878,
5035+
0.6331, 1.0106, 0.5128, -2.2565,
5036+
1.2044, -0.6916, -0.1242, 0.6808,
5037+
0.1672, 0.1105, -1.7364, 0.0669
5038+
}).reshape(3,2,4);
5039+
var y = torch.nn.functional.normalize(x);
5040+
Assert.Equal(x.shape, y.shape);
5041+
Assert.Equal(x.device_type, y.device_type);
5042+
5043+
var expected = torch.from_array(new double[]
5044+
{ -0.9650, 0.3147, 0.9965, 0.8631,
5045+
-0.2621, 0.9492, -0.0834, -0.5050,
5046+
-0.5331, -0.6352, 0.8108, -0.5755,
5047+
0.8460, 0.7724, 0.5853, -0.8178,
5048+
0.9905, -0.9875, -0.0713, 0.9952,
5049+
0.1375, 0.1577, -0.9975, 0.0978
5050+
}).reshape(3, 2, 4);
5051+
5052+
5053+
Assert.True(y.allclose(expected, rtol: 0.005, atol: 0.005));
5054+
}
5055+
}
5056+
50275057
[Fact]
50285058
public void TestBatchNormFunc()
50295059
{

0 commit comments

Comments
 (0)