Skip to content

Commit 7bc7f36

Browse files
authored
feat: add DataFrame.ai.forecast() support (#1828)
* feat: add DataFrame.ai.forecast() support * test * fix * constructor * update * comments * fix
1 parent eef158b commit 7bc7f36

File tree

6 files changed

+245
-49
lines changed

6 files changed

+245
-49
lines changed

bigframes/ml/core.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,17 @@ class BaseBqml:
3535

3636
def __init__(self, session: bigframes.session.Session):
3737
self._session = session
38-
self._base_sql_generator = ml_sql.BaseSqlGenerator()
38+
self._sql_generator = ml_sql.BaseSqlGenerator()
39+
40+
def ai_forecast(
41+
self,
42+
input_data: bpd.DataFrame,
43+
options: Mapping[str, Union[str, int, float, Iterable[str]]],
44+
) -> bpd.DataFrame:
45+
result_sql = self._sql_generator.ai_forecast(
46+
source_sql=input_data.sql, options=options
47+
)
48+
return self._session.read_gbq(result_sql)
3949

4050

4151
class BqmlModel(BaseBqml):
@@ -55,8 +65,8 @@ def __init__(self, session: bigframes.Session, model: bigquery.Model):
5565
self._model = model
5666
model_ref = self._model.reference
5767
assert model_ref is not None
58-
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
59-
model_ref
68+
self._sql_generator: ml_sql.ModelManipulationSqlGenerator = (
69+
ml_sql.ModelManipulationSqlGenerator(model_ref)
6070
)
6171

6272
def _apply_ml_tvf(
@@ -126,30 +136,28 @@ def model(self) -> bigquery.Model:
126136
def recommend(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
127137
return self._apply_ml_tvf(
128138
input_data,
129-
self._model_manipulation_sql_generator.ml_recommend,
139+
self._sql_generator.ml_recommend,
130140
)
131141

132142
def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
133143
return self._apply_ml_tvf(
134144
input_data,
135-
self._model_manipulation_sql_generator.ml_predict,
145+
self._sql_generator.ml_predict,
136146
)
137147

138148
def explain_predict(
139149
self, input_data: bpd.DataFrame, options: Mapping[str, int | float]
140150
) -> bpd.DataFrame:
141151
return self._apply_ml_tvf(
142152
input_data,
143-
lambda source_sql: self._model_manipulation_sql_generator.ml_explain_predict(
153+
lambda source_sql: self._sql_generator.ml_explain_predict(
144154
source_sql=source_sql,
145155
struct_options=options,
146156
),
147157
)
148158

149159
def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
150-
sql = self._model_manipulation_sql_generator.ml_global_explain(
151-
struct_options=options
152-
)
160+
sql = self._sql_generator.ml_global_explain(struct_options=options)
153161
return (
154162
self._session.read_gbq(sql)
155163
.sort_values(by="attribution", ascending=False)
@@ -159,7 +167,7 @@ def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
159167
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
160168
return self._apply_ml_tvf(
161169
input_data,
162-
self._model_manipulation_sql_generator.ml_transform,
170+
self._sql_generator.ml_transform,
163171
)
164172

165173
def generate_text(
@@ -170,7 +178,7 @@ def generate_text(
170178
options["flatten_json_output"] = True
171179
return self._apply_ml_tvf(
172180
input_data,
173-
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_text(
181+
lambda source_sql: self._sql_generator.ml_generate_text(
174182
source_sql=source_sql,
175183
struct_options=options,
176184
),
@@ -186,7 +194,7 @@ def generate_embedding(
186194
options["flatten_json_output"] = True
187195
return self._apply_ml_tvf(
188196
input_data,
189-
lambda source_sql: self._model_manipulation_sql_generator.ml_generate_embedding(
197+
lambda source_sql: self._sql_generator.ml_generate_embedding(
190198
source_sql=source_sql,
191199
struct_options=options,
192200
),
@@ -201,7 +209,7 @@ def generate_table(
201209
) -> bpd.DataFrame:
202210
return self._apply_ml_tvf(
203211
input_data,
204-
lambda source_sql: self._model_manipulation_sql_generator.ai_generate_table(
212+
lambda source_sql: self._sql_generator.ai_generate_table(
205213
source_sql=source_sql,
206214
struct_options=options,
207215
),
@@ -216,14 +224,14 @@ def detect_anomalies(
216224

217225
return self._apply_ml_tvf(
218226
input_data,
219-
lambda source_sql: self._model_manipulation_sql_generator.ml_detect_anomalies(
227+
lambda source_sql: self._sql_generator.ml_detect_anomalies(
220228
source_sql=source_sql,
221229
struct_options=options,
222230
),
223231
)
224232

225233
def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
226-
sql = self._model_manipulation_sql_generator.ml_forecast(struct_options=options)
234+
sql = self._sql_generator.ml_forecast(struct_options=options)
227235
timestamp_col_name = "forecast_timestamp"
228236
index_cols = [timestamp_col_name]
229237
first_col_name = self._session.read_gbq(sql).columns.values[0]
@@ -232,9 +240,7 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
232240
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
233241

234242
def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
235-
sql = self._model_manipulation_sql_generator.ml_explain_forecast(
236-
struct_options=options
237-
)
243+
sql = self._sql_generator.ml_explain_forecast(struct_options=options)
238244
timestamp_col_name = "time_series_timestamp"
239245
index_cols = [timestamp_col_name]
240246
first_col_name = self._session.read_gbq(sql).columns.values[0]
@@ -243,7 +249,7 @@ def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
243249
return self._session.read_gbq(sql, index_col=index_cols).reset_index()
244250

245251
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
246-
sql = self._model_manipulation_sql_generator.ml_evaluate(
252+
sql = self._sql_generator.ml_evaluate(
247253
input_data.sql if (input_data is not None) else None
248254
)
249255

@@ -254,28 +260,24 @@ def llm_evaluate(
254260
input_data: bpd.DataFrame,
255261
task_type: Optional[str] = None,
256262
):
257-
sql = self._model_manipulation_sql_generator.ml_llm_evaluate(
258-
input_data.sql, task_type
259-
)
263+
sql = self._sql_generator.ml_llm_evaluate(input_data.sql, task_type)
260264

261265
return self._session.read_gbq(sql)
262266

263267
def arima_evaluate(self, show_all_candidate_models: bool = False):
264-
sql = self._model_manipulation_sql_generator.ml_arima_evaluate(
265-
show_all_candidate_models
266-
)
268+
sql = self._sql_generator.ml_arima_evaluate(show_all_candidate_models)
267269

268270
return self._session.read_gbq(sql)
269271

270272
def arima_coefficients(self) -> bpd.DataFrame:
271-
sql = self._model_manipulation_sql_generator.ml_arima_coefficients()
273+
sql = self._sql_generator.ml_arima_coefficients()
272274

273275
return self._session.read_gbq(sql)
274276

275277
def centroids(self) -> bpd.DataFrame:
276278
assert self._model.model_type == "KMEANS"
277279

278-
sql = self._model_manipulation_sql_generator.ml_centroids()
280+
sql = self._sql_generator.ml_centroids()
279281

280282
return self._session.read_gbq(
281283
sql, index_col=["centroid_id", "feature"]
@@ -284,7 +286,7 @@ def centroids(self) -> bpd.DataFrame:
284286
def principal_components(self) -> bpd.DataFrame:
285287
assert self._model.model_type == "PCA"
286288

287-
sql = self._model_manipulation_sql_generator.ml_principal_components()
289+
sql = self._sql_generator.ml_principal_components()
288290

289291
return self._session.read_gbq(
290292
sql, index_col=["principal_component_id", "feature"]
@@ -293,7 +295,7 @@ def principal_components(self) -> bpd.DataFrame:
293295
def principal_component_info(self) -> bpd.DataFrame:
294296
assert self._model.model_type == "PCA"
295297

296-
sql = self._model_manipulation_sql_generator.ml_principal_component_info()
298+
sql = self._sql_generator.ml_principal_component_info()
297299

298300
return self._session.read_gbq(sql)
299301

@@ -319,7 +321,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
319321
# truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
320322
# The possibility of conflicts should be low.
321323
vertex_ai_model_id = vertex_ai_model_id[:63]
322-
sql = self._model_manipulation_sql_generator.alter_model(
324+
sql = self._sql_generator.alter_model(
323325
options={"vertex_ai_model_id": vertex_ai_model_id}
324326
)
325327
# Register the model and wait it to finish

bigframes/ml/sql.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def build_parameters(self, **kwargs: Union[str, int, float, Iterable[str]]) -> s
4949
param_strs = [f"{k}={self.encode_value(v)}" for k, v in kwargs.items()]
5050
return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs)
5151

52+
def build_named_parameters(
53+
self, **kwargs: Union[str, int, float, Iterable[str]]
54+
) -> str:
55+
param_strs = [f"{k} => {self.encode_value(v)}" for k, v in kwargs.items()]
56+
return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs)
57+
5258
def build_structs(self, **kwargs: Union[int, float, str, Mapping]) -> str:
5359
"""Encode a dict of values into a formatted STRUCT items for SQL"""
5460
param_strs = []
@@ -187,6 +193,17 @@ def ml_distance(
187193
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-distance"""
188194
return f"""SELECT *, ML.DISTANCE({sql_utils.identifier(col_x)}, {sql_utils.identifier(col_y)}, '{type}') AS {sql_utils.identifier(name)} FROM ({source_sql})"""
189195

196+
def ai_forecast(
197+
self,
198+
source_sql: str,
199+
options: Mapping[str, Union[int, float, bool, Iterable[str]]],
200+
):
201+
"""Encode AI.FORECAST.
202+
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-forecast"""
203+
named_parameters_sql = self.build_named_parameters(**options)
204+
205+
return f"""SELECT * FROM AI.FORECAST(({source_sql}),{named_parameters_sql})"""
206+
190207

191208
class ModelCreationSqlGenerator(BaseSqlGenerator):
192209
"""Sql generator for creating a model entity. Model id is the standalone id without project id and dataset id."""

bigframes/operations/ai.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,24 @@
1616

1717
import re
1818
import typing
19-
from typing import Dict, List, Optional, Sequence
19+
from typing import Dict, Iterable, List, Optional, Sequence, Union
2020
import warnings
2121

2222
import numpy as np
2323

24-
from bigframes import dtypes, exceptions
24+
from bigframes import dtypes, exceptions, options
2525
from bigframes.core import guid, log_adapter
2626

2727

2828
@log_adapter.class_logger
2929
class AIAccessor:
30-
def __init__(self, df) -> None:
30+
def __init__(self, df, base_bqml=None) -> None:
3131
import bigframes # Import in the function body to avoid circular imports.
3232
import bigframes.dataframe
33-
34-
if not bigframes.options.experiments.ai_operators:
35-
raise NotImplementedError()
33+
from bigframes.ml import core as ml_core
3634

3735
self._df: bigframes.dataframe.DataFrame = df
36+
self._base_bqml: ml_core.BaseBqml = base_bqml or ml_core.BaseBqml(df._session)
3837

3938
def filter(
4039
self,
@@ -89,6 +88,8 @@ def filter(
8988
ValueError: when the instruction refers to a non-existing column, or when no
9089
columns are referred to.
9190
"""
91+
if not options.experiments.ai_operators:
92+
raise NotImplementedError()
9293

9394
answer_col = "answer"
9495

@@ -181,6 +182,9 @@ def map(
181182
ValueError: when the instruction refers to a non-existing column, or when no
182183
columns are referred to.
183184
"""
185+
if not options.experiments.ai_operators:
186+
raise NotImplementedError()
187+
184188
import bigframes.dataframe
185189
import bigframes.series
186190

@@ -320,6 +324,8 @@ def classify(
320324
columns are referred to, or when the count of labels does not meet the
321325
requirement.
322326
"""
327+
if not options.experiments.ai_operators:
328+
raise NotImplementedError()
323329

324330
if len(labels) < 2 or len(labels) > 20:
325331
raise ValueError(
@@ -401,6 +407,9 @@ def join(
401407
Raises:
402408
ValueError if the amount of data that will be sent for LLM processing is larger than max_rows.
403409
"""
410+
if not options.experiments.ai_operators:
411+
raise NotImplementedError()
412+
404413
self._validate_model(model)
405414
columns = self._parse_columns(instruction)
406415

@@ -525,6 +534,8 @@ def search(
525534
ValueError: when the search_column is not found from the the data frame.
526535
TypeError: when the provided model is not TextEmbeddingGenerator.
527536
"""
537+
if not options.experiments.ai_operators:
538+
raise NotImplementedError()
528539

529540
if search_column not in self._df.columns:
530541
raise ValueError(f"Column `{search_column}` not found")
@@ -640,6 +651,9 @@ def top_k(
640651
ValueError: when the instruction refers to a non-existing column, or when no
641652
columns are referred to.
642653
"""
654+
if not options.experiments.ai_operators:
655+
raise NotImplementedError()
656+
643657
import bigframes.dataframe
644658
import bigframes.series
645659

@@ -834,6 +848,8 @@ def sim_join(
834848
Raises:
835849
ValueError: when the amount of data to be processed exceeds the specified max_rows.
836850
"""
851+
if not options.experiments.ai_operators:
852+
raise NotImplementedError()
837853

838854
if left_on not in self._df.columns:
839855
raise ValueError(f"Left column {left_on} not found")
@@ -883,6 +899,73 @@ def sim_join(
883899

884900
return join_result
885901

902+
def forecast(
903+
self,
904+
timestamp_column: str,
905+
data_column: str,
906+
*,
907+
model: str = "TimesFM 2.0",
908+
id_columns: Optional[Iterable[str]] = None,
909+
horizon: int = 10,
910+
confidence_level: float = 0.95,
911+
):
912+
"""
913+
Forecast time series at future horizon. Using Google Research's open source TimesFM(https://github.com/google-research/timesfm) model.
914+
915+
.. note::
916+
917+
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
918+
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
919+
and might have limited support. For more information, see the launch stage descriptions
920+
(https://cloud.google.com/products#product-launch-stages).
921+
922+
Args:
923+
timestamp_column (str):
924+
A str value that specified the name of the time points column.
925+
The time points column provides the time points used to generate the forecast.
926+
The time points column must use one of the following data types: TIMESTAMP, DATE and DATETIME
927+
data_column (str):
928+
A str value that specifies the name of the data column. The data column contains the data to forecast.
929+
The data column must use one of the following data types: INT64, NUMERIC and FLOAT64
930+
model (str, default "TimesFM 2.0"):
931+
A str value that specifies the name of the model. TimesFM 2.0 is the only supported value, and is the default value.
932+
id_columns (Iterable[str] or None, default None):
933+
An iterable of str value that specifies the names of one or more ID columns. Each ID identifies a unique time series to forecast.
934+
Specify one or more values for this argument in order to forecast multiple time series using a single query.
935+
The columns that you specify must use one of the following data types: STRING, INT64, ARRAY<STRING> and ARRAY<INT64>
936+
horizon (int, default 10):
937+
An int value that specifies the number of time points to forecast. The default value is 10. The valid input range is [1, 10,000].
938+
confidence_level (float, default 0.95):
939+
A FLOAT64 value that specifies the percentage of the future values that fall in the prediction interval.
940+
The default value is 0.95. The valid input range is [0, 1).
941+
942+
Returns:
943+
DataFrame:
944+
The forecast dataframe matches that of the BigQuery AI.FORECAST function.
945+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-forecast
946+
947+
Raises:
948+
ValueError: when referring to a non-existing column.
949+
"""
950+
columns = [timestamp_column, data_column]
951+
if id_columns:
952+
columns += id_columns
953+
for column in columns:
954+
if column not in self._df.columns:
955+
raise ValueError(f"Column `{column}` not found")
956+
957+
options: dict[str, Union[int, float, str, Iterable[str]]] = {
958+
"data_col": data_column,
959+
"timestamp_col": timestamp_column,
960+
"model": model,
961+
"horizon": horizon,
962+
"confidence_level": confidence_level,
963+
}
964+
if id_columns:
965+
options["id_cols"] = id_columns
966+
967+
return self._base_bqml.ai_forecast(input_data=self._df, options=options)
968+
886969
@staticmethod
887970
def _attach_embedding(dataframe, source_column: str, embedding_column: str, model):
888971
result_df = dataframe.copy()

0 commit comments

Comments
 (0)