Skip to content

Commit

Permalink
Upsert metric for execution
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Jan 23, 2024
1 parent 39109e4 commit 5d40402
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 13 deletions.
12 changes: 9 additions & 3 deletions xlml/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def post_process(self) -> DAGNode:
A DAG node that executes the post process.
"""
with TaskGroup(group_id="post_process") as group:
process_id = metric.generate_process_id.override(retries=0)()
process_id = metric.generate_process_id.override(retries=0)(
self.task_test_config.benchmark_id
)
metric.process_metrics.override(retries=0)(
process_id,
self.task_test_config,
Expand Down Expand Up @@ -243,7 +245,9 @@ def post_process(self) -> DAGNode:
A DAG node that executes the post process.
"""
with TaskGroup(group_id="post_process") as group:
process_id = metric.generate_process_id.override(retries=0)()
process_id = metric.generate_process_id.override(retries=0)(
self.task_test_config.benchmark_id
)
metric.process_metrics.override(retries=0)(
process_id,
self.task_test_config,
Expand Down Expand Up @@ -360,7 +364,9 @@ def post_process(self) -> DAGNode:
A DAG node that executes the post process.
"""
with TaskGroup(group_id="post_process") as group:
process_id = metric.generate_process_id.override(retries=0)()
process_id = metric.generate_process_id.override(retries=0)(
self.task_test_config.benchmark_id
)
metric.process_metrics.override(retries=0)(
process_id,
self.task_test_config,
Expand Down
44 changes: 41 additions & 3 deletions xlml/utils/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def __init__(
):
self.project = google.auth.default()[1] if project is None else project
self.database = (
metric_config.DatasetOption.BENCHMARK_DATASET.value
if database is None
else database
metric_config.DatasetOption.XLML_DATASET.value if database is None else database
)
self.client = bigquery.Client(
project=project,
Expand All @@ -110,6 +108,46 @@ def is_valid_metric(self, value: float):
invalid_values = [math.inf, -math.inf, math.nan]
return not (value in invalid_values or math.isnan(value))

def delete(self, row_ids: Iterable[str]) -> None:
"""Delete records from tables.
There is a known issue that you cannot delete or update over table
in the streaming buffer period, which can last up to 90 min. The
error message is like `BigQuery Error : UPDATE or DELETE statement
over table would affect rows in the streaming buffer, which is not supported`
Args:
row_ids: A list of ids to remove.
"""
table_index_dict = {
self.job_history_table_id: "uuid",
self.metric_history_table_id: "job_uuid",
self.metadata_history_table_id: "job_uuid",
}

for table, index in table_index_dict.items():
for row_id in row_ids:
query_statement = (
f"DELETE FROM `{table}` WHERE EXISTS ("
f"SELECT * FROM `{table}` WHERE {index}='{row_id}')"
)
query_job = self.client.query(query_statement)

try:
result = query_job.result()
print(result)
logging.info(
f"No matching records or successfully deleted records in {table} with id {row_id}."
)
except Exception as e:
raise RuntimeError(
(
f"Failed to delete records in {table} with id {row_id} and error: {e}."
" Please note you cannot delete or update table in the streaming"
" buffer period, which can last up to 90 min."
)
)

def insert(self, test_runs: Iterable[TestRun]) -> None:
"""Insert Benchmark test runs into the table.
Expand Down
18 changes: 18 additions & 0 deletions xlml/utils/bigquery_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def setUp(self):
[self.metadata_history_row],
)
]
self.row_ids = [
"test1_id",
"test2_id",
]

@parameterized.named_parameters(
("-math.inf", -math.inf, False),
Expand All @@ -63,6 +67,20 @@ def test_is_valid_metric(self, x: float, expected_value: bool):
actual_value = bq_metric.is_valid_metric(x)
self.assertEqual(actual_value, expected_value)

@mock.patch.object(google.auth, "default", return_value=["mock", "mock_project"])
@mock.patch.object(bigquery.Client, "query")
def test_delete_failure(self, default, query):
bq_metric = test_bigquery.BigQueryMetricClient()
query.return_value.result.raiseError.side_effect = Exception("Test")
self.assertRaises(RuntimeError, bq_metric.delete, self.row_ids)

@mock.patch.object(google.auth, "default", return_value=["mock", "mock_project"])
@mock.patch.object(bigquery.Client, "query")
def test_delete_success(self, default, query):
bq_metric = test_bigquery.BigQueryMetricClient()
query.return_value.result.return_value = []
bq_metric.delete(self.row_ids)

@mock.patch.object(google.auth, "default", return_value=["mock", "mock_project"])
@mock.patch.object(bigquery.Client, "get_table", return_value="mock_table")
@mock.patch.object(bigquery.Client, "insert_rows", return_value=["there is an error"])
Expand Down
25 changes: 18 additions & 7 deletions xlml/utils/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,18 @@ def generate_row_uuid(base_id: str, index: int) -> str:


@task(trigger_rule="all_done")
def generate_process_id() -> str:
"""Generate a process id that will be a base id for uuid of test runs.
def generate_process_id(benchmark_id: str) -> str:
"""Generate a process id that will be a base id for single/multiple test run(s).
Args:
benchmark_id: The unique key for metrics generated by the test.
Returns:
A random uuid.
An id based on benchmark_id and Airflow run_id.
"""
return str(uuid.uuid4())
context = get_current_context()
id = str(f"{benchmark_id}_{context['run_id']}")
return hashlib.sha256(id.encode("utf-8")).hexdigest()


def is_valid_entry() -> bool:
Expand Down Expand Up @@ -505,9 +510,12 @@ def process_metrics(
else:
test_job_status = get_gce_job_status(task_test_config.benchmark_id)

row_ids = []
for index in range(len(metadata_history_rows_list)):
current_uuid = generate_row_uuid(base_id, index)
row_ids.append(current_uuid)
job_history_row = bigquery.JobHistoryRow(
uuid=generate_row_uuid(base_id, index),
uuid=current_uuid,
timestamp=current_time,
owner=task_test_config.task_owner,
job_name=benchmark_id,
Expand All @@ -522,5 +530,8 @@ def process_metrics(

print("Test run rows:", test_run_rows)

if is_valid_entry():
bigquery_metric.insert(test_run_rows)
# if is_valid_entry():
# delete records from BigQuery tables for the same Airflow run_id (if applies),
# then, insert new records from current run.
bigquery_metric.delete(row_ids)
bigquery_metric.insert(test_run_rows)

0 comments on commit 5d40402

Please sign in to comment.