Skip to content

Commit

Permalink
Allow to CREATE TEMP TABLE in SQL generation when necessary. Previous…
Browse files Browse the repository at this point in the history
…ly we assume RAND() in the WITH clause behave as if they are evaluated only once, but that's not always the case. In situation when that's not true, we need to CREATE TEMP TABLE to materialize the subqueries that have volatile functions, so that the same result is used in all places.

PiperOrigin-RevId: 710112682
  • Loading branch information
tcya authored and meterstick-copybara committed Jan 3, 2025
1 parent 3a70436 commit 9691afd
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 13 deletions.
71 changes: 66 additions & 5 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def compute_on_beam(
# pylint: enable=g-long-lambda


def to_sql(table, split_by=None):
return lambda metric: metric.to_sql(table, split_by)
def to_sql(table, split_by=None, create_tmp_table_for_volatile_fn=None):
return lambda metric: metric.to_sql(
table, split_by, create_tmp_table_for_volatile_fn
)


# Classes we built so caching across instances can be enabled with confidence.
Expand Down Expand Up @@ -675,7 +677,25 @@ def to_series_or_number(self, df):

def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
"""Executes the query from to_sql() and process the result."""
query = self.to_sql(table, split_by)
query = self.to_sql(table, split_by, False)
# We try to avoid using CREATE TEMP TABLE when possible. It's only used when
# - the query contains RAND();
# - the execute doesn't evaluate RAND() only once in the WITH clause;
# - ALLOW_TEMP_TABLE is True.
if sql.ALLOW_TEMP_TABLE and 'RAND' in str(query):
query_with_tmp_table = self.to_sql(table, split_by, True)
if str(query) != str(
query_with_tmp_table
) and not sql.rand_run_only_once_in_with_clause(execute):
try:
execute('CREATE OR REPLACE TEMP TABLE T AS (SELECT 42 AS ans);')
sql.TEMP_TABLE_SUPPORTED = True
query = self.to_sql(table, split_by, True)
except Exception: # pylint: disable=broad-except
sql.TEMP_TABLE_SUPPORTED = False
raise NotImplementedError # pylint: disable=raise-missing-from
finally:
sql.TEMP_TABLE_SUPPORTED = None
res = execute(str(query))
extra_idx = list(utils.get_extra_idx(self, return_superset=True))
indexes = split_by + extra_idx if split_by else extra_idx
Expand All @@ -688,8 +708,38 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
res.sort_values(split_by, kind='mergesort', inplace=True)
return res

def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
"""Generates SQL query for the metric."""
def to_sql(
self,
table,
split_by: Optional[Union[Text, List[Text]]] = None,
create_tmp_table_for_volatile_fn=None,
):
"""Generates SQL query for the metric.
Args:
table: The table or subquery we want to query from.
split_by: The columns that we use to split the data.
create_tmp_table_for_volatile_fn: When generating the query, we assume
that volatile functions like RAND() in the WITH clause behave as if they
are evaluated only once. Unfortunately, not all engines conform to that.
In that case, we need to CREATE TEMP TABLE to materialize the subqueries
that have volatile functions, so that the same result is used in all
places. An example is
WITH T AS (SELECT RAND() AS r)
SELECT t1.r - t2.r AS d
FROM T t1 CROSS JOIN T t2.
If it doesn't always evaluates to 0, then this arg should be True, and
we will put all subqueries that
1) have volatile functions and
2) are referenced in the same query multiple times,
into CREATE TEMP TABLE statements.
When you use compute_on_sql or compute_on_beam, this arg is
automatically decided based on your `execute` function.
Returns:
The SQL query for the metric as a SQL instance, which is similar to a str.
Calling str() on it will get the query in string.
"""
global_filter = utils.get_global_filter(self)
indexes = sql.Columns(split_by).add(
utils.get_extra_idx(self, return_superset=True)
Expand All @@ -706,6 +756,17 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
global_filter, indexes,
sql.Filters(), with_data)
query.with_data = with_data
create_tmp_table = (
sql.ALLOW_TEMP_TABLE
if create_tmp_table_for_volatile_fn is None
else create_tmp_table_for_volatile_fn
)
if not create_tmp_table:
return query
# None means we don't know yet so we only check for False.
if sql.TEMP_TABLE_SUPPORTED is False: # pylint: disable=g-bool-id-comparison
raise NotImplementedError # to fall back to the mixed mode
with_data.temp_tables = sql.get_temp_tables(with_data)
return query

def get_sql_and_with_clause(self, table: sql.Datasource,
Expand Down
14 changes: 9 additions & 5 deletions operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,23 +2052,27 @@ def compute_children_sql(self,
"""The return should be similar to compute_children()."""
raise NotImplementedError

def to_sql(self, table, split_by=None):
def to_sql(self, table, split_by=None, create_tmp_table_for_volatile_fn=None):
if not isinstance(self, (Jackknife, Bootstrap)):
raise NotImplementedError
split_by = [split_by] if isinstance(split_by, str) else list(split_by or [])
# If self is not root, this function won't be called.
self._is_root_node = True
if self.has_been_preaggregated or not self.can_precompute():
if not self.where:
return super(MetricWithCI, self).to_sql(table, split_by)
return super(MetricWithCI, self).to_sql(
table, split_by, create_tmp_table_for_volatile_fn
)
table = sql.Sql(None, table, self.where)
self_no_filter = copy.deepcopy(self)
self_no_filter.where = None
return self_no_filter.to_sql(table, split_by)
return self_no_filter.to_sql(
table, split_by, create_tmp_table_for_volatile_fn
)

expanded, _ = utils.get_fully_expanded_equivalent_metric_tree(self)
if self != expanded:
return expanded.to_sql(table, split_by)
return expanded.to_sql(table, split_by, create_tmp_table_for_volatile_fn)

expanded.where = None # The filter has been taken care of in preaggregation
expanded = utils.push_filters_to_leaf(expanded)
Expand Down Expand Up @@ -2097,7 +2101,7 @@ def to_sql(self, table, split_by=None):
equiv.unit = None
else:
equiv.has_local_filter = any([l.where for l in leaf])
return equiv.to_sql(preagg, split_by)
return equiv.to_sql(preagg, split_by, create_tmp_table_for_volatile_fn)

def get_sql_and_with_clause(
self, table, split_by, global_filter, indexes, local_filter, with_data
Expand Down
82 changes: 79 additions & 3 deletions sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@


SAFE_DIVIDE = 'IF(({denom}) = 0, NULL, ({numer}) / ({denom}))'
# If to use CREATE TEMP TABLE. Setting it to False disables CREATE TEMP TABLE
# even when it's needed.
ALLOW_TEMP_TABLE = True
# If the engine supports CREATE TEMP TABLE
TEMP_TABLE_SUPPORTED = None


def is_compatible(sql0, sql1):
Expand Down Expand Up @@ -67,6 +72,48 @@ def add_suffix(alias):
return alias + '_1'


def rand_run_only_once_in_with_clause(execute):
"""Check if the RAND() is only evaluated once in the WITH clause."""
d = execute(
'''WITH T AS (SELECT RAND() AS r)
SELECT t1.r - t2.r AS d
FROM T t1 CROSS JOIN T t2'''
)
return bool(d.iloc[0, 0] == 0)


def dep_on_rand_table(query, rand_tables):
"""Returns if a SQL query depends on any stochastic table in rand_tables."""
for rand_table in rand_tables:
if re.search(r'\b%s\b' % rand_table, str(query)):
return True
return False


def get_temp_tables(with_data):
"""Gets all the subquery tables that need to be materialized."""
tmp_tables = set()
for rand_table in with_data:
query = with_data[rand_table]
if 'RAND' not in str(query):
continue
dep_on_rand = set([rand_table])
for alias in with_data:
if dep_on_rand_table(with_data[alias].from_data, dep_on_rand):
dep_on_rand.add(alias)
for t in dep_on_rand:
from_data = with_data[t].from_data
if isinstance(from_data, Join) and not t.startswith(
'BootstrapRandomChoices'
):
if dep_on_rand_table(from_data.ds1, dep_on_rand) and dep_on_rand_table(
from_data.ds2, dep_on_rand
):
tmp_tables.add(rand_table)
break
return tmp_tables


def get_alias(c):
return getattr(c, 'alias_raw', c)

Expand Down Expand Up @@ -571,6 +618,7 @@ class Datasources(SqlComponents):
def __init__(self, datasources=None):
super(Datasources, self).__init__()
self.children = collections.OrderedDict()
self.temp_tables = set()
self.add(datasources)

@property
Expand Down Expand Up @@ -663,7 +711,7 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]):
return
if not isinstance(children, Datasource):
raise ValueError('Not a Datasource!')
alias, table = children.alias, children.table
alias, table = children.alias, children.table,
if alias not in self.children:
if table not in self.children.values():
self.children[alias] = table
Expand All @@ -676,6 +724,23 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]):
children.alias = add_suffix(alias)
return self.add(children)

