Skip to content

Commit 97f7bbd

Browse files
Add GPT2 tokenizer setup and refactor transformer and tokens code (#56)
* Add gpt-tokenizer dependency. * Add function for using a standalone tokenizer with GPT2 instead of using the local task token rep. There seems to be a memory leak, and a refactoring is necessary. * Refactor code so that common functions among transformers are in a single file. Refactor token gemb file so that embedBatch is adapted to a use case where a tokenizer is available. * Refactored computePrediction and computeDecoder so that it's compatible with next N tokens prediction. * Remove unused test function and clean-up comments. * Clarify some TODOs and comments on the code. * Add node_modules/gpt-tokenizer to package-lock.json. * Clarify testing function and simplify mapToIdx and tokenizeAndMapToIdx.
1 parent ca04a1f commit 97f7bbd

File tree

12 files changed

+485
-398
lines changed

12 files changed

+485
-398
lines changed

animated-transformer/package-lock.json

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

animated-transformer/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"@codemirror/language": "^6.9.0",
2727
"@tensorflow/tfjs": "^4.20.0",
2828
"@tensorflow/tfjs-vis": "^1.5.1",
29+
"gpt-tokenizer": "2.8.1",
2930
"codemirror": "^6.0.1",
3031
"d3": "^7.8.0",
3132
"d3-color": "^3.1.0",

animated-transformer/src/app/web-colab/tiny-transformer-example/trainer-cell.worker.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@ import {
2929
} from './ailab';
3030
import {
3131
computeTransformer,
32-
transformerAccuracy,
3332
TransformerConfig,
34-
lastTokenCrossEntropyLoss,
3533
TransformerModel,
3634
VarTransformerParams,
3735
initDecoderParams,
38-
TransformerComputation,
3936
} from 'src/lib/transformer/transformer_gtensor';
37+
import {
38+
transformerAccuracy,
39+
lastTokenCrossEntropyLoss,
40+
} from 'src/lib/transformer/common_transformer';
4041
import {
4142
assignParams,
4243
deserializeParams,

animated-transformer/src/lib/tokens/token_gemb.spec.ts

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import {
2121
strSeqPrepFn,
2222
embed,
2323
prepareBasicTaskTokenRep,
24+
tokenizeAndMapToIdx,
25+
mapToIdx,
2426
embedBatch,
2527
expectedOutputSeqPrepFn,
2628
} from '../tokens/token_gemb';
@@ -56,8 +58,9 @@ describe('token_gemb', () => {
5658
const tokenEmbedding = new GTensor(tf.tensor([aEmb, bEmb, padEmb]), ['tokenId', 'inputRep']);
5759

5860
const seqsToEmbed = [['a', 'b', '[pad]', 'a'], ['a', 'b'], [], ['b'], ['a']];
61+
const seqsIdxs = mapToIdx(tokenRep.tokenToIdx, seqsToEmbed);
5962

60-
const seqEmb = embedBatch(tokenRep.tokenToIdx, tokenEmbedding, seqsToEmbed, {
63+
const seqEmb = embedBatch(tokenEmbedding, seqsIdxs, {
6164
paddingId: 2,
6265
padAt: 'start',
6366
dtype: 'int32',
@@ -82,8 +85,9 @@ describe('token_gemb', () => {
8285
const embeddings = new GTensor(tf.tensor([aEmb, bEmb, padEmb]), ['tokenId', 'inputRep']);
8386

8487
const seqsToEmbed = [['a', 'b', '[pad]', 'a'], ['a', 'b'], [], ['b'], ['a']];
88+
const seqsIdxs = mapToIdx(tokenRep.tokenToIdx, seqsToEmbed);
8589

86-
const seqEmb = embedBatch(tokenRep.tokenToIdx, embeddings, seqsToEmbed, {
90+
const seqEmb = embedBatch(embeddings, seqsIdxs, {
8791
paddingId: 2,
8892
padAt: 'end',
8993
dtype: 'int32',
@@ -160,4 +164,33 @@ describe('token_gemb', () => {
160164
expect(targetTokensOneHot.tensor.arraySync()).toEqual(expectedOutputArr);
161165
expect(targetTokensOneHot.dimNames).toEqual(['batch', 'pos', 'tokenId'])
162166
});
167+
it('Test tokenizeAndMapToIdx', () => {
168+
// Mock a tokenizer for testing tokenizeAndMapToIdx.
169+
function tokenize_fn_test(input: string): number[] {
170+
let output: number[] = [];
171+
for (let i = 0; i < input.length; i++) {
172+
if (input[i] == 'a')
173+
output = output.concat(0);
174+
else
175+
output = output.concat(1);
176+
}
177+
return output;
178+
};
179+
180+
const seqsToEmbed = ['aba', 'ab', '', 'b', 'a'];
181+
const seqsIdxs = tokenizeAndMapToIdx(tokenize_fn_test, seqsToEmbed);
182+
const expectedIdxs =
183+
[[0, 1, 0], [0, 1], [], [1], [0]];
184+
185+
expect(seqsIdxs).toEqual(expectedIdxs);
186+
});
187+
it('Test mapToIdx', () => {
188+
const tokens = ['a', 'b', '[pad]'];
189+
const tokenRep = prepareBasicTaskTokenRep(tokens);
190+
191+
const seqsToEmbed = [['a', 'b', '[pad]', 'a'], ['a', 'b'], [], ['b'], ['a']];
192+
const seqsIdxs = mapToIdx(tokenRep.tokenToIdx, seqsToEmbed);
193+
const expectedIdxs = [[0, 1, 2, 0], [0, 1], [], [1], [0]];
194+
expect(seqsIdxs).toEqual(expectedIdxs);
195+
});
163196
});

animated-transformer/src/lib/tokens/token_gemb.ts

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,30 @@ export function embed(
5757
return embeddedInput;
5858
}
5959

60-
// TODO: consider supporting padding string[][] ?
61-
// pad(inputs: string[][], config: {
62-
// paddingId: number;
63-
// padAt: 'start' | 'end';
64-
// dtype: tf.NumericDataType,
65-
// }) {
60+
// Maps tokens in string format to indexes.
61+
export function mapToIdx(
62+
tokenToIdx: { [token: string]: number },
63+
examples: string[][]
64+
): number[][] {
65+
return examples.map((example) => example.map((s) => tokenToIdx[s]));
66+
}
6667

67-
// }
68+
// TODO(@aliciafmachado): Merge this function with the one below
69+
// once we create a class to wrap the tokenization.
70+
export function tokenizeAndMapToIdx(
71+
tokenize_fn: (input: string) => number[],
72+
examples: string[]
73+
): number[][] {
74+
return examples.map((example) => tokenize_fn(example));
75+
}
6876

6977
// When batchSize is defined and batchSize > examples.length, then
7078
// padding-filled examples are added to the final output GTensor. When
7179
// batchSize < examples.length, examples is truncated to make the output
7280
// GTensor.
7381
export function embedBatch(
74-
tokenToIdx: { [token: string]: number },
7582
embeddings: GTensor<'tokenId' | 'inputRep'>,
76-
examples: string[][],
83+
examples: number[][],
7784
config: {
7885
paddingId: number;
7986
padAt: 'start' | 'end';
@@ -91,21 +98,17 @@ export function embedBatch(
9198
let maxInputLength = 0;
9299
if (!config.maxInputLength) {
93100
examples.forEach((l) => (maxInputLength = Math.max(l.length, maxInputLength)));
94-
examples.map((l) => l.map((s) => tokenToIdx[s]));
95101
} else {
96102
maxInputLength = config.maxInputLength;
97103
}
98104

99105
examples.forEach((example) => {
100106
if (example.length >= maxInputLength) {
101107
const tensor = tf.tensor1d(
102-
example.slice(0, maxInputLength).map((s) => tokenToIdx[s]),
108+
example.slice(0, maxInputLength),
103109
config.dtype
104110
);
105111
inputEmbList.push(tensor);
106-
// console.log(l)
107-
// console.log(l.map(s => this.tokenToIdx[s]))
108-
// console.log(tensor.dataSync())
109112
} else if (example.length === 0) {
110113
const tensor = tf.fill([maxInputLength], config.paddingId, config.dtype);
111114
inputEmbList.push(tensor);
@@ -116,7 +119,7 @@ export function embedBatch(
116119
: [[0, maxInputLength - example.length]];
117120
const tensor = tf.pad(
118121
tf.tensor1d(
119-
example.map((s) => tokenToIdx[s]),
122+
example,
120123
config.dtype
121124
),
122125
paddingLocation,
@@ -152,10 +155,15 @@ export type BasicTaskTokenRep = {
152155
spaceToken: string;
153156
// tokens is all tokens, including mask, pod, eos, etc
154157
tokens: string[];
158+
// remove below
155159
tokenToIdx: { [token: string]: number };
156-
idxToOneHot : {[tokenIdx: number]: number[]};
160+
idxToOneHot: { [tokenIdx: number]: number[] };
157161
};
158162

163+
// TODO(@aliciafmachado): token wrap class with the tokenize and untokenize fn?
164+
// make basictasktokenrep minimal and then add a wrapper class that creates the tokenToIdx and idxToOneHot.
165+
// This interface would be compatible with a tokenizer straight out-of-the-box.
166+
159167
// ----------------------------------------------------------------------------
160168
// Prepate the task representation in a vector space.
161169
// TODO: maybe this should be viewed as a task extension: i.e. Task --> Task
@@ -165,7 +173,7 @@ export function prepareBasicTaskTokenRep(baseVocab: string[]): BasicTaskTokenRep
165173
const padToken = '[PAD]';
166174
const eosToken = '[EOS]';
167175
const spaceToken = ' '
168-
const vocab = [ ...baseVocab, maskToken, padToken, eosToken, spaceToken];
176+
const vocab = [...baseVocab, maskToken, padToken, eosToken, spaceToken];
169177
const tokenToIdx: { [token: string]: number } = {};
170178
vocab.forEach((t, i) => (tokenToIdx[t] = i));
171179

@@ -175,7 +183,7 @@ export function prepareBasicTaskTokenRep(baseVocab: string[]): BasicTaskTokenRep
175183
// );
176184

177185
// TODO: Find a better place for the idxToOneHot lookup table
178-
const idxToOneHot : {[tokenIdx: number]: number[] } = {};
186+
const idxToOneHot: { [tokenIdx: number]: number[] } = {};
179187
const oneHotTokens = [tf.oneHot(tf.tensor1d(Object.values(tokenToIdx), 'int32'), baseVocab.length + 4).arraySync() as number[][]];
180188
Object.values(tokenToIdx).forEach((i) => (idxToOneHot[i] = oneHotTokens[0][i]));
181189
return {
@@ -217,10 +225,10 @@ export function strSeqPrepFn(
217225
options: { maxInputLength: number }
218226
): GTensor<'batch' | 'pos' | 'inputRep'> {
219227
const padTokenId = model.config.tokenRep.tokenToIdx[model.config.tokenRep.padToken];
228+
const inputSeqsInIdxs = mapToIdx(model.config.tokenRep.tokenToIdx, inputSeqs);
220229
const batchedInputEmb = embedBatch(
221-
model.config.tokenRep.tokenToIdx,
222230
model.params.tokenEmbedding,
223-
inputSeqs,
231+
inputSeqsInIdxs,
224232
{
225233
paddingId: padTokenId,
226234
padAt: 'start',
@@ -282,21 +290,21 @@ export function singleNextTokenIdxOutputPrepFn(
282290
}
283291

284292
// Returns the one Hot representation for each token of the expected output sequence for the provided input sequence
285-
export function expectedOutputSeqPrepFn(
293+
export function expectedOutputSeqPrepFn(
286294
model: { config: { tokenRep: BasicTaskTokenRep } },
287295
inputSeqs: string[][],
288296
expectedOutputs: string[][],
289297
): GTensor<'batch' | 'pos' | 'tokenId'> {
290298
// Compute Token rep for inputSeq
291299
const batchInputs = inputSeqs.map((inputSeq) => inputSeq.map((token) => model.config.tokenRep.tokenToIdx[token]))
292-
// Compute Token rep for inputSeq
300+
// Compute Token rep for inputSeq
293301
const expectedOutputSeq = expectedOutputs.map((outputToken) => model.config.tokenRep.tokenToIdx[outputToken[0]])
294302
// Shift input sequences to the right and add the corresponding target in "expectedOutputs" at the end of each sequence
295-
let shiftedInputs = batchInputs.map((x) => x.slice(1, ))
303+
let shiftedInputs = batchInputs.map((x) => x.slice(1,))
296304
const expectedOutputSeqIdx = expectedOutputSeq.map((y, index) => shiftedInputs[index].concat(y))
297305
const expectedOutputSeqOneHot = expectedOutputSeqIdx.map((sample) => sample.map((tidx) => model.config.tokenRep.idxToOneHot[tidx]))
298306
// TODO: We should probably be using a lookup function and storing the one-hot for every token in the GPU as a constant.
299-
return new GTensor(tf.tensor(expectedOutputSeqOneHot),['batch', 'pos', 'tokenId']);
307+
return new GTensor(tf.tensor(expectedOutputSeqOneHot), ['batch', 'pos', 'tokenId']);
300308
}
301309

302310

animated-transformer/src/lib/trainer/basic_transformer_trainer.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {
2525
splitGenerativeTaskTestSet,
2626
} from '../seqtasks/util';
2727
import { BasicTaskTokenRep, StrSeqPrepFn } from '../tokens/token_gemb';
28-
import { transformerAccuracy } from '../transformer/transformer_gtensor';
28+
import { transformerAccuracy, lastTokenCrossEntropyLoss } from '../transformer/common_transformer';
2929
import { TaskDatasetSplit, TrainState, TrainStateConfig } from './train_state';
3030
import { RandomStream, makeRandomStream } from '../random/random';
3131
// import { GTensorTree, GVariableTree } from 'src/lib/gtensor/gtensor_tree';
@@ -74,7 +74,7 @@ export function initTransformerTrainState(
7474
generator: RandomStream
7575
): tf.Scalar {
7676
const decoderComputation = transformer.computeTransformer(model, inputs, generator);
77-
const loss = transformer.lastTokenCrossEntropyLoss(model, decoderComputation, targets);
77+
const loss = lastTokenCrossEntropyLoss(model, decoderComputation, targets);
7878
return loss as tf.Scalar;
7979
}
8080

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright 2023 Google LLC. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
import { GTensor, makeTruncNormal } from '../gtensor/gtensor';
17+
import { causalMask } from './common_transformer';
18+
import * as tf from '@tensorflow/tfjs';
19+
import * as abtask from '../seqtasks/ab_task';
20+
import { embedBatch, mapToIdx, prepareBasicTaskTokenRep } from '../tokens/token_gemb';
21+
22+
describe('Common Transformer util types and functions', () => {
23+
it('AB task data prep', async () => {
24+
const inputRep = 2;
25+
const batchSize = 4;
26+
const task = new abtask.AorBisMaxTask({
27+
kind: 'AorBisMaxTask',
28+
id: 'an A or B is Max task',
29+
maxInputLen: 2,
30+
maxOutputLen: 2,
31+
genStateConfig: { seed: 0 },
32+
// Create a tokenEmbedding that also has [MASC] token & [PAD] token.
33+
// inputRepSize: inputRep,
34+
});
35+
const tokenRep = prepareBasicTaskTokenRep(task.baseVocab);
36+
const padTokenId = tokenRep.tokenToIdx[tokenRep.padToken];
37+
const embeddings = makeTruncNormal({
38+
tokenId: tokenRep.tokens.length,
39+
inputRep,
40+
});
41+
42+
const examples = task.exampleIter.takeOutN(4);
43+
const examplesIdxs = mapToIdx(tokenRep.tokenToIdx, examples.map((example) => example.input));
44+
const maskIdx = tokenRep.tokenToIdx[tokenRep.maskToken];
45+
46+
const batchedInputEmb = embedBatch(
47+
embeddings,
48+
examplesIdxs.map((example) => example.concat(maskIdx)),
49+
{ paddingId: padTokenId, padAt: 'start', dtype: 'int32' },
50+
);
51+
52+
expect(batchedInputEmb.gshape()).toEqual({
53+
batch: batchSize,
54+
// +1 for the appended [MASK] token to be predicted.
55+
pos: task.config.maxInputLen + 1,
56+
inputRep,
57+
});
58+
});
59+
60+
it('Compute masked self attention', () => {
61+
const exampleAffinities = new GTensor(
62+
tf.tensor([
63+
[
64+
[
65+
[0, 0, 0],
66+
[0, 0, 0],
67+
[0, 0, 0],
68+
],
69+
],
70+
]),
71+
['batch', 'heads', 'keyPos', 'queryPos'],
72+
);
73+
const masked = causalMask(exampleAffinities);
74+
75+
expect(masked.dimNames).toEqual(['batch', 'heads', 'keyPos', 'queryPos']);
76+
tf.test_util.expectArraysClose(masked.tensor.arraySync(), [
77+
[
78+
[
79+
[1, 0, 0],
80+
[0.5, 0.5, 0],
81+
[0.33, 0.33, 0.33],
82+
],
83+
],
84+
]);
85+
});
86+
});

0 commit comments

Comments
 (0)