Skip to content

Commit 190d545

Browse files
gericdongpattishin
andauthored
feat(aiplatform): add sample for Gen AI code model tuning (#3410)
* feat(aiplatform): add sample for Gen AI code model tuning * udpate comments to make it clear * fix: skip this test temporary until the service account permission issue is resolved * fix: skip another pipeline test temporary until the service account permission issue is resolved * fix(aiplatform): try to get the test to run after IAM role changed * fix(aiplatform): remove the skip for another pipeline test * fix: address review comments --------- Co-authored-by: Patti Shin <[email protected]>
1 parent 359506f commit 190d545

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2023 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(
20+
project,
21+
pipelineJobId,
22+
modelDisplayName,
23+
gcsOutputDirectory,
24+
location = 'europe-west4',
25+
datasetUri = 'gs://cloud-samples-data/ai-platform/generative_ai/sql_create_context.jsonl',
26+
trainSteps = 300
27+
) {
28+
// [START aiplatform_genai_code_model_tuning]
29+
/**
30+
* TODO(developer): Uncomment these variables before running the sample.\
31+
* (Not necessary if passing values as arguments)
32+
*/
33+
// const project = 'YOUR_PROJECT_ID';
34+
// const location = 'YOUR_PROJECT_LOCATION';
35+
const aiplatform = require('@google-cloud/aiplatform');
36+
const {PipelineServiceClient} = aiplatform.v1;
37+
38+
// Import the helper module for converting arbitrary protobuf.Value objects.
39+
const {helpers} = aiplatform;
40+
41+
// Specifies the location of the api endpoint
42+
const clientOptions = {
43+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
44+
};
45+
const model = 'code-bison@001';
46+
47+
const pipelineClient = new PipelineServiceClient(clientOptions);
48+
49+
async function tuneLLM() {
50+
// Configure the parent resource
51+
const parent = `projects/${project}/locations/${location}`;
52+
53+
const parameters = {
54+
train_steps: helpers.toValue(trainSteps),
55+
project: helpers.toValue(project),
56+
location: helpers.toValue('us-central1'),
57+
dataset_uri: helpers.toValue(datasetUri),
58+
large_model_reference: helpers.toValue(model),
59+
model_display_name: helpers.toValue(modelDisplayName),
60+
};
61+
62+
const runtimeConfig = {
63+
gcsOutputDirectory,
64+
parameterValues: parameters,
65+
};
66+
67+
const pipelineJob = {
68+
templateUri:
69+
'https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0',
70+
displayName: 'my-tuning-job',
71+
runtimeConfig,
72+
};
73+
74+
const createPipelineRequest = {
75+
parent,
76+
pipelineJob,
77+
pipelineJobId,
78+
};
79+
80+
const [response] = await pipelineClient.createPipelineJob(
81+
createPipelineRequest
82+
);
83+
84+
console.log('Tuning pipeline job:');
85+
console.log(`\tName: ${response.name}`);
86+
console.log(
87+
`\tCreate time: ${new Date(1970, 0, 1)
88+
.setSeconds(response.createTime.seconds)
89+
.toLocaleString()}`
90+
);
91+
console.log(`\tStatus: ${response.status}`);
92+
}
93+
94+
await tuneLLM();
95+
// [END aiplatform_genai_code_model_tuning]
96+
}
97+
98+
exports.tuneModel = main;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/* eslint-disable */
16+
17+
'use strict';
18+
19+
const {assert} = require('chai');
20+
const {describe, it} = require('mocha');
21+
const uuid = require('uuid');
22+
const sinon = require('sinon');
23+
24+
const projectId = process.env.CAIP_PROJECT_ID;
25+
const location = 'europe-west4';
26+
27+
const aiplatform = require('@google-cloud/aiplatform');
28+
const clientOptions = {
29+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
30+
};
31+
const pipelineClient = new aiplatform.v1.PipelineServiceClient(clientOptions);
32+
33+
const {tuneModel} = require('../code-model-tuning');
34+
35+
const timestampId = `${new Date()
36+
.toISOString()
37+
.replace(/(:|\.)/g, '-')
38+
.toLowerCase()}`;
39+
const pipelineJobName = `my-tuning-pipeline-${timestampId}`;
40+
const modelDisplayName = `my-tuned-model-${timestampId}`;
41+
const bucketName = 'ucaip-samples-europe-west4/training_pipeline_output';
42+
const bucketUri = `gs://${bucketName}/tune-model-nodejs`;
43+
44+
describe('Tune a code model', () => {
45+
const stubConsole = function () {
46+
sinon.stub(console, 'error');
47+
sinon.stub(console, 'log');
48+
};
49+
50+
const restoreConsole = function () {
51+
console.log.restore();
52+
console.error.restore();
53+
};
54+
55+
beforeEach(stubConsole);
56+
afterEach(restoreConsole);
57+
58+
it('should prompt-tune an existing code model', async () => {
59+
// Act
60+
await tuneModel(projectId, pipelineJobName, modelDisplayName, bucketUri);
61+
62+
// Assert
63+
assert.include(console.log.firstCall.args, 'Tuning pipeline job:');
64+
});
65+
66+
after(async () => {
67+
// Cancel and delete the pipeline job
68+
const name = pipelineClient.pipelineJobPath(
69+
projectId,
70+
location,
71+
pipelineJobName
72+
);
73+
74+
const cancelRequest = {
75+
name,
76+
};
77+
78+
pipelineClient.cancelPipelineJob(cancelRequest).then(() => {
79+
const deleteRequest = {
80+
name,
81+
};
82+
83+
return pipelineClient.deletePipeline(deleteRequest);
84+
});
85+
});
86+
});

0 commit comments

Comments
 (0)