Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions providers/databricks/docs/operators/sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ The sensor executes the SQL statement supplied by the user. The only required pa

Other parameters are optional and could be found in the class documentation.

Attaching query tags
^^^^^^^^^^^^^^^^^^^^

Both ``DatabricksSqlSensor`` and ``DatabricksPartitionSensor`` support ``query_tags`` and
``include_airflow_query_tags`` to attach metadata to every query sent to Databricks.

.. code-block:: python

sensor = DatabricksSqlSensor(
task_id="sensor_with_tags",
sql_warehouse_name="my_warehouse",
sql="SELECT 1 FROM my_table WHERE status = 'ready'",
query_tags={"team": "data-eng", "env": "prod"}, # merged with Airflow context tags
include_airflow_query_tags=True, # adds dag_id, task_id, run_id, try_number, map_index
)

Set ``include_airflow_query_tags=False`` to suppress the automatic Airflow context tags.

Examples
--------
Configuring Databricks connection to be used with the Sensor.
Expand Down
8 changes: 8 additions & 0 deletions providers/databricks/docs/operators/sql_statements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ but not limited to:
* ``catalog``
* ``schema``
* ``parameters``
* ``query_tags``
* ``include_airflow_query_tags`` - When ``True`` (default), Airflow context metadata
(``dag_id``, ``task_id``, ``run_id``, ``try_number``, ``map_index``) is merged into
``query_tags`` automatically. Set to ``False`` to suppress these automatic tags.

Examples
--------
Expand Down Expand Up @@ -91,6 +95,10 @@ but not limited to:
* ``catalog``
* ``schema``
* ``parameters``
* ``query_tags``
* ``include_airflow_query_tags`` - When ``True`` (default), Airflow context metadata
(``dag_id``, ``task_id``, ``run_id``, ``try_number``, ``map_index``) is merged into
``query_tags`` automatically. Set to ``False`` to suppress these automatic tags.

