Skip to content

Commit 53ff2cf

Browse files
ts version update (dotnet#6419)
1 parent 61b1fa5 commit 53ff2cf

29 files changed

+260
-184
lines changed

docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
<OutputType>Exe</OutputType>
55
<TargetFramework>netcoreapp3.1</TargetFramework>
66
<CopyLocalLockFileAssemblies>false</CopyLocalLockFileAssemblies>
7+
<NoWarn>$(NoWarn);MSB3270</NoWarn>
78
</PropertyGroup>
89

910
<ItemGroup>

eng/Versions.props

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
<TensorflowDotNETVersion>0.20.1</TensorflowDotNETVersion>
4949
<TensorFlowMajorVersion>2</TensorFlowMajorVersion>
5050
<TensorFlowVersion>2.3.1</TensorFlowVersion>
51-
<TorchSharpVersion>0.96.7</TorchSharpVersion>
51+
<TorchSharpVersion>0.98.1</TorchSharpVersion>
5252
<LibTorchVersion>1.11.0.1</LibTorchVersion>
5353
<!-- Build/infrastructure Dependencies -->
5454
<CodecovVersion>1.12.4</CodecovVersion>

src/Microsoft.ML.AutoML.Interactive/Microsoft.ML.AutoML.Interactive.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
<PropertyGroup>
44
<TargetFramework>net6.0</TargetFramework>
55
<IsPackable>false</IsPackable>
6+
<NoWarn>$(NoWarn);MSB3270</NoWarn>
67
</PropertyGroup>
78

89
<ItemGroup>

src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
<PackageDescription>ML.NET AutoML: Optimizes an ML pipeline for your dataset, by automatically locating the best feature engineering, model, and hyperparameters</PackageDescription>
88
<TargetsForTfmSpecificBuildOutput>$(TargetsForTfmSpecificBuildOutput);CopyProjectReferencesToPackage</TargetsForTfmSpecificBuildOutput>
99

10-
<!--
10+
<!--
1111
1591: Documentation warnings
1212
NU5100: Warning that gets triggered because a .dll is not placed under lib folder on package. This is by design as we want AutoML Interactive to be under interactive-extensions folder.
1313
-->
14-
<NoWarn>$(NoWarn);1591;NU5100</NoWarn>
14+
<NoWarn>$(NoWarn);1591;NU5100;MSB3270</NoWarn>
1515
<TargetsForTfmSpecificContentInPackage>$(TargetsForTfmSpecificContentInPackage);AddAutoMLInteractiveToInteractiveExtensionsFolder</TargetsForTfmSpecificContentInPackage>
16-
16+
1717
</PropertyGroup>
1818

1919
<!-- The following properties are set to package AutoML Interactive with the AutoML nuget package. If AutoML Interactive undergoes TFM or dependency changes, we need to update the TargetFramework passed in below-->
@@ -28,7 +28,7 @@
2828
<TfmSpecificPackageFile Include="@(_ItemsToIncludeForInteractive)" />
2929
</ItemGroup>
3030
</Target>
31-
31+
3232
<ItemGroup>
3333
<ProjectReference Include="..\..\tools-local\Microsoft.ML.AutoML.SourceGenerator\Microsoft.ML.AutoML.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
3434
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj">

src/Microsoft.ML.CodeGenerator/Microsoft.ML.CodeGenerator.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
<TargetFramework>netstandard2.0</TargetFramework>
66
<IncludeInPackage>Microsoft.ML.CodeGenerator</IncludeInPackage>
77
<PackageDescription>ML.NET Code Generator</PackageDescription>
8+
<NoWarn>$(NoWarn);MSB3270</NoWarn>
89
</PropertyGroup>
910

1011
<ItemGroup>

src/Microsoft.ML.TorchSharp/Microsoft.ML.TorchSharp.csproj

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
<PropertyGroup>
55
<TargetFramework>netstandard2.0</TargetFramework>
6-
<NoWarn>$(NoWarn);CS8002</NoWarn>
6+
<NoWarn>$(NoWarn);CS8002;MSB3270</NoWarn>
77
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
88
<PackageDescription>Microsoft.ML.TorchSharp contains ML.NET integration of TorchSharp.</PackageDescription>
9+
<PlatformTarget>AnyCPU</PlatformTarget>
910
</PropertyGroup>
1011

1112
<ItemGroup>

src/Microsoft.ML.TorchSharp/NasBert/BaseModule.cs

-21
This file was deleted.

src/Microsoft.ML.TorchSharp/NasBert/Models/BaseHead.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Microsoft.ML.TorchSharp.NasBert.Models
1111
{
12-
internal abstract class BaseHead : BaseModule
12+
internal abstract class BaseHead : torch.nn.Module
1313
{
1414
protected BaseHead(string name) : base(name)
1515
{

src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs

+8-43
Original file line numberDiff line numberDiff line change
@@ -11,68 +11,33 @@
1111

1212
namespace Microsoft.ML.TorchSharp.NasBert.Models
1313
{
14-
internal abstract class BaseModel : BaseModule
14+
internal abstract class BaseModel : torch.nn.Module<torch.Tensor, torch.Tensor, torch.Tensor>
1515
{
1616
protected readonly NasBertTrainer.Options Options;
1717
public BertTaskType HeadType => Options.TaskType;
1818

19-
protected readonly TransformerEncoder Encoder;
19+
//public ModelType EncoderType => Options.ModelType;
2020

2121
#pragma warning disable CA1024 // Use properties where appropriate: Modules should be fields in TorchSharp
22-
public TransformerEncoder GetEncoder() => Encoder;
22+
public abstract TransformerEncoder GetEncoder();
2323

2424
public abstract BaseHead GetHead();
2525
#pragma warning restore CA1024 // Use properties where appropriate
2626

27-
protected BaseModel(NasBertTrainer.Options options, int padIndex, int symbolsCount)
27+
protected BaseModel(TextClassificationTrainer.Options options)
2828
: base(nameof(BaseModel))
2929
{
3030
Options = options ?? throw new ArgumentNullException(nameof(options));
31-
Encoder = new TransformerEncoder(
32-
paddingIdx: padIndex,
33-
vocabSize: symbolsCount,
34-
dropout: Options.Dropout,
35-
attentionDropout: Options.AttentionDropout,
36-
activationDropout: Options.ActivationDropout,
37-
activationFn: Options.ActivationFunction,
38-
dynamicDropout: Options.DynamicDropout,
39-
maxSeqLen: Options.MaxSequenceLength,
40-
embedSize: Options.EmbeddingDim,
41-
arches: Options.Arches?.ToList(),
42-
numSegments: 0,
43-
encoderNormalizeBefore: Options.EncoderNormalizeBefore,
44-
numEncoderLayers: Options.EncoderLayers,
45-
applyBertInit: true,
46-
freezeTransfer: Options.FreezeTransfer);
47-
}
48-
49-
protected void Initialize()
50-
{
51-
if (Options.FreezeEncoder)
52-
{
53-
ModelUtils.FreezeModuleParams(Encoder);
54-
}
5531
}
5632

5733
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
58-
public new abstract torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null);
59-
60-
/// <summary>
61-
/// Run only Encoder and return features.
62-
/// </summary>
63-
protected torch.Tensor ExtractFeatures(torch.Tensor srcTokens)
64-
{
65-
return Encoder.forward(srcTokens, null, null);
66-
}
34+
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null)
35+
=> throw new NotImplementedException();
6736

6837
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
69-
public override void train()
38+
public override void train(bool train = true)
7039
{
71-
base.train();
72-
if (!Options.LayerNormTraining)
73-
{
74-
Encoder.CloseLayerNormTraining();
75-
}
40+
base.train(train);
7641
}
7742
}
7843
}

src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertModel.cs

+54-4
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,85 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Linq;
78
using System.Text;
9+
using Microsoft.ML.TorchSharp.Utils;
810
using TorchSharp;
911

1012
namespace Microsoft.ML.TorchSharp.NasBert.Models
1113
{
12-
internal sealed class NasBertModel : BaseModel
14+
internal class NasBertModel : BaseModel
1315
{
1416
private readonly PredictionHead _predictionHead;
1517

1618
public override BaseHead GetHead() => _predictionHead;
19+
public override TransformerEncoder GetEncoder() => Encoder;
1720

18-
public NasBertModel(NasBertTrainer.Options options, int padIndex, int symbolsCount, int numClasses)
19-
: base(options, padIndex, symbolsCount)
21+
protected readonly TransformerEncoder Encoder;
22+
23+
public NasBertModel(TextClassificationTrainer.Options options, int padIndex, int symbolsCount, int numClasses)
24+
: base(options)
2025
{
2126
_predictionHead = new PredictionHead(
2227
inputDim: Options.EncoderOutputDim,
2328
numClasses: numClasses,
2429
dropoutRate: Options.PoolerDropout);
30+
31+
Encoder = new TransformerEncoder(
32+
paddingIdx: padIndex,
33+
vocabSize: symbolsCount,
34+
dropout: Options.Dropout,
35+
attentionDropout: Options.AttentionDropout,
36+
activationDropout: Options.ActivationDropout,
37+
activationFn: Options.ActivationFunction,
38+
dynamicDropout: Options.DynamicDropout,
39+
maxSeqLen: Options.MaxSequenceLength,
40+
embedSize: Options.EmbeddingDim,
41+
arches: Options.Arches?.ToList(),
42+
numSegments: 0,
43+
encoderNormalizeBefore: Options.EncoderNormalizeBefore,
44+
numEncoderLayers: Options.EncoderLayers,
45+
applyBertInit: true,
46+
freezeTransfer: Options.FreezeTransfer);
47+
2548
Initialize();
2649
RegisterComponents();
2750
}
2851

2952
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
30-
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null)
53+
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor mask = null)
3154
{
3255
using var disposeScope = torch.NewDisposeScope();
3356
var x = ExtractFeatures(srcTokens);
3457
x = _predictionHead.forward(x);
3558
return x.MoveToOuterDisposeScope();
3659
}
60+
61+
protected void Initialize()
62+
{
63+
if (Options.FreezeEncoder)
64+
{
65+
ModelUtils.FreezeModuleParams(Encoder);
66+
}
67+
}
68+
69+
/// <summary>
70+
/// Run only Encoder and return features.
71+
/// </summary>
72+
protected torch.Tensor ExtractFeatures(torch.Tensor srcTokens)
73+
{
74+
return Encoder.forward(srcTokens, null, null);
75+
}
76+
77+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
78+
public override void train(bool train = true)
79+
{
80+
base.train(train);
81+
if (!Options.LayerNormTraining)
82+
{
83+
Encoder.CloseLayerNormTraining();
84+
}
85+
}
86+
3787
}
3888
}

