Skip to content

Commit 231f010

Browse files
authored
feat: batch prediction samples (#3884)
1 parent 4b67f18 commit 231f010

File tree

4 files changed

+367
-0
lines changed

4 files changed

+367
-0
lines changed

Diff for: ai-platform/snippets/batch-code-predict.js

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(projectId, inputUri, outputUri, jobDisplayName) {
20+
// [START generativeaionvertexai_batch_code_predict]
21+
// Imports the aiplatform library
22+
const aiplatformLib = require('@google-cloud/aiplatform');
23+
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;
24+
25+
/**
26+
* TODO(developer): Uncomment/update these variables before running the sample.
27+
*/
28+
// projectId = 'YOUR_PROJECT_ID';
29+
30+
// Optional: URI of the input dataset.
31+
// Could be a BigQuery table or a Google Cloud Storage file.
32+
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
33+
// inputUri =
34+
// 'gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl';
35+
36+
// Optional: URI where the output will be stored.
37+
// Could be a BigQuery table or a Google Cloud Storage file.
38+
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
39+
// outputUri = 'gs://batch-bucket-testing/batch_code_predict_output';
40+
41+
// The name of batch prediction job
42+
// jobDisplayName = `Batch code prediction job: ${new Date().getMilliseconds()}`;
43+
44+
// The name of pre-trained model
45+
const codeModel = 'code-bison';
46+
const location = 'us-central1';
47+
48+
// Construct your modelParameters
49+
const parameters = {
50+
maxOutputTokens: '200',
51+
temperature: '0.2',
52+
};
53+
const parametersValue = aiplatformLib.helpers.toValue(parameters);
54+
// Configure the parent resource
55+
const parent = `projects/${projectId}/locations/${location}`;
56+
const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${codeModel}`;
57+
58+
// Specifies the location of the api endpoint
59+
const clientOptions = {
60+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
61+
};
62+
63+
// Instantiates a client
64+
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);
65+
66+
// Perform batch code prediction using a pre-trained code generation model.
67+
// Example of using Google Cloud Storage bucket as the input and output data source
68+
async function callBatchCodePredicton() {
69+
const gcsSource = new aiplatform.GcsSource({
70+
uris: [inputUri],
71+
});
72+
73+
const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
74+
gcsSource,
75+
instancesFormat: 'jsonl',
76+
});
77+
78+
const gcsDestination = new aiplatform.GcsDestination({
79+
outputUriPrefix: outputUri,
80+
});
81+
82+
const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
83+
gcsDestination,
84+
predictionsFormat: 'jsonl',
85+
});
86+
87+
const batchPredictionJob = new aiplatform.BatchPredictionJob({
88+
displayName: jobDisplayName,
89+
model: modelName,
90+
inputConfig,
91+
outputConfig,
92+
modelParameters: parametersValue,
93+
});
94+
95+
const request = {
96+
parent,
97+
batchPredictionJob,
98+
};
99+
100+
// Create batch prediction job request
101+
const [response] = await jobServiceClient.createBatchPredictionJob(request);
102+
103+
console.log('Raw response: ', JSON.stringify(response, null, 2));
104+
}
105+
106+
await callBatchCodePredicton();
107+
// [END generativeaionvertexai_batch_code_predict]
108+
}
109+
110+
main(...process.argv.slice(2)).catch(err => {
111+
console.error(err.message);
112+
process.exitCode = 1;
113+
});

Diff for: ai-platform/snippets/batch-text-predict.js

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(projectId, inputUri, outputUri, jobDisplayName) {
20+
// [START generativeaionvertexai_batch_text_predict]
21+
// Imports the aiplatform library
22+
const aiplatformLib = require('@google-cloud/aiplatform');
23+
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;
24+
25+
/**
26+
* TODO(developer): Uncomment/update these variables before running the sample.
27+
*/
28+
// projectId = 'YOUR_PROJECT_ID';
29+
30+
// Optional: URI of the input dataset.
31+
// Could be a BigQuery table or a Google Cloud Storage file.
32+
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
33+
// inputUri =
34+
// 'gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl';
35+
36+
// Optional: URI where the output will be stored.
37+
// Could be a BigQuery table or a Google Cloud Storage file.
38+
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
39+
// outputUri = 'gs://batch-bucket-testing/batch_text_predict_output';
40+
41+
// The name of batch prediction job
42+
// jobDisplayName = `Batch text prediction job: ${new Date().getMilliseconds()}`;
43+
44+
// The name of pre-trained model
45+
const textModel = 'text-bison';
46+
const location = 'us-central1';
47+
48+
// Construct your modelParameters
49+
const parameters = {
50+
maxOutputTokens: '200',
51+
temperature: '0.2',
52+
topP: '0.95',
53+
topK: '40',
54+
};
55+
const parametersValue = aiplatformLib.helpers.toValue(parameters);
56+
// Configure the parent resource
57+
const parent = `projects/${projectId}/locations/${location}`;
58+
const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${textModel}`;
59+
60+
// Specifies the location of the api endpoint
61+
const clientOptions = {
62+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
63+
};
64+
65+
// Instantiates a client
66+
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);
67+
68+
// Perform batch text prediction using a pre-trained text generation model.
69+
// Example of using Google Cloud Storage bucket as the input and output data source
70+
async function callBatchTextPredicton() {
71+
const gcsSource = new aiplatform.GcsSource({
72+
uris: [inputUri],
73+
});
74+
75+
const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
76+
gcsSource,
77+
instancesFormat: 'jsonl',
78+
});
79+
80+
const gcsDestination = new aiplatform.GcsDestination({
81+
outputUriPrefix: outputUri,
82+
});
83+
84+
const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
85+
gcsDestination,
86+
predictionsFormat: 'jsonl',
87+
});
88+
89+
const batchPredictionJob = new aiplatform.BatchPredictionJob({
90+
displayName: jobDisplayName,
91+
model: modelName,
92+
inputConfig,
93+
outputConfig,
94+
modelParameters: parametersValue,
95+
});
96+
97+
const request = {
98+
parent,
99+
batchPredictionJob,
100+
};
101+
102+
// Create batch prediction job request
103+
const [response] = await jobServiceClient.createBatchPredictionJob(request);
104+
105+
console.log('Raw response: ', JSON.stringify(response, null, 2));
106+
}
107+
108+
await callBatchTextPredicton();
109+
// [END generativeaionvertexai_batch_text_predict]
110+
}
111+
112+
main(...process.argv.slice(2)).catch(err => {
113+
console.error(err.message);
114+
process.exitCode = 1;
115+
});

