Skip to content
Open
7 changes: 5 additions & 2 deletions packages/bigframes/bigframes/bigquery/_operations/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import bigframes.dataframe as dataframe
import bigframes.ml.base
import bigframes.session
import bigframes.core.expression as ex
from bigframes.bigquery._operations import utils


Expand All @@ -50,7 +51,9 @@ def create_model(
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
options: Optional[
Mapping[str, Union[str, int, float, bool, list, "ex.Expression"]]
] = None,
training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None,
session: Optional[bigframes.session.Session] = None,
Expand Down Expand Up @@ -78,7 +81,7 @@ def create_model(
The OUTPUT clause, which specifies the schema of the output data.
connection_name (str, optional):
The connection to use for the model.
options (Mapping[str, Union[str, int, float, bool, list]], optional):
options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.expression.Expression]], optional):
The OPTIONS clause, which specifies the model options.
training_data (Union[bigframes.pandas.DataFrame, str], optional):
The query or DataFrame to use for training the model.
Expand Down
11 changes: 9 additions & 2 deletions packages/bigframes/bigframes/core/sql/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from typing import Any, Dict, List, Mapping, Optional, Union

import bigframes.core.expression as ex
from bigframes.core.compile.sqlglot import sql as sg_sql
from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler


def create_model_ddl(
Expand All @@ -28,7 +30,9 @@ def create_model_ddl(
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
options: Optional[
Mapping[str, Union[str, int, float, bool, list, "ex.Expression"]]
] = None,
training_data: Optional[str] = None,
custom_holiday: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -70,7 +74,10 @@ def create_model_ddl(
if options:
rendered_options = []
for option_name, option_value in options.items():
if isinstance(option_value, (list, tuple)):
if isinstance(option_value, ex.Expression):
sg_expr = expression_compiler.compile_expression(option_value)
rendered_val = sg_sql.to_sql(sg_expr)
elif isinstance(option_value, (list, tuple)):
# Handle list options like model_registry="vertex_ai"
# wait, usually options are key=value.
# if value is list, it is [val1, val2]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CREATE MODEL `my_model`
OPTIONS(l2_reg = 0.1, booster_type = 'gbtree')
AS SELECT * FROM t
14 changes: 14 additions & 0 deletions packages/bigframes/tests/unit/core/sql/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def test_create_model_list_option(snapshot):
snapshot.assert_match(sql, "create_model_list_option.sql")


def test_create_model_expression_option(snapshot):
import bigframes.core.expression as ex
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved code readability and adherence to the PEP 8 style guide, it's best to place imports at the top of the file. Please move this import to the top-level of the module.

References
  1. PEP 8, the style guide for Python code, recommends that all imports should be at the top of the file. This makes it easy to see what modules the script requires. Imports within functions are generally discouraged, except for cases like avoiding circular dependencies or for optional imports, which does not seem to be the case here. (link)


sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
options={
"l2_reg": ex.ScalarConstantExpression(0.1, None),
"booster_type": "gbtree",
},
training_data="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_expression_option.sql")


def test_evaluate_model_basic(snapshot):
sql = bigframes.core.sql.ml.evaluate(
model_name="my_project.my_dataset.my_model",
Expand Down
Loading