|
4 | 4 |
|
5 | 5 | using System;
|
6 | 6 | using System.Collections.Generic;
|
| 7 | +using System.Linq; |
7 | 8 | using System.Text;
|
| 9 | +using Microsoft.ML.TorchSharp.Utils; |
8 | 10 | using TorchSharp;
|
9 | 11 |
|
10 | 12 | namespace Microsoft.ML.TorchSharp.NasBert.Models
|
11 | 13 | {
|
12 |
| - internal sealed class NasBertModel : BaseModel |
| 14 | + internal class NasBertModel : BaseModel |
13 | 15 | {
|
14 | 16 | private readonly PredictionHead _predictionHead;
|
15 | 17 |
|
16 | 18 | public override BaseHead GetHead() => _predictionHead;
|
| 19 | + public override TransformerEncoder GetEncoder() => Encoder; |
17 | 20 |
|
18 |
| - public NasBertModel(NasBertTrainer.Options options, int padIndex, int symbolsCount, int numClasses) |
19 |
| - : base(options, padIndex, symbolsCount) |
| 21 | + protected readonly TransformerEncoder Encoder; |
| 22 | + |
| 23 | + public NasBertModel(TextClassificationTrainer.Options options, int padIndex, int symbolsCount, int numClasses) |
| 24 | + : base(options) |
20 | 25 | {
|
21 | 26 | _predictionHead = new PredictionHead(
|
22 | 27 | inputDim: Options.EncoderOutputDim,
|
23 | 28 | numClasses: numClasses,
|
24 | 29 | dropoutRate: Options.PoolerDropout);
|
| 30 | + |
| 31 | + Encoder = new TransformerEncoder( |
| 32 | + paddingIdx: padIndex, |
| 33 | + vocabSize: symbolsCount, |
| 34 | + dropout: Options.Dropout, |
| 35 | + attentionDropout: Options.AttentionDropout, |
| 36 | + activationDropout: Options.ActivationDropout, |
| 37 | + activationFn: Options.ActivationFunction, |
| 38 | + dynamicDropout: Options.DynamicDropout, |
| 39 | + maxSeqLen: Options.MaxSequenceLength, |
| 40 | + embedSize: Options.EmbeddingDim, |
| 41 | + arches: Options.Arches?.ToList(), |
| 42 | + numSegments: 0, |
| 43 | + encoderNormalizeBefore: Options.EncoderNormalizeBefore, |
| 44 | + numEncoderLayers: Options.EncoderLayers, |
| 45 | + applyBertInit: true, |
| 46 | + freezeTransfer: Options.FreezeTransfer); |
| 47 | + |
25 | 48 | Initialize();
|
26 | 49 | RegisterComponents();
|
27 | 50 | }
|
28 | 51 |
|
29 | 52 | [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
|
30 |
| - public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null) |
| 53 | + public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor mask = null) |
31 | 54 | {
|
32 | 55 | using var disposeScope = torch.NewDisposeScope();
|
33 | 56 | var x = ExtractFeatures(srcTokens);
|
34 | 57 | x = _predictionHead.forward(x);
|
35 | 58 | return x.MoveToOuterDisposeScope();
|
36 | 59 | }
|
| 60 | + |
| 61 | + protected void Initialize() |
| 62 | + { |
| 63 | + if (Options.FreezeEncoder) |
| 64 | + { |
| 65 | + ModelUtils.FreezeModuleParams(Encoder); |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | + /// <summary> |
| 70 | + /// Run only Encoder and return features. |
| 71 | + /// </summary> |
| 72 | + protected torch.Tensor ExtractFeatures(torch.Tensor srcTokens) |
| 73 | + { |
| 74 | + return Encoder.forward(srcTokens, null, null); |
| 75 | + } |
| 76 | + |
| 77 | + [System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")] |
| 78 | + public override void train(bool train = true) |
| 79 | + { |
| 80 | + base.train(train); |
| 81 | + if (!Options.LayerNormTraining) |
| 82 | + { |
| 83 | + Encoder.CloseLayerNormTraining(); |
| 84 | + } |
| 85 | + } |
| 86 | + |
37 | 87 | }
|
38 | 88 | }
|
0 commit comments