@@ -35,7 +35,17 @@ class BaseBqml:
35
35
36
36
def __init__ (self , session : bigframes .session .Session ):
37
37
self ._session = session
38
- self ._base_sql_generator = ml_sql .BaseSqlGenerator ()
38
+ self ._sql_generator = ml_sql .BaseSqlGenerator ()
39
+
40
+ def ai_forecast (
41
+ self ,
42
+ input_data : bpd .DataFrame ,
43
+ options : Mapping [str , Union [str , int , float , Iterable [str ]]],
44
+ ) -> bpd .DataFrame :
45
+ result_sql = self ._sql_generator .ai_forecast (
46
+ source_sql = input_data .sql , options = options
47
+ )
48
+ return self ._session .read_gbq (result_sql )
39
49
40
50
41
51
class BqmlModel (BaseBqml ):
@@ -55,8 +65,8 @@ def __init__(self, session: bigframes.Session, model: bigquery.Model):
55
65
self ._model = model
56
66
model_ref = self ._model .reference
57
67
assert model_ref is not None
58
- self ._model_manipulation_sql_generator = ml_sql .ModelManipulationSqlGenerator (
59
- model_ref
68
+ self ._sql_generator : ml_sql .ModelManipulationSqlGenerator = (
69
+ ml_sql . ModelManipulationSqlGenerator ( model_ref )
60
70
)
61
71
62
72
def _apply_ml_tvf (
@@ -126,30 +136,28 @@ def model(self) -> bigquery.Model:
126
136
def recommend (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
127
137
return self ._apply_ml_tvf (
128
138
input_data ,
129
- self ._model_manipulation_sql_generator .ml_recommend ,
139
+ self ._sql_generator .ml_recommend ,
130
140
)
131
141
132
142
def predict (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
133
143
return self ._apply_ml_tvf (
134
144
input_data ,
135
- self ._model_manipulation_sql_generator .ml_predict ,
145
+ self ._sql_generator .ml_predict ,
136
146
)
137
147
138
148
def explain_predict (
139
149
self , input_data : bpd .DataFrame , options : Mapping [str , int | float ]
140
150
) -> bpd .DataFrame :
141
151
return self ._apply_ml_tvf (
142
152
input_data ,
143
- lambda source_sql : self ._model_manipulation_sql_generator .ml_explain_predict (
153
+ lambda source_sql : self ._sql_generator .ml_explain_predict (
144
154
source_sql = source_sql ,
145
155
struct_options = options ,
146
156
),
147
157
)
148
158
149
159
def global_explain (self , options : Mapping [str , bool ]) -> bpd .DataFrame :
150
- sql = self ._model_manipulation_sql_generator .ml_global_explain (
151
- struct_options = options
152
- )
160
+ sql = self ._sql_generator .ml_global_explain (struct_options = options )
153
161
return (
154
162
self ._session .read_gbq (sql )
155
163
.sort_values (by = "attribution" , ascending = False )
@@ -159,7 +167,7 @@ def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
159
167
def transform (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
160
168
return self ._apply_ml_tvf (
161
169
input_data ,
162
- self ._model_manipulation_sql_generator .ml_transform ,
170
+ self ._sql_generator .ml_transform ,
163
171
)
164
172
165
173
def generate_text (
@@ -170,7 +178,7 @@ def generate_text(
170
178
options ["flatten_json_output" ] = True
171
179
return self ._apply_ml_tvf (
172
180
input_data ,
173
- lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_text (
181
+ lambda source_sql : self ._sql_generator .ml_generate_text (
174
182
source_sql = source_sql ,
175
183
struct_options = options ,
176
184
),
@@ -186,7 +194,7 @@ def generate_embedding(
186
194
options ["flatten_json_output" ] = True
187
195
return self ._apply_ml_tvf (
188
196
input_data ,
189
- lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_embedding (
197
+ lambda source_sql : self ._sql_generator .ml_generate_embedding (
190
198
source_sql = source_sql ,
191
199
struct_options = options ,
192
200
),
@@ -201,7 +209,7 @@ def generate_table(
201
209
) -> bpd .DataFrame :
202
210
return self ._apply_ml_tvf (
203
211
input_data ,
204
- lambda source_sql : self ._model_manipulation_sql_generator .ai_generate_table (
212
+ lambda source_sql : self ._sql_generator .ai_generate_table (
205
213
source_sql = source_sql ,
206
214
struct_options = options ,
207
215
),
@@ -216,14 +224,14 @@ def detect_anomalies(
216
224
217
225
return self ._apply_ml_tvf (
218
226
input_data ,
219
- lambda source_sql : self ._model_manipulation_sql_generator .ml_detect_anomalies (
227
+ lambda source_sql : self ._sql_generator .ml_detect_anomalies (
220
228
source_sql = source_sql ,
221
229
struct_options = options ,
222
230
),
223
231
)
224
232
225
233
def forecast (self , options : Mapping [str , int | float ]) -> bpd .DataFrame :
226
- sql = self ._model_manipulation_sql_generator .ml_forecast (struct_options = options )
234
+ sql = self ._sql_generator .ml_forecast (struct_options = options )
227
235
timestamp_col_name = "forecast_timestamp"
228
236
index_cols = [timestamp_col_name ]
229
237
first_col_name = self ._session .read_gbq (sql ).columns .values [0 ]
@@ -232,9 +240,7 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
232
240
return self ._session .read_gbq (sql , index_col = index_cols ).reset_index ()
233
241
234
242
def explain_forecast (self , options : Mapping [str , int | float ]) -> bpd .DataFrame :
235
- sql = self ._model_manipulation_sql_generator .ml_explain_forecast (
236
- struct_options = options
237
- )
243
+ sql = self ._sql_generator .ml_explain_forecast (struct_options = options )
238
244
timestamp_col_name = "time_series_timestamp"
239
245
index_cols = [timestamp_col_name ]
240
246
first_col_name = self ._session .read_gbq (sql ).columns .values [0 ]
@@ -243,7 +249,7 @@ def explain_forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
243
249
return self ._session .read_gbq (sql , index_col = index_cols ).reset_index ()
244
250
245
251
def evaluate (self , input_data : Optional [bpd .DataFrame ] = None ):
246
- sql = self ._model_manipulation_sql_generator .ml_evaluate (
252
+ sql = self ._sql_generator .ml_evaluate (
247
253
input_data .sql if (input_data is not None ) else None
248
254
)
249
255
@@ -254,28 +260,24 @@ def llm_evaluate(
254
260
input_data : bpd .DataFrame ,
255
261
task_type : Optional [str ] = None ,
256
262
):
257
- sql = self ._model_manipulation_sql_generator .ml_llm_evaluate (
258
- input_data .sql , task_type
259
- )
263
+ sql = self ._sql_generator .ml_llm_evaluate (input_data .sql , task_type )
260
264
261
265
return self ._session .read_gbq (sql )
262
266
263
267
def arima_evaluate (self , show_all_candidate_models : bool = False ):
264
- sql = self ._model_manipulation_sql_generator .ml_arima_evaluate (
265
- show_all_candidate_models
266
- )
268
+ sql = self ._sql_generator .ml_arima_evaluate (show_all_candidate_models )
267
269
268
270
return self ._session .read_gbq (sql )
269
271
270
272
def arima_coefficients (self ) -> bpd .DataFrame :
271
- sql = self ._model_manipulation_sql_generator .ml_arima_coefficients ()
273
+ sql = self ._sql_generator .ml_arima_coefficients ()
272
274
273
275
return self ._session .read_gbq (sql )
274
276
275
277
def centroids (self ) -> bpd .DataFrame :
276
278
assert self ._model .model_type == "KMEANS"
277
279
278
- sql = self ._model_manipulation_sql_generator .ml_centroids ()
280
+ sql = self ._sql_generator .ml_centroids ()
279
281
280
282
return self ._session .read_gbq (
281
283
sql , index_col = ["centroid_id" , "feature" ]
@@ -284,7 +286,7 @@ def centroids(self) -> bpd.DataFrame:
284
286
def principal_components (self ) -> bpd .DataFrame :
285
287
assert self ._model .model_type == "PCA"
286
288
287
- sql = self ._model_manipulation_sql_generator .ml_principal_components ()
289
+ sql = self ._sql_generator .ml_principal_components ()
288
290
289
291
return self ._session .read_gbq (
290
292
sql , index_col = ["principal_component_id" , "feature" ]
@@ -293,7 +295,7 @@ def principal_components(self) -> bpd.DataFrame:
293
295
def principal_component_info (self ) -> bpd .DataFrame :
294
296
assert self ._model .model_type == "PCA"
295
297
296
- sql = self ._model_manipulation_sql_generator .ml_principal_component_info ()
298
+ sql = self ._sql_generator .ml_principal_component_info ()
297
299
298
300
return self ._session .read_gbq (sql )
299
301
@@ -319,7 +321,7 @@ def register(self, vertex_ai_model_id: Optional[str] = None) -> BqmlModel:
319
321
# truncate as Vertex ID only accepts 63 characters, easily exceeding the limit for temp models.
320
322
# The possibility of conflicts should be low.
321
323
vertex_ai_model_id = vertex_ai_model_id [:63 ]
322
- sql = self ._model_manipulation_sql_generator .alter_model (
324
+ sql = self ._sql_generator .alter_model (
323
325
options = {"vertex_ai_model_id" : vertex_ai_model_id }
324
326
)
325
327
# Register the model and wait it to finish
0 commit comments