Examples
--------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
validate_trigger_event,
)
from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin
from airflow.providers.databricks.utils.query_tags import build_query_tags, dict_to_query_tag_list
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
Expand Down Expand Up @@ -1258,10 +1259,15 @@ class DatabricksSQLStatementsOperator(DatabricksSQLStatementsMixin, BaseOperator
:param do_xcom_push: Whether we should push statement_id to xcom.:
:param timeout: The timeout for the Airflow task executing the SQL statement. By default a value of 3600 seconds is used.
:param deferrable: Run operator in the deferrable mode.
:param query_tags: Optional dictionary of query tags to attach to the SQL statement. Tags are
passed as the ``query_tags`` field in the Databricks Statement Execution REST API request body.
See https://docs.databricks.com/api/workspace/statementexecution/executestatement
:param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags.
Defaults to True.
"""

# Used in airflow.models.BaseOperator
template_fields: Sequence[str] = ("databricks_conn_id",)
template_fields: Sequence[str] = ("databricks_conn_id", "query_tags")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
Expand All @@ -1284,9 +1290,11 @@ def __init__(
wait_for_termination: bool = True,
timeout: float = 3600,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
query_tags: dict[str, str | None] | None = None,
include_airflow_query_tags: bool = True,
**kwargs,
) -> None:
"""Create a new ``DatabricksSubmitRunOperator``."""
"""Create a new ``DatabricksSQLStatementsOperator``."""
super().__init__(**kwargs)
self.statement = statement
self.warehouse_id = warehouse_id
Expand All @@ -1300,6 +1308,8 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.query_tags = query_tags or {}
self.include_airflow_query_tags = include_airflow_query_tags

# This variable will be used in case our task gets killed.
self.statement_id: str | None = None
Expand All @@ -1321,6 +1331,7 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)

def execute(self, context: Context):
tags = build_query_tags(context, self.query_tags, self.include_airflow_query_tags)
json = {
"statement": self.statement,
"warehouse_id": self.warehouse_id,
Expand All @@ -1332,6 +1343,8 @@ def execute(self, context: Context):
# execution state.
"wait_timeout": "0s",
}
if tags:
json["query_tags"] = dict_to_query_tag_list(tags)
self.statement_id = self._hook.post_sql_statement(json)
if self.do_xcom_push and context is not None:
context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.providers.databricks.utils.query_tags import build_query_tags

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
Expand All @@ -46,24 +47,6 @@
_DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")


def _get_airflow_query_tags(context: Context) -> dict[str, str | None]:
"""Return Airflow context metadata as a query-tags dict."""
task_instance = context.get("ti")
if task_instance is None:
return {}

def _as_str(value: Any) -> str | None:
return None if value is None else str(value)

return {
"airflow_dag_id": _as_str(task_instance.dag_id),
"airflow_task_id": _as_str(task_instance.task_id),
"airflow_run_id": _as_str(task_instance.run_id),
"airflow_try_number": _as_str(task_instance.try_number),
"airflow_map_index": _as_str(task_instance.map_index),
}


class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
Executes SQL code in a Databricks SQL endpoint or a Databricks cluster.
Expand Down Expand Up @@ -329,14 +312,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen
return list(zip(descriptions, results))

def _get_query_tags(self, context: Context) -> dict[str, str | None] | None:
query_tags: dict[str, str | None] = {}

if self.include_airflow_query_tags and context is not None:
query_tags.update(_get_airflow_query_tags(context))

query_tags.update(self.query_tags)

return query_tags or None
return build_query_tags(context, self.query_tags, self.include_airflow_query_tags)

def execute(self, context: Context) -> Any:
self.get_db_hook().query_tags = self._get_query_tags(context)
Expand Down Expand Up @@ -561,14 +537,7 @@ def _create_sql_query(self) -> str:
return sql.strip()

def _get_query_tags(self, context: Context) -> dict[str, str | None] | None:
query_tags: dict[str, str | None] = {}

if self.include_airflow_query_tags and context is not None:
query_tags.update(_get_airflow_query_tags(context))

query_tags.update(self.query_tags)

return query_tags or None
return build_query_tags(context, self.query_tags, self.include_airflow_query_tags)

def execute(self, context: Context) -> Any:
self._sql = self._create_sql_query()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState
from airflow.providers.databricks.operators.databricks import DEFER_METHOD_NAME
from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin
from airflow.providers.databricks.utils.query_tags import build_query_tags, dict_to_query_tag_list

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
Expand All @@ -40,6 +41,7 @@ class DatabricksSQLStatementsSensor(DatabricksSQLStatementsMixin, BaseSensorOper
"databricks_conn_id",
"statement",
"statement_id",
"query_tags",
)
template_ext: Sequence[str] = (".json-tpl",)
ui_color = "#1CB1C2"
Expand All @@ -63,6 +65,8 @@ def __init__(
wait_for_termination: bool = True,
timeout: float = 3600,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
query_tags: dict[str, Any] | None = None,
include_airflow_query_tags: bool = True,
**kwargs,
):
# Handle the scenario where either both statement and statement_id are set/not set
Expand Down Expand Up @@ -91,6 +95,8 @@ def __init__(
self.deferrable = deferrable
self.timeout = timeout
self.do_xcom_push = do_xcom_push
self.query_tags = query_tags or {}
self.include_airflow_query_tags = include_airflow_query_tags

@cached_property
def _hook(self):
Expand All @@ -108,6 +114,7 @@ def _get_hook(self, caller: str) -> DatabricksHook:
def execute(self, context: Context):
if not self.statement_id:
# Otherwise, we'll go ahead and "submit" the statement
tags = build_query_tags(context, self.query_tags, self.include_airflow_query_tags)
json = {
"statement": self.statement,
"warehouse_id": self.warehouse_id,
Expand All @@ -116,6 +123,8 @@ def execute(self, context: Context):
"parameters": self.parameters,
"wait_timeout": "0s",
}
if tags:
json["query_tags"] = dict_to_query_tag_list(tags)

self.statement_id = self._hook.post_sql_statement(json)
self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.providers.common.compat.sdk import AirflowException, BaseSensorOperator
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.providers.databricks.utils.query_tags import build_query_tags

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
Expand Down Expand Up @@ -61,6 +62,11 @@ class DatabricksPartitionSensor(BaseSensorOperator):
:param partition_operator: Optional comparison operator for partitions, such as >=.
:param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler
:param client_parameters: Additional parameters internal to Databricks SQL connector parameters.
:param query_tags: Optional dictionary of query tags to attach to Databricks SQL queries.
Tags are injected via the ``QUERY_TAGS`` Databricks session parameter so they appear in
``system.query.history``. (templated)
:param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags.
Defaults to True.
"""

template_fields: Sequence[str] = (
Expand All @@ -70,6 +76,7 @@ class DatabricksPartitionSensor(BaseSensorOperator):
"table_name",
"partitions",
"http_headers",
"query_tags",
)

template_ext: Sequence[str] = (".sql",)
Expand All @@ -90,6 +97,8 @@ def __init__(
partition_operator: str = "=",
handler: Callable[[Any], Any] = fetch_all_handler,
client_parameters: dict[str, Any] | None = None,
query_tags: dict[str, str | None] | None = None,
include_airflow_query_tags: bool = True,
**kwargs,
) -> None:
self.databricks_conn_id = databricks_conn_id
Expand All @@ -106,6 +115,8 @@ def __init__(
self.client_parameters = client_parameters or {}
self.hook_params = kwargs.pop("hook_params", {})
self.handler = handler
self.query_tags = query_tags or {}
self.include_airflow_query_tags = include_airflow_query_tags
self.escaper = ParamEscaper()
super().__init__(**kwargs)

Expand Down Expand Up @@ -218,6 +229,9 @@ def _generate_partition_query(

def poke(self, context: Context) -> bool:
"""Check the table partitions and return the results."""
self._get_hook.query_tags = build_query_tags(
context, self.query_tags, self.include_airflow_query_tags
)
partition_result = self._check_table_partitions()
self.log.debug("Partition sensor result: %s", partition_result)
if partition_result:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.providers.common.compat.sdk import AirflowException, BaseSensorOperator
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.providers.databricks.utils.query_tags import build_query_tags

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
Expand Down Expand Up @@ -55,6 +56,11 @@ class DatabricksSqlSensor(BaseSensorOperator):
:param sql: SQL statement to be executed.
:param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler
:param client_parameters: Additional parameters internal to Databricks SQL connector parameters.
:param query_tags: Optional dictionary of query tags to attach to Databricks SQL queries.
Tags are injected via the ``QUERY_TAGS`` Databricks session parameter so they appear in
``system.query.history``. (templated)
:param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags.
Defaults to True.
"""

template_fields: Sequence[str] = (
Expand All @@ -63,6 +69,7 @@ class DatabricksSqlSensor(BaseSensorOperator):
"catalog",
"schema",
"http_headers",
"query_tags",
)

template_ext: Sequence[str] = (".sql",)
Expand All @@ -81,6 +88,8 @@ def __init__(
sql: str | Iterable[str],
handler: Callable[[Any], Any] = fetch_all_handler,
client_parameters: dict[str, Any] | None = None,
query_tags: dict[str, str | None] | None = None,
include_airflow_query_tags: bool = True,
**kwargs,
) -> None:
"""Create DatabricksSqlSensor object using the specified input arguments."""
Expand All @@ -96,6 +105,8 @@ def __init__(
self.client_parameters = client_parameters or {}
self.hook_params = kwargs.pop("hook_params", {})
self.handler = handler
self.query_tags = query_tags or {}
self.include_airflow_query_tags = include_airflow_query_tags
super().__init__(**kwargs)

@cached_property
Expand Down Expand Up @@ -132,4 +143,5 @@ def _get_results(self) -> bool:

def poke(self, context: Context) -> bool:
"""Sensor poke function to get and return results from the SQL sensor."""
self.hook.query_tags = build_query_tags(context, self.query_tags, self.include_airflow_query_tags)
return self._get_results()
Loading
Loading