Diff for: ai-platform/snippets/test/batch-code-predict.test.js

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const {assert} = require('chai');
20+
const {after, describe, it} = require('mocha');
21+
const uuid = require('uuid').v4;
22+
const cp = require('child_process');
23+
const {JobServiceClient} = require('@google-cloud/aiplatform');
24+
25+
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
26+
27+
describe('Batch code predict', async () => {
28+
const displayName = `batch-code-predict-job-${uuid()}`;
29+
const location = 'us-central1';
30+
const inputUri =
31+
'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl';
32+
const outputUri = 'gs://ucaip-samples-test-output/';
33+
const jobServiceClient = new JobServiceClient({
34+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
35+
});
36+
const projectId = process.env.CAIP_PROJECT_ID;
37+
let batchPredictionJobId;
38+
39+
after(async () => {
40+
const name = jobServiceClient.batchPredictionJobPath(
41+
projectId,
42+
location,
43+
batchPredictionJobId
44+
);
45+
46+
const cancelRequest = {
47+
name,
48+
};
49+
50+
jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
51+
const deleteRequest = {
52+
name,
53+
};
54+
55+
return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
56+
});
57+
});
58+
59+
it('should create job with code prediction', async () => {
60+
const response = execSync(
61+
`node ./batch-code-predict.js ${projectId} ${inputUri} ${outputUri} ${displayName}`
62+
);
63+
64+
assert.match(response, new RegExp(displayName));
65+
66+
batchPredictionJobId = response
67+
.split('/locations/us-central1/batchPredictionJobs/')[1]
68+
.split('\n')[0];
69+
});
70+
});

Diff for: ai-platform/snippets/test/batch-text-predict.test.js

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const {assert} = require('chai');
20+
const {after, describe, it} = require('mocha');
21+
const uuid = require('uuid').v4;
22+
const cp = require('child_process');
23+
const {JobServiceClient} = require('@google-cloud/aiplatform');
24+
25+
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
26+
27+
describe('Batch text predict', async () => {
28+
const displayName = `batch-text-predict-job-${uuid()}`;
29+
const location = 'us-central1';
30+
const inputUri =
31+
'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl';
32+
const outputUri = 'gs://ucaip-samples-test-output/';
33+
const jobServiceClient = new JobServiceClient({
34+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
35+
});
36+
const projectId = process.env.CAIP_PROJECT_ID;
37+
let batchPredictionJobId;
38+
39+
after(async () => {
40+
const name = jobServiceClient.batchPredictionJobPath(
41+
projectId,
42+
location,
43+
batchPredictionJobId
44+
);
45+
46+
const cancelRequest = {
47+
name,
48+
};
49+
50+
jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
51+
const deleteRequest = {
52+
name,
53+
};
54+
55+
return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
56+
});
57+
});
58+
59+
it('should create job with text prediction', async () => {
60+
const response = execSync(
61+
`node ./batch-text-predict.js ${projectId} ${inputUri} ${outputUri} ${displayName}`
62+
);
63+
64+
assert.match(response, new RegExp(displayName));
65+
batchPredictionJobId = response
66+
.split('/locations/us-central1/batchPredictionJobs/')[1]
67+
.split('\n')[0];
68+
});
69+
});

0 commit comments

Comments
 (0)