def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]):
"""Marks alias and all its data dependencies as temp tables."""
if isinstance(table, str):
self.temp_tables.add(table)
if table in self.children:
self.add_temp_table(self.children[table])
return
if isinstance(table, Join):
self.add_temp_table(table.ds1)
self.add_temp_table(table.ds2)
return
if isinstance(table, Datasource):
return self.add_temp_table(table.table)
if isinstance(table, Sql):
return self.add_temp_table(table.from_data)
return self

def extend(self, other: 'Datasources'):
"""Merge other to self. Adjust the query if a new alias is needed."""
datasources = list(other.datasources)
Expand All @@ -691,7 +756,18 @@ def extend(self, other: 'Datasources'):
return self

def __str__(self):
return ',\n'.join((d.get_expression('WITH') for d in self.datasources if d))
temp_tables = []
with_tables = []
for d in self.datasources:
expression = d.get_expression('WITH')
if d.alias in self.temp_tables:
temp_tables.append(f'CREATE OR REPLACE TEMP TABLE {expression};')
else:
with_tables.append(expression)
res = '\n'.join(temp_tables)
if with_tables:
res += '\nWITH\n' + ',\n'.join(with_tables)
return res.strip()


class Sql(SqlComponent):
Expand Down Expand Up @@ -766,7 +842,7 @@ def merge(self, other: 'Sql'):
return True

def __str__(self):
with_clause = 'WITH\n%s' % self.with_data if self.with_data else None
with_clause = str(self.with_data) if self.with_data else None
all_columns = self.all_columns or '*'
select_clause = f'SELECT\n{all_columns}'
from_clause = ('FROM %s'
Expand Down

0 comments on commit 9691afd

Please sign in to comment.