|
1 | 1 | // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
|
2 | 2 |
|
| 3 | +using System; |
3 | 4 | using System.Collections.Generic;
|
| 5 | +using System.Linq.Expressions; |
| 6 | +using System.Reflection; |
4 | 7 | using Xunit;
|
5 | 8 |
|
6 | 9 | using static TorchSharp.torch;
|
@@ -233,6 +236,110 @@ public void UtilsVtoP()
|
233 | 236 | Assert.Equal(data, data1);
|
234 | 237 | }
|
235 | 238 |
|
| 239 | + [Fact] |
| 240 | + public void UtilsFusion() |
| 241 | + { |
| 242 | + static void SetRandomParameter<T>( |
| 243 | + T module, |
| 244 | + Expression<Func<T, Modules.Parameter>> parameterProperty) |
| 245 | + { |
| 246 | + var propertyExpression = (MemberExpression)parameterProperty.Body; |
| 247 | + var property = (PropertyInfo)propertyExpression.Member; |
| 248 | + var parameter = (Modules.Parameter)property.GetValue(module)!; |
| 249 | + var randomTensor = rand_like( |
| 250 | + parameter, |
| 251 | + parameter.dtype, |
| 252 | + parameter.device) * 100; |
| 253 | + var newParameter = new Modules.Parameter(randomTensor, parameter.requires_grad); |
| 254 | + property.SetValue(module, newParameter); |
| 255 | + } |
| 256 | + |
| 257 | + static void SetRandomTensor<T>( |
| 258 | + T module, |
| 259 | + Expression<Func<T, Tensor>> tensorProperty) |
| 260 | + { |
| 261 | + var propertyExpression = (MemberExpression)tensorProperty.Body; |
| 262 | + var property = (PropertyInfo)propertyExpression.Member; |
| 263 | + var tensor = (Tensor)property.GetValue(module)!; |
| 264 | + var newTensor = rand_like( |
| 265 | + tensor, |
| 266 | + tensor.dtype, |
| 267 | + tensor.device, |
| 268 | + tensor.requires_grad) * 100; |
| 269 | + property.SetValue(module, newTensor); |
| 270 | + } |
| 271 | + |
| 272 | + static void AssertRelativelyEqual( |
| 273 | + Tensor expected, Tensor actual, double tolerance = 1e-5) |
| 274 | + { |
| 275 | + Assert.Equal(expected.size(), actual.size()); |
| 276 | + var difference = (expected - actual) / expected; |
| 277 | + var maxDifference = (double)difference.abs().max(); |
| 278 | + Assert.InRange(maxDifference, -tolerance, tolerance); |
| 279 | + } |
| 280 | + |
| 281 | + { |
| 282 | + // linear |
| 283 | + var x = rand(new long[] { 20, 20 }) * 100; |
| 284 | + |
| 285 | + var linear = nn.Linear(20, 5); |
| 286 | + linear.eval(); |
| 287 | + SetRandomParameter(linear, x => x.weight!); |
| 288 | + SetRandomParameter(linear, x => x.bias!); |
| 289 | + |
| 290 | + var batchNorm1d = nn.BatchNorm1d(5, eps: 1); |
| 291 | + batchNorm1d.eval(); |
| 292 | + SetRandomParameter(batchNorm1d, x => x.weight!); |
| 293 | + SetRandomParameter(batchNorm1d, x => x.bias!); |
| 294 | + SetRandomTensor(batchNorm1d, x => x.running_mean!); |
| 295 | + SetRandomTensor(batchNorm1d, x => x.running_var!); |
| 296 | + |
| 297 | + (var weight, var bias) = nn.utils.fuse_linear_bn_weights( |
| 298 | + linear.weight!, linear.bias, |
| 299 | + batchNorm1d.running_mean!, batchNorm1d.running_var!, |
| 300 | + bn_eps: 1, batchNorm1d.weight!, batchNorm1d.bias!); |
| 301 | + |
| 302 | + var newLinear = nn.Linear(20, 5); |
| 303 | + newLinear.eval(); |
| 304 | + newLinear.weight = weight; |
| 305 | + newLinear.bias = bias; |
| 306 | + |
| 307 | + AssertRelativelyEqual( |
| 308 | + batchNorm1d.call(linear.call(x)), |
| 309 | + newLinear.call(x)); |
| 310 | + } |
| 311 | + |
| 312 | + { |
| 313 | + // conv |
| 314 | + var x = rand(new long[] { 20, 20, 20, 20 }) * 100; |
| 315 | + var conv = nn.Conv2d(20, 5, 3); |
| 316 | + conv.eval(); |
| 317 | + SetRandomParameter(conv, x => x.weight!); |
| 318 | + SetRandomParameter(conv, x => x.bias!); |
| 319 | + |
| 320 | + var batchNorm2d = nn.BatchNorm2d(5, eps: 13); |
| 321 | + batchNorm2d.eval(); |
| 322 | + SetRandomParameter(batchNorm2d, x => x.weight!); |
| 323 | + SetRandomParameter(batchNorm2d, x => x.bias!); |
| 324 | + SetRandomTensor(batchNorm2d, x => x.running_mean!); |
| 325 | + SetRandomTensor(batchNorm2d, x => x.running_var!); |
| 326 | + |
| 327 | + (var weight, var bias) = nn.utils.fuse_conv_bn_weights( |
| 328 | + conv.weight!, conv.bias, |
| 329 | + batchNorm2d.running_mean!, batchNorm2d.running_var!, |
| 330 | + bn_eps: 13, batchNorm2d.weight!, batchNorm2d.bias!); |
| 331 | + |
| 332 | + var newConv = nn.Conv2d(20, 5, 3); |
| 333 | + newConv.eval(); |
| 334 | + newConv.weight = weight; |
| 335 | + newConv.bias = bias; |
| 336 | + |
| 337 | + AssertRelativelyEqual( |
| 338 | + batchNorm2d.call(conv.call(x)), |
| 339 | + newConv.call(x)); |
| 340 | + } |
| 341 | + } |
| 342 | + |
236 | 343 | [Fact(Skip = "Intermittently fails")]
|
237 | 344 | public void AllowTF32()
|
238 | 345 | {
|
|
0 commit comments