Skip to content

Commit 6d3e2bb

Browse files
committed
Add Big Qurey necessary dependencies
1 parent 81eca03 commit 6d3e2bb

File tree

13 files changed

+894
-74
lines changed

13 files changed

+894
-74
lines changed

benchmarks/benchmark_db_utils.py

Lines changed: 46 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,12 @@
2525
import dataclasses
2626
import getpass
2727
import os
28-
import sys
2928
import uuid
3029

3130
from argparse import Namespace
3231

33-
BQ_WRITER_PATH = "/benchmark-automation/benchmark_db_writer/src"
3432
temp_dir = gettempdir()
3533
DEFAULT_LOCAL_DIR = os.path.join(temp_dir, "")
36-
# bq_writer_repo_root = get_bq_writer_path(DEFAULT_LOCAL_DIR)
3734

3835
DEFAULT_TUNING_PARAMS_FILE = os.path.join(temp_dir, "tuning_params.json")
3936

@@ -114,7 +111,6 @@ def write_run(
114111
dataset: The dataset used in the run.
115112
num_of_superblock: The number of superblocks in the hardware. ( valid for GPUs)
116113
update_person_ldap: The LDAP ID of the person updating the record (default: current user).
117-
is_test: Whether to use the testing project or the production project.
118114
metrics: Metrics object containing:
119115
median_step_time: The median step time of the run.
120116
e2e_step_time: The end-to-end time of the run.
@@ -134,25 +130,20 @@ def write_run(
134130
Raises:
135131
ValueError: If any of the IDs are invalid.
136132
"""
137-
bq_writer_repo_root = BQ_WRITER_PATH
138-
sys.path.append(bq_writer_repo_root)
139-
140133
# pylint: disable=import-outside-toplevel
141134

142-
from benchmark_db_writer import bq_writer_utils
143-
from benchmark_db_writer import dataclass_bigquery_writer
144-
from benchmark_db_writer.run_summary_writer import sample_run_summary_writer
145-
from benchmark_db_writer.schema.workload_benchmark_v2 import workload_benchmark_v2_schema
135+
from benchmarks.benchmark_db_writer import bq_writer_utils
136+
from benchmarks.benchmark_db_writer import dataclass_bigquery_writer
137+
from benchmarks.benchmark_db_writer.schema.workload_benchmark_v2 import workload_benchmark_v2_schema
146138

147139
def get_db_client(
148-
project: str, dataset: str, table: str, dataclass_type: Type, is_test: bool = False
140+
project: str, dataset: str, table: str, dataclass_type: Type
149141
) -> dataclass_bigquery_writer.DataclassBigQueryWriter:
150142
"""Creates a BigQuery client object.
151143
152144
Args:
153145
table: The name of the BigQuery table.
154146
dataclass_type: The dataclass type corresponding to the table schema.
155-
is_test: Whether to use the testing project or the production project.
156147
157148
Returns:
158149
A BigQuery client object.
@@ -167,53 +158,45 @@ def get_db_client(
167158

168159
print(options.model_id)
169160

170-
if (
171-
sample_run_summary_writer.validate_model_id(options.model_id, options.is_test)
172-
and sample_run_summary_writer.validate_hardware_id(options.hardware_id, options.is_test)
173-
and sample_run_summary_writer.validate_software_id(options.software_id, options.is_test)
174-
):
175-
summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema(
176-
run_id=f"run-{uuid.uuid4()}",
177-
model_id=options.model_id,
178-
software_id=options.software_id,
179-
hardware_id=options.hardware_id,
180-
hardware_num_chips=number_of_chips,
181-
hardware_num_nodes=number_of_nodes,
182-
result_success=run_success,
183-
configs_framework=framework_config_in_json,
184-
configs_env=env_variables,
185-
configs_container_version=options.container_image_name,
186-
configs_xla_flags=options.xla_flags.replace(",", " "),
187-
configs_dataset=options.dataset,
188-
logs_artifact_directory="",
189-
update_person_ldap=getpass.getuser(),
190-
run_source="automation",
191-
run_type=options.run_type,
192-
run_release_status=run_release_status,
193-
workload_precision=options.precision,
194-
workload_gbs=int(options.global_batch_size),
195-
workload_optimizer=options.optimizer,
196-
workload_sequence_length=int(options.seq_length),
197-
metrics_e2e_time=metrics.e2e_step_time,
198-
metrics_mfu=mfu,
199-
metrics_step_time=metrics.median_step_time,
200-
metrics_tokens_per_second=metrics.avg_tokens_per_sec,
201-
metrics_num_steps=number_of_steps,
202-
metrics_other=other_metrics_in_json,
203-
hardware_nccl_driver_nickname=nccl_driver_nickname,
204-
hardware_topology=options.topology,
205-
hardware_num_superblocks=0,
206-
logs_comments=comment,
207-
)
208-
209-
client = get_db_client(
210-
options.db_project,
211-
options.db_dataset,
212-
"run_summary",
213-
workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema,
214-
options.is_test,
215-
)
216-
client.write([summary])
217-
218-
else:
219-
raise ValueError("Could not upload data in run summary table")
161+
summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema(
162+
run_id=f"run-{uuid.uuid4()}",
163+
model_id=options.model_id,
164+
software_id=options.software_id,
165+
hardware_id=options.hardware_id,
166+
hardware_num_chips=number_of_chips,
167+
hardware_num_nodes=number_of_nodes,
168+
hardware_num_slices=options.hardware_num_slices,
169+
result_success=run_success,
170+
configs_framework=framework_config_in_json,
171+
configs_env=env_variables,
172+
configs_container_version=options.container_image_name,
173+
configs_xla_flags=options.xla_flags.replace(",", " "),
174+
configs_dataset=options.dataset,
175+
logs_artifact_directory="",
176+
update_person_ldap=getpass.getuser(),
177+
run_source="automation",
178+
run_type=options.run_type,
179+
run_release_status=run_release_status,
180+
workload_precision=options.precision,
181+
workload_gbs=int(options.global_batch_size),
182+
workload_optimizer=options.optimizer,
183+
workload_sequence_length=int(options.seq_length),
184+
metrics_e2e_time=metrics.e2e_step_time,
185+
metrics_mfu=mfu,
186+
metrics_step_time=metrics.median_step_time,
187+
metrics_tokens_per_second=metrics.avg_tokens_per_sec,
188+
metrics_num_steps=number_of_steps,
189+
metrics_other=other_metrics_in_json,
190+
hardware_nccl_driver_nickname=nccl_driver_nickname,
191+
hardware_topology=options.topology,
192+
hardware_num_superblocks=0,
193+
logs_comments=comment,
194+
)
195+
196+
client = get_db_client(
197+
options.db_project,
198+
options.db_dataset,
199+
"run_summary",
200+
workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema,
201+
)
202+
client.write([summary])
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module defines enumerations for BigQuery data types (e.g., `STRING`,
17+
`INT64`) and field modes (e.g., `NULLABLE`, `REQUIRED`).
18+
19+
It also defines a primary mapping, `TypeMapping`, which translates these
20+
BigQuery types into their corresponding standard Python types (like `str`, `int`,
21+
`datetime.datetime`). Custom types (`TimeStamp`, `Geography`) are included
22+
for specific BQ types not perfectly represented by Python built-ins.
23+
Copied & Modified from https://github.com/AI-Hypercomputer/aotc/blob/main/src/aotc/
24+
benchmark_db_writer/src/benchmark_db_writer/bigquery_types.py
25+
"""
26+
import datetime
27+
import decimal
28+
import enum
29+
from typing import Dict, NewType, Type
30+
31+
32+
class BigQueryFieldModes(str, enum.Enum):
33+
"""
34+
Enums for BigQueryFieldModes
35+
"""
36+
37+
NULLABLE = "NULLABLE"
38+
REQUIRED = "REQUIRED"
39+
REPEATED = "REPEATED"
40+
41+
42+
class BigQueryTypes(str, enum.Enum):
43+
"""
44+
Enums for BigQueryTypes
45+
"""
46+
47+
STRING = "STRING"
48+
BYTES = "BYTES"
49+
INTEGER = "INT64"
50+
INT64 = "INT64"
51+
FLOAT64 = "FLOAT64"
52+
FLOAT = "FLOAT64"
53+
NUMERIC = "NUMERIC"
54+
BOOL = "BOOL"
55+
BOOLEAN = "BOOL"
56+
STRUCT = "STRUCT"
57+
RECORD = "STRUCT"
58+
TIMESTAMP = "TIMESTAMP"
59+
DATE = "DATE"
60+
TIME = "TIME"
61+
DATETIME = "DATETIME"
62+
GEOGRAPHY = "GEOGRAPHY"
63+
JSON = "JSON"
64+
65+
66+
Geography = NewType("Geography", str)
67+
68+
69+
class TimeStamp(datetime.datetime):
70+
pass
71+
72+
73+
TypeMapping: Dict[BigQueryTypes, Type] = {
74+
BigQueryTypes.STRING: str,
75+
BigQueryTypes.BYTES: bytes,
76+
BigQueryTypes.INT64: int,
77+
BigQueryTypes.FLOAT64: float,
78+
BigQueryTypes.NUMERIC: decimal.Decimal,
79+
BigQueryTypes.BOOL: bool,
80+
BigQueryTypes.TIMESTAMP: TimeStamp,
81+
BigQueryTypes.DATE: datetime.date,
82+
BigQueryTypes.TIME: datetime.time,
83+
BigQueryTypes.DATETIME: datetime.datetime,
84+
BigQueryTypes.GEOGRAPHY: Geography,
85+
BigQueryTypes.JSON: dict,
86+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Utilities and factory functions for creating BigQuery writer clients.
17+
18+
This module provides helper functions to simplify the instantiation of the
19+
`DataclassBigQueryWriter`. It centralizes the configuration, such as
20+
project and dataset IDs, making it easier to create database clients
21+
for specific tables.
22+
Copied & Modified from https://github.com/AI-Hypercomputer/aotc/blob/main/
23+
src/aotc/benchmark_db_writer/src/benchmark_db_writer/bigquery_types.py
24+
"""
25+
from typing import Type
26+
from benchmarks.benchmark_db_writer import dataclass_bigquery_writer
27+
28+
29+
def create_bq_writer_object(project, dataset, table, dataclass_type):
30+
"""Creates a BQ writer config and uses it to create BQ writer object."""
31+
32+
config = dataclass_bigquery_writer.BigqueryWriterConfig(project, dataset, table)
33+
34+
writer = dataclass_bigquery_writer.DataclassBigQueryWriter(dataclass_type, config)
35+
36+
return writer
37+
38+
39+
def get_db_client(table: str, dataclass_type: Type) -> create_bq_writer_object:
40+
"""Creates a BigQuery client object.
41+
42+
Args:
43+
table: The name of the BigQuery table.
44+
dataclass_type: The dataclass type corresponding to the table schema.
45+
46+
Returns:
47+
A BigQuery client object.
48+
"""
49+
50+
project = "ml-workload-benchmarks"
51+
dataset = "benchmark_dataset_v2"
52+
return create_bq_writer_object(
53+
project=project,
54+
dataset=dataset,
55+
table=table,
56+
dataclass_type=dataclass_type,
57+
)

0 commit comments

Comments
 (0)