2525import dataclasses
2626import getpass
2727import os
28- import sys
2928import uuid
3029
3130from argparse import Namespace
3231
33- BQ_WRITER_PATH = "/benchmark-automation/benchmark_db_writer/src"
3432temp_dir = gettempdir ()
3533DEFAULT_LOCAL_DIR = os .path .join (temp_dir , "" )
36- # bq_writer_repo_root = get_bq_writer_path(DEFAULT_LOCAL_DIR)
3734
3835DEFAULT_TUNING_PARAMS_FILE = os .path .join (temp_dir , "tuning_params.json" )
3936
@@ -114,7 +111,6 @@ def write_run(
114111 dataset: The dataset used in the run.
115112 num_of_superblock: The number of superblocks in the hardware. ( valid for GPUs)
116113 update_person_ldap: The LDAP ID of the person updating the record (default: current user).
117- is_test: Whether to use the testing project or the production project.
118114 metrics: Metrics object containing:
119115 median_step_time: The median step time of the run.
120116 e2e_step_time: The end-to-end time of the run.
@@ -134,25 +130,20 @@ def write_run(
134130 Raises:
135131 ValueError: If any of the IDs are invalid.
136132 """
137- bq_writer_repo_root = BQ_WRITER_PATH
138- sys .path .append (bq_writer_repo_root )
139-
140133 # pylint: disable=import-outside-toplevel
141134
142- from benchmark_db_writer import bq_writer_utils
143- from benchmark_db_writer import dataclass_bigquery_writer
144- from benchmark_db_writer .run_summary_writer import sample_run_summary_writer
145- from benchmark_db_writer .schema .workload_benchmark_v2 import workload_benchmark_v2_schema
135+ from benchmarks .benchmark_db_writer import bq_writer_utils
136+ from benchmarks .benchmark_db_writer import dataclass_bigquery_writer
137+ from benchmarks .benchmark_db_writer .schema .workload_benchmark_v2 import workload_benchmark_v2_schema
146138
147139 def get_db_client (
148- project : str , dataset : str , table : str , dataclass_type : Type , is_test : bool = False
140+ project : str , dataset : str , table : str , dataclass_type : Type
149141 ) -> dataclass_bigquery_writer .DataclassBigQueryWriter :
150142 """Creates a BigQuery client object.
151143
152144 Args:
153145 table: The name of the BigQuery table.
154146 dataclass_type: The dataclass type corresponding to the table schema.
155- is_test: Whether to use the testing project or the production project.
156147
157148 Returns:
158149 A BigQuery client object.
@@ -167,53 +158,45 @@ def get_db_client(
167158
168159 print (options .model_id )
169160
170- if (
171- sample_run_summary_writer .validate_model_id (options .model_id , options .is_test )
172- and sample_run_summary_writer .validate_hardware_id (options .hardware_id , options .is_test )
173- and sample_run_summary_writer .validate_software_id (options .software_id , options .is_test )
174- ):
175- summary = workload_benchmark_v2_schema .WorkloadBenchmarkV2Schema (
176- run_id = f"run-{ uuid .uuid4 ()} " ,
177- model_id = options .model_id ,
178- software_id = options .software_id ,
179- hardware_id = options .hardware_id ,
180- hardware_num_chips = number_of_chips ,
181- hardware_num_nodes = number_of_nodes ,
182- result_success = run_success ,
183- configs_framework = framework_config_in_json ,
184- configs_env = env_variables ,
185- configs_container_version = options .container_image_name ,
186- configs_xla_flags = options .xla_flags .replace ("," , " " ),
187- configs_dataset = options .dataset ,
188- logs_artifact_directory = "" ,
189- update_person_ldap = getpass .getuser (),
190- run_source = "automation" ,
191- run_type = options .run_type ,
192- run_release_status = run_release_status ,
193- workload_precision = options .precision ,
194- workload_gbs = int (options .global_batch_size ),
195- workload_optimizer = options .optimizer ,
196- workload_sequence_length = int (options .seq_length ),
197- metrics_e2e_time = metrics .e2e_step_time ,
198- metrics_mfu = mfu ,
199- metrics_step_time = metrics .median_step_time ,
200- metrics_tokens_per_second = metrics .avg_tokens_per_sec ,
201- metrics_num_steps = number_of_steps ,
202- metrics_other = other_metrics_in_json ,
203- hardware_nccl_driver_nickname = nccl_driver_nickname ,
204- hardware_topology = options .topology ,
205- hardware_num_superblocks = 0 ,
206- logs_comments = comment ,
207- )
208-
209- client = get_db_client (
210- options .db_project ,
211- options .db_dataset ,
212- "run_summary" ,
213- workload_benchmark_v2_schema .WorkloadBenchmarkV2Schema ,
214- options .is_test ,
215- )
216- client .write ([summary ])
217-
218- else :
219- raise ValueError ("Could not upload data in run summary table" )
161+ summary = workload_benchmark_v2_schema .WorkloadBenchmarkV2Schema (
162+ run_id = f"run-{ uuid .uuid4 ()} " ,
163+ model_id = options .model_id ,
164+ software_id = options .software_id ,
165+ hardware_id = options .hardware_id ,
166+ hardware_num_chips = number_of_chips ,
167+ hardware_num_nodes = number_of_nodes ,
168+ hardware_num_slices = options .hardware_num_slices ,
169+ result_success = run_success ,
170+ configs_framework = framework_config_in_json ,
171+ configs_env = env_variables ,
172+ configs_container_version = options .container_image_name ,
173+ configs_xla_flags = options .xla_flags .replace ("," , " " ),
174+ configs_dataset = options .dataset ,
175+ logs_artifact_directory = "" ,
176+ update_person_ldap = getpass .getuser (),
177+ run_source = "automation" ,
178+ run_type = options .run_type ,
179+ run_release_status = run_release_status ,
180+ workload_precision = options .precision ,
181+ workload_gbs = int (options .global_batch_size ),
182+ workload_optimizer = options .optimizer ,
183+ workload_sequence_length = int (options .seq_length ),
184+ metrics_e2e_time = metrics .e2e_step_time ,
185+ metrics_mfu = mfu ,
186+ metrics_step_time = metrics .median_step_time ,
187+ metrics_tokens_per_second = metrics .avg_tokens_per_sec ,
188+ metrics_num_steps = number_of_steps ,
189+ metrics_other = other_metrics_in_json ,
190+ hardware_nccl_driver_nickname = nccl_driver_nickname ,
191+ hardware_topology = options .topology ,
192+ hardware_num_superblocks = 0 ,
193+ logs_comments = comment ,
194+ )
195+
196+ client = get_db_client (
197+ options .db_project ,
198+ options .db_dataset ,
199+ "run_summary" ,
200+ workload_benchmark_v2_schema .WorkloadBenchmarkV2Schema ,
201+ )
202+ client .write ([summary ])
0 commit comments