src/Microsoft.ML.TorchSharp/NasBert/Models/PredictionHead.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace Microsoft.ML.TorchSharp.NasBert.Models
1313
{
14-
internal sealed class PredictionHead : BaseHead
14+
internal sealed class PredictionHead : BaseHead, torch.nn.IModule<torch.Tensor, torch.Tensor>
1515
{
1616
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
1717
private readonly Sequential Classifier;
@@ -34,7 +34,7 @@ public PredictionHead(int inputDim, int numClasses, double dropoutRate)
3434
}
3535

3636
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp")]
37-
public override torch.Tensor forward(torch.Tensor features)
37+
public torch.Tensor forward(torch.Tensor features)
3838
{
3939
// TODO: try whitening-like techniques
4040
// take <s> token (equiv. to [CLS])

src/Microsoft.ML.TorchSharp/NasBert/Models/TransformerEncoder.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
namespace Microsoft.ML.TorchSharp.NasBert.Models
1818
{
19-
internal sealed class TransformerEncoder : BaseModule
19+
internal sealed class TransformerEncoder : torch.nn.Module, torch.nn.IModule<torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor>
2020
{
2121
#pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format Have to match TorchSharp model
2222

@@ -44,8 +44,8 @@ internal sealed class TransformerEncoder : BaseModule
4444
private readonly LayerNorm EmbeddingLayerNorm;
4545
private readonly EmbedTransfer EmbedTransfer;
4646
private readonly Dropout DropoutLayer;
47-
private readonly ModuleList Layers;
48-
private readonly ModuleList HiddenTransferList;
47+
private readonly ModuleList<TransformerCell> Layers;
48+
private readonly ModuleList<torch.nn.Module> HiddenTransferList;
4949
#pragma warning restore MSML_PrivateFieldName // Private field name not in: _camelCase format
5050

5151
public Parameter TokenEmbeddingMatrix => TokenEmbedding.weight;
@@ -116,9 +116,9 @@ public TransformerEncoder(
116116
activationFn,
117117
addBiasKv,
118118
addZeroAttention,
119-
dynamicDropout) as torch.nn.Module)
119+
dynamicDropout))
120120
.ToArray();
121-
Layers = new ModuleList(layers);
121+
Layers = new ModuleList<TransformerCell>(layers);
122122

123123
var blockPerLayer = numEncoderLayers / DistillBlocks;
124124
HiddenSizePerBlock = CheckBlockHiddenSize(blockPerLayer);
@@ -129,7 +129,7 @@ public TransformerEncoder(
129129
.Select(i => new HiddenTransferDiscrete(hiddenSizePerBlockExtend[i],
130130
hiddenSizePerBlockExtend[i + 1]) as torch.nn.Module)
131131
.ToArray();
132-
HiddenTransferList = new ModuleList(hiddenTransferList);
132+
HiddenTransferList = new ModuleList<torch.nn.Module>(hiddenTransferList);
133133

134134
if (freezeEmbeddings)
135135
{

src/Microsoft.ML.TorchSharp/NasBert/Modules/ActivationFunction.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
namespace Microsoft.ML.TorchSharp.NasBert.Modules
1111
{
1212

13-
internal sealed class ActivationFunction : BaseModule
13+
internal sealed class ActivationFunction : torch.nn.Module<torch.Tensor, torch.Tensor>
1414
{
15-
private readonly torch.nn.Module _function;
15+
private readonly torch.nn.Module<torch.Tensor, torch.Tensor> _function;
1616

1717
public ActivationFunction(string name) : base(name)
1818
{
@@ -43,7 +43,7 @@ public override string GetName()
4343
/// See https://arxiv.org/pdf/1606.08415.pdf:
4444
/// y = 0.5x * (1 + tanh(sqrt(2/Pi) * (x + 0.044715 x^3)))
4545
/// </summary>
46-
public class GeLUFast : torch.nn.Module
46+
public class GeLUFast : torch.nn.Module<torch.Tensor, torch.Tensor>
4747
{
4848
private readonly double _alpha = Math.Sqrt(2 / Math.PI);
4949
private readonly double _beta = 0.044715;

src/Microsoft.ML.TorchSharp/NasBert/Modules/ConvSeparable.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace Microsoft.ML.TorchSharp.NasBert.Modules
1313
{
14-
internal sealed class ConvSeparable : BaseModule
14+
internal sealed class ConvSeparable : torch.nn.Module<torch.Tensor, torch.Tensor>
1515
{
1616
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Have to match TorchSharp model.")]
1717
private readonly Sequential Conv;

src/Microsoft.ML.TorchSharp/NasBert/Modules/EmbedTransfer.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,26 @@
1212

1313
namespace Microsoft.ML.TorchSharp.NasBert.Modules
1414
{
15-
internal abstract class EmbedTransfer : BaseModule
15+
internal abstract class EmbedTransfer : torch.nn.Module<torch.Tensor, int, torch.Tensor>
1616
{
1717
protected EmbedTransfer(string name) : base(name) { }
1818

1919
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
20-
public abstract torch.Tensor forward(torch.Tensor x, int hiddenSize);
20+
public override torch.Tensor forward(torch.Tensor x, int hiddenSize) => throw new NotImplementedException();
2121
}
2222

2323
internal sealed class EmbedTransferNonDiscrete : EmbedTransfer
2424
{
2525
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:not in _camelCase format", Justification = "Need to match TorchSharp.")]
26-
private readonly ModuleList HiddenTransfer;
26+
private readonly ModuleList<Linear> HiddenTransfer;
2727

2828
public EmbedTransferNonDiscrete() : base(nameof(EmbedTransferNonDiscrete))
2929
{
3030
//var hiddenTransfer = SearchSpace.HiddenSizeChoices[Range.EndAt(Index.FromEnd(1))]
3131
var hiddenTransfer = SearchSpace.HiddenSizeChoices.Where((source, index) => index != SearchSpace.HiddenSizeChoices.Length - 1)
32-
.Select(hidden => torch.nn.Linear(hidden, SearchSpace.HiddenSizeChoices[SearchSpace.HiddenSizeChoices.Length - 1]) as torch.nn.Module)
32+
.Select(hidden => torch.nn.Linear(hidden, SearchSpace.HiddenSizeChoices[SearchSpace.HiddenSizeChoices.Length - 1]))
3333
.ToArray();
34-
HiddenTransfer = new ModuleList(hiddenTransfer);
34+
HiddenTransfer = new ModuleList<Linear>(hiddenTransfer);
3535
RegisterComponents();
3636
}
3737

0 commit comments

Comments
 (0)