diff --git a/providers/databricks/docs/operators/sql.rst b/providers/databricks/docs/operators/sql.rst index 448fce7bbf8f0..e0840e60987e6 100644 --- a/providers/databricks/docs/operators/sql.rst +++ b/providers/databricks/docs/operators/sql.rst @@ -103,6 +103,24 @@ The sensor executes the SQL statement supplied by the user. The only required pa Other parameters are optional and could be found in the class documentation. +Attaching query tags +^^^^^^^^^^^^^^^^^^^^ + +Both ``DatabricksSqlSensor`` and ``DatabricksPartitionSensor`` support ``query_tags`` and +``include_airflow_query_tags`` to attach metadata to every query sent to Databricks. + +.. code-block:: python + + sensor = DatabricksSqlSensor( + task_id="sensor_with_tags", + sql_warehouse_name="my_warehouse", + sql="SELECT 1 FROM my_table WHERE status = 'ready'", + query_tags={"team": "data-eng", "env": "prod"}, # merged with Airflow context tags + include_airflow_query_tags=True, # adds dag_id, task_id, run_id, try_number, map_index + ) + +Set ``include_airflow_query_tags=False`` to suppress the automatic Airflow context tags. + Examples -------- Configuring Databricks connection to be used with the Sensor. diff --git a/providers/databricks/docs/operators/sql_statements.rst b/providers/databricks/docs/operators/sql_statements.rst index 9d314a238dab7..58954a06fea2e 100644 --- a/providers/databricks/docs/operators/sql_statements.rst +++ b/providers/databricks/docs/operators/sql_statements.rst @@ -45,6 +45,10 @@ but not limited to: * ``catalog`` * ``schema`` * ``parameters`` +* ``query_tags`` +* ``include_airflow_query_tags`` - When ``True`` (default), Airflow context metadata + (``dag_id``, ``task_id``, ``run_id``, ``try_number``, ``map_index``) is merged into + ``query_tags`` automatically. Set to ``False`` to suppress these automatic tags. Examples -------- @@ -91,6 +95,10 @@ but not limited to: * ``catalog`` * ``schema`` * ``parameters`` +* ``query_tags`` +* ``include_airflow_query_tags`` - When ``True`` (default), Airflow context metadata + (``dag_id``, ``task_id``, ``run_id``, ``try_number``, ``map_index``) is merged into + ``query_tags`` automatically. Set to ``False`` to suppress these automatic tags. Examples -------- diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index bf743aaf5166b..0599c0f52fb4e 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -54,6 +54,7 @@ validate_trigger_event, ) from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin +from airflow.providers.databricks.utils.query_tags import build_query_tags, dict_to_query_tag_list from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -1258,10 +1259,15 @@ class DatabricksSQLStatementsOperator(DatabricksSQLStatementsMixin, BaseOperator :param do_xcom_push: Whether we should push statement_id to xcom.: :param timeout: The timeout for the Airflow task executing the SQL statement. By default a value of 3600 seconds is used. :param deferrable: Run operator in the deferrable mode. + :param query_tags: Optional dictionary of query tags to attach to the SQL statement. Tags are + passed as the ``query_tags`` field in the Databricks Statement Execution REST API request body. + See https://docs.databricks.com/api/workspace/statementexecution/executestatement + :param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags. + Defaults to True. """ # Used in airflow.models.BaseOperator - template_fields: Sequence[str] = ("databricks_conn_id",) + template_fields: Sequence[str] = ("databricks_conn_id", "query_tags") template_ext: Sequence[str] = (".json-tpl",) # Databricks brand color (blue) under white text ui_color = "#1CB1C2" @@ -1284,9 +1290,11 @@ def __init__( wait_for_termination: bool = True, timeout: float = 3600, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + query_tags: dict[str, str | None] | None = None, + include_airflow_query_tags: bool = True, **kwargs, ) -> None: - """Create a new ``DatabricksSubmitRunOperator``.""" + """Create a new ``DatabricksSQLStatementsOperator``.""" super().__init__(**kwargs) self.statement = statement self.warehouse_id = warehouse_id @@ -1300,6 +1308,8 @@ def __init__( self.databricks_retry_args = databricks_retry_args self.wait_for_termination = wait_for_termination self.deferrable = deferrable + self.query_tags = query_tags or {} + self.include_airflow_query_tags = include_airflow_query_tags # This variable will be used in case our task gets killed. self.statement_id: str | None = None @@ -1321,6 +1331,7 @@ def _get_hook(self, caller: str) -> DatabricksHook: ) def execute(self, context: Context): + tags = build_query_tags(context, self.query_tags, self.include_airflow_query_tags) json = { "statement": self.statement, "warehouse_id": self.warehouse_id, @@ -1332,6 +1343,8 @@ def execute(self, context: Context): # execution state. "wait_timeout": "0s", } + if tags: + json["query_tags"] = dict_to_query_tag_list(tags) self.statement_id = self._hook.post_sql_statement(json) if self.do_xcom_push and context is not None: context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py index f72514f2488f9..f025bc0962211 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py @@ -38,6 +38,7 @@ ) from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook +from airflow.providers.databricks.utils.query_tags import build_query_tags if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Context @@ -46,24 +47,6 @@ _DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/") -def _get_airflow_query_tags(context: Context) -> dict[str, str | None]: - """Return Airflow context metadata as a query-tags dict.""" - task_instance = context.get("ti") - if task_instance is None: - return {} - - def _as_str(value: Any) -> str | None: - return None if value is None else str(value) - - return { - "airflow_dag_id": _as_str(task_instance.dag_id), - "airflow_task_id": _as_str(task_instance.task_id), - "airflow_run_id": _as_str(task_instance.run_id), - "airflow_try_number": _as_str(task_instance.try_number), - "airflow_map_index": _as_str(task_instance.map_index), - } - - class DatabricksSqlOperator(SQLExecuteQueryOperator): """ Executes SQL code in a Databricks SQL endpoint or a Databricks cluster. @@ -329,14 +312,7 @@ def _process_output(self, results: list[Any], descriptions: list[Sequence[Sequen return list(zip(descriptions, results)) def _get_query_tags(self, context: Context) -> dict[str, str | None] | None: - query_tags: dict[str, str | None] = {} - - if self.include_airflow_query_tags and context is not None: - query_tags.update(_get_airflow_query_tags(context)) - - query_tags.update(self.query_tags) - - return query_tags or None + return build_query_tags(context, self.query_tags, self.include_airflow_query_tags) def execute(self, context: Context) -> Any: self.get_db_hook().query_tags = self._get_query_tags(context) @@ -561,14 +537,7 @@ def _create_sql_query(self) -> str: return sql.strip() def _get_query_tags(self, context: Context) -> dict[str, str | None] | None: - query_tags: dict[str, str | None] = {} - - if self.include_airflow_query_tags and context is not None: - query_tags.update(_get_airflow_query_tags(context)) - - query_tags.update(self.query_tags) - - return query_tags or None + return build_query_tags(context, self.query_tags, self.include_airflow_query_tags) def execute(self, context: Context) -> Any: self._sql = self._create_sql_query() diff --git a/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py b/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py index 74fc0c2ecefb0..5f4e64c011584 100644 --- a/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py @@ -26,6 +26,7 @@ from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState from airflow.providers.databricks.operators.databricks import DEFER_METHOD_NAME from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin +from airflow.providers.databricks.utils.query_tags import build_query_tags, dict_to_query_tag_list if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Context @@ -40,6 +41,7 @@ class DatabricksSQLStatementsSensor(DatabricksSQLStatementsMixin, BaseSensorOper "databricks_conn_id", "statement", "statement_id", + "query_tags", ) template_ext: Sequence[str] = (".json-tpl",) ui_color = "#1CB1C2" @@ -63,6 +65,8 @@ def __init__( wait_for_termination: bool = True, timeout: float = 3600, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + query_tags: dict[str, Any] | None = None, + include_airflow_query_tags: bool = True, **kwargs, ): # Handle the scenario where either both statement and statement_id are set/not set @@ -91,6 +95,8 @@ def __init__( self.deferrable = deferrable self.timeout = timeout self.do_xcom_push = do_xcom_push + self.query_tags = query_tags or {} + self.include_airflow_query_tags = include_airflow_query_tags @cached_property def _hook(self): @@ -108,6 +114,7 @@ def _get_hook(self, caller: str) -> DatabricksHook: def execute(self, context: Context): if not self.statement_id: # Otherwise, we'll go ahead and "submit" the statement + tags = build_query_tags(context, self.query_tags, self.include_airflow_query_tags) json = { "statement": self.statement, "warehouse_id": self.warehouse_id, @@ -116,6 +123,8 @@ def execute(self, context: Context): "parameters": self.parameters, "wait_timeout": "0s", } + if tags: + json["query_tags"] = dict_to_query_tag_list(tags) self.statement_id = self._hook.post_sql_statement(json) self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id) diff --git a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py index 2036dca97c387..ad8429cea4714 100644 --- a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py +++ b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py @@ -30,6 +30,7 @@ from airflow.providers.common.compat.sdk import AirflowException, BaseSensorOperator from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook +from airflow.providers.databricks.utils.query_tags import build_query_tags if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Context @@ -61,6 +62,11 @@ class DatabricksPartitionSensor(BaseSensorOperator): :param partition_operator: Optional comparison operator for partitions, such as >=. :param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler :param client_parameters: Additional parameters internal to Databricks SQL connector parameters. + :param query_tags: Optional dictionary of query tags to attach to Databricks SQL queries. + Tags are injected via the ``QUERY_TAGS`` Databricks session parameter so they appear in + ``system.query.history``. (templated) + :param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags. + Defaults to True. """ template_fields: Sequence[str] = ( @@ -70,6 +76,7 @@ class DatabricksPartitionSensor(BaseSensorOperator): "table_name", "partitions", "http_headers", + "query_tags", ) template_ext: Sequence[str] = (".sql",) @@ -90,6 +97,8 @@ def __init__( partition_operator: str = "=", handler: Callable[[Any], Any] = fetch_all_handler, client_parameters: dict[str, Any] | None = None, + query_tags: dict[str, str | None] | None = None, + include_airflow_query_tags: bool = True, **kwargs, ) -> None: self.databricks_conn_id = databricks_conn_id @@ -106,6 +115,8 @@ def __init__( self.client_parameters = client_parameters or {} self.hook_params = kwargs.pop("hook_params", {}) self.handler = handler + self.query_tags = query_tags or {} + self.include_airflow_query_tags = include_airflow_query_tags self.escaper = ParamEscaper() super().__init__(**kwargs) @@ -218,6 +229,9 @@ def _generate_partition_query( def poke(self, context: Context) -> bool: """Check the table partitions and return the results.""" + self._get_hook.query_tags = build_query_tags( + context, self.query_tags, self.include_airflow_query_tags + ) partition_result = self._check_table_partitions() self.log.debug("Partition sensor result: %s", partition_result) if partition_result: diff --git a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py index 6ab52df67b579..90c0b8de5dd3a 100644 --- a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py @@ -27,6 +27,7 @@ from airflow.providers.common.compat.sdk import AirflowException, BaseSensorOperator from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook +from airflow.providers.databricks.utils.query_tags import build_query_tags if TYPE_CHECKING: from airflow.providers.common.compat.sdk import Context @@ -55,6 +56,11 @@ class DatabricksSqlSensor(BaseSensorOperator): :param sql: SQL statement to be executed. :param handler: Handler for DbApiHook.run() to return results, defaults to fetch_all_handler :param client_parameters: Additional parameters internal to Databricks SQL connector parameters. + :param query_tags: Optional dictionary of query tags to attach to Databricks SQL queries. + Tags are injected via the ``QUERY_TAGS`` Databricks session parameter so they appear in + ``system.query.history``. (templated) + :param include_airflow_query_tags: If True, add Airflow DAG/task/run metadata as query tags. + Defaults to True. """ template_fields: Sequence[str] = ( @@ -63,6 +69,7 @@ class DatabricksSqlSensor(BaseSensorOperator): "catalog", "schema", "http_headers", + "query_tags", ) template_ext: Sequence[str] = (".sql",) @@ -81,6 +88,8 @@ def __init__( sql: str | Iterable[str], handler: Callable[[Any], Any] = fetch_all_handler, client_parameters: dict[str, Any] | None = None, + query_tags: dict[str, str | None] | None = None, + include_airflow_query_tags: bool = True, **kwargs, ) -> None: """Create DatabricksSqlSensor object using the specified input arguments.""" @@ -96,6 +105,8 @@ def __init__( self.client_parameters = client_parameters or {} self.hook_params = kwargs.pop("hook_params", {}) self.handler = handler + self.query_tags = query_tags or {} + self.include_airflow_query_tags = include_airflow_query_tags super().__init__(**kwargs) @cached_property @@ -132,4 +143,5 @@ def _get_results(self) -> bool: def poke(self, context: Context) -> bool: """Sensor poke function to get and return results from the SQL sensor.""" + self.hook.query_tags = build_query_tags(context, self.query_tags, self.include_airflow_query_tags) return self._get_results() diff --git a/providers/databricks/src/airflow/providers/databricks/utils/query_tags.py b/providers/databricks/src/airflow/providers/databricks/utils/query_tags.py new file mode 100644 index 0000000000000..90b58d7539adb --- /dev/null +++ b/providers/databricks/src/airflow/providers/databricks/utils/query_tags.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Shared utilities for Databricks query-tag handling.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +def get_airflow_query_tags(context: Context) -> dict[str, str | None]: + """Return Airflow context metadata as a query-tags dict.""" + task_instance = context.get("ti") + if task_instance is None: + return {} + + def _as_str(value: Any) -> str | None: + return None if value is None else str(value) + + return { + "airflow_dag_id": _as_str(task_instance.dag_id), + "airflow_task_id": _as_str(task_instance.task_id), + "airflow_run_id": _as_str(task_instance.run_id), + "airflow_try_number": _as_str(task_instance.try_number), + "airflow_map_index": _as_str(task_instance.map_index), + } + + +def build_query_tags( + context: Context | None, + user_query_tags: dict[str, str | None], + include_airflow_query_tags: bool, +) -> dict[str, str | None] | None: + """ + Merge Airflow context tags with user-supplied tags. + + Airflow tags are added first; user-supplied tags override on key collision. + Returns ``None`` when the resulting dict is empty so callers can skip + injection entirely. + """ + tags: dict[str, str | None] = {} + if include_airflow_query_tags and context is not None: + tags.update(get_airflow_query_tags(context)) + tags.update(user_query_tags) + return tags or None + + +def dict_to_query_tag_list(tags: dict[str, str | None]) -> list[dict[str, str]]: + """ + Convert a ``{key: value}`` dict to the ``[{"key": ..., "value": ...}]`` list format. + + Required by the Databricks Statement Execution REST API. + + See: https://docs.databricks.com/api/workspace/statementexecution/executestatement + Entries whose value is ``None`` are omitted. + """ + return [{"key": k, "value": v} for k, v in tags.items() if v is not None] diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 3d613712caa2a..ae026c2d15a62 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -2848,6 +2848,107 @@ def raise_func(): } assert result.job_facets == {"sql": SQLJobFacet(query="normalized" + query)} + def test_query_tags_defaults(self): + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + assert op.query_tags == {} + assert op.include_airflow_query_tags is True + + def test_query_tags_in_template_fields(self): + assert "query_tags" in DatabricksSQLStatementsOperator.template_fields + + def test_query_tags_stored(self): + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + query_tags={"env": "prod"}, + ) + assert op.query_tags == {"env": "prod"} + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_passes_query_tags_to_post_sql_statement(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + query_tags={"team": "data"}, + include_airflow_query_tags=False, + ) + op.execute(None) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + assert posted_json["query_tags"] == [{"key": "team", "value": "data"}] + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_omits_query_tags_when_none(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + include_airflow_query_tags=False, + ) + op.execute(None) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + assert "query_tags" not in posted_json + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_includes_airflow_query_tags(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + query_tags={"custom": "value"}, + ) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index", "xcom_push"]) + mock_ti.dag_id = "my_dag" + mock_ti.task_id = "my_task" + mock_ti.run_id = "run_1" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + + op.execute(context={"ti": mock_ti}) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + tag_keys = {t["key"] for t in posted_json["query_tags"]} + assert "airflow_dag_id" in tag_keys + assert "custom" in tag_keys + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_custom_query_tags_override_airflow_tags(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + op = DatabricksSQLStatementsOperator( + task_id=TASK_ID, + statement="select * from test.test;", + warehouse_id=WAREHOUSE_ID, + query_tags={"airflow_dag_id": "overridden"}, + ) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index", "xcom_push"]) + mock_ti.dag_id = "original" + mock_ti.task_id = "t" + mock_ti.run_id = "r" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + + op.execute(context={"ti": mock_ti}) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + tag_map = {t["key"]: t["value"] for t in posted_json["query_tags"]} + assert tag_map["airflow_dag_id"] == "overridden" + class TestDatabricksNotebookOperator: def test_is_instance_of_databricks_task_base_operator(self): diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py index bfd6d89437b15..1f4513b886a2f 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py @@ -29,7 +29,6 @@ from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.databricks.operators.databricks_sql import ( DatabricksSqlOperator, - _get_airflow_query_tags, ) DATE = "2017-04-20" @@ -462,11 +461,6 @@ def test_parse_gcs_path(): class TestDatabricksSqlOperatorQueryTags: """Tests for query tags support in DatabricksSqlOperator.""" - def test_get_airflow_query_tags_returns_empty_dict_without_task_instance(self): - """_get_airflow_query_tags must return {} when context has no 'ti' key.""" - result = _get_airflow_query_tags({}) - assert result == {} - def test_get_query_tags_with_none_context_returns_custom_tags_only(self): """When context is None, only custom tags are returned (no Airflow tags).""" op = DatabricksSqlOperator( diff --git a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py index 8f94274e7a4f6..fe04781f65d1d 100644 --- a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py @@ -235,3 +235,93 @@ def test_execute_complete_failure(self, db_mock_class): with pytest.raises(AirflowException, match="^SQL Statement execution failed with terminal state: .*"): op.execute_complete(context=None, event=event) + + def test_query_tags_defaults(self): + sensor = DatabricksSQLStatementsSensor( + task_id=TASK_ID, statement=STATEMENT, warehouse_id=WAREHOUSE_ID + ) + assert sensor.query_tags == {} + assert sensor.include_airflow_query_tags is True + + def test_query_tags_in_template_fields(self): + assert "query_tags" in DatabricksSQLStatementsSensor.template_fields + + def test_query_tags_stored(self): + sensor = DatabricksSQLStatementsSensor( + task_id=TASK_ID, statement=STATEMENT, warehouse_id=WAREHOUSE_ID, query_tags={"env": "prod"} + ) + assert sensor.query_tags == {"env": "prod"} + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_execute_passes_query_tags_to_post_sql_statement(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + sensor = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + query_tags={"team": "data"}, + include_airflow_query_tags=False, + ) + sensor.execute(context=None) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + assert posted_json["query_tags"] == [{"key": "team", "value": "data"}] + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_execute_omits_query_tags_when_none(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + sensor = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + include_airflow_query_tags=False, + ) + sensor.execute(context=None) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + assert "query_tags" not in posted_json + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_execute_includes_airflow_query_tags(self, db_mock_class): + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + sensor = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + query_tags={"custom": "value"}, + ) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index", "xcom_push"]) + mock_ti.dag_id = "my_dag" + mock_ti.task_id = "my_task" + mock_ti.run_id = "run_1" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + + sensor.execute(context={"ti": mock_ti}) + + posted_json = db_mock.post_sql_statement.call_args[0][0] + tag_keys = {t["key"] for t in posted_json["query_tags"]} + assert "airflow_dag_id" in tag_keys + assert "custom" in tag_keys + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_execute_skips_query_tags_when_statement_id_provided(self, db_mock_class): + db_mock = db_mock_class.return_value + + sensor = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement_id=STATEMENT_ID, + warehouse_id=WAREHOUSE_ID, + query_tags={"env": "test"}, + include_airflow_query_tags=False, + wait_for_termination=False, + ) + sensor.execute(context=None) + + db_mock.post_sql_statement.assert_not_called() diff --git a/providers/databricks/tests/unit/databricks/sensors/test_databricks_partition.py b/providers/databricks/tests/unit/databricks/sensors/test_databricks_partition.py index 23e6d12c3104b..6aef232e4e1c1 100644 --- a/providers/databricks/tests/unit/databricks/sensors/test_databricks_partition.py +++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks_partition.py @@ -19,6 +19,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from unittest import mock from unittest.mock import patch import pytest @@ -137,3 +138,71 @@ def test_fail_poke(self, _check_table_partitions): AirflowException, match=rf"Specified partition\(s\): {partitions} were not found." ): self.partition_sensor.poke(context={}) + + def _make_sensor(self, **kwargs): + return DatabricksPartitionSensor( + task_id=TASK_ID, + databricks_conn_id=DEFAULT_CONN_ID, + sql_warehouse_name=DEFAULT_SQL_WAREHOUSE, + table_name=DEFAULT_TABLE, + schema=DEFAULT_SCHEMA, + catalog=DEFAULT_CATALOG, + partitions=DEFAULT_PARTITION, + **kwargs, + ) + + def test_query_tags_defaults(self): + sensor = self._make_sensor() + assert sensor.query_tags == {} + assert sensor.include_airflow_query_tags is True + + def test_query_tags_in_template_fields(self): + assert "query_tags" in DatabricksPartitionSensor.template_fields + + def test_query_tags_stored(self): + sensor = self._make_sensor(query_tags={"team": "data"}) + assert sensor.query_tags == {"team": "data"} + + def test_poke_sets_query_tags_on_hook(self): + sensor = self._make_sensor(query_tags={"env": "test"}, include_airflow_query_tags=False) + with patch.object(sensor, "_check_table_partitions", return_value=[True]): + sensor.poke(context=None) + assert sensor._get_hook.query_tags == {"env": "test"} + + def test_poke_no_tags_when_disabled_and_no_custom(self): + sensor = self._make_sensor(include_airflow_query_tags=False) + with patch.object(sensor, "_check_table_partitions", return_value=[True]): + sensor.poke(context=None) + assert sensor._get_hook.query_tags is None + + def test_poke_includes_airflow_tags_from_context(self): + sensor = self._make_sensor(query_tags={"custom": "value"}) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index"]) + mock_ti.dag_id = "my_dag" + mock_ti.task_id = "my_task" + mock_ti.run_id = "run_1" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + mock_context = {"ti": mock_ti} + + with patch.object(sensor, "_check_table_partitions", return_value=[True]): + sensor.poke(context=mock_context) + + tags = sensor._get_hook.query_tags + assert tags is not None + assert tags["airflow_dag_id"] == "my_dag" + assert tags["custom"] == "value" + + def test_custom_tags_override_airflow_tags(self): + sensor = self._make_sensor(query_tags={"airflow_dag_id": "overridden"}) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index"]) + mock_ti.dag_id = "original" + mock_ti.task_id = "t" + mock_ti.run_id = "r" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + + with patch.object(sensor, "_check_table_partitions", return_value=[True]): + sensor.poke(context={"ti": mock_ti}) + + assert sensor._get_hook.query_tags["airflow_dag_id"] == "overridden" diff --git a/providers/databricks/tests/unit/databricks/sensors/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/sensors/test_databricks_sql.py index 3431722f77061..00f3ff4749aeb 100644 --- a/providers/databricks/tests/unit/databricks/sensors/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks_sql.py @@ -19,6 +19,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from unittest import mock from unittest.mock import patch import pytest @@ -100,3 +101,71 @@ def test_fail__get_results(self): " Please specify either http_path or sql_warehouse_name.", ): self.sensor._get_results() + + def _make_sensor(self, **kwargs): + return DatabricksSqlSensor( + task_id=TASK_ID, + databricks_conn_id=DEFAULT_CONN_ID, + sql_warehouse_name=DEFAULT_SQL_WAREHOUSE, + sql=DEFAULT_SQL, + **kwargs, + ) + + def test_query_tags_defaults(self): + sensor = self._make_sensor() + assert sensor.query_tags == {} + assert sensor.include_airflow_query_tags is True + + def test_query_tags_in_template_fields(self): + assert "query_tags" in DatabricksSqlSensor.template_fields + + def test_query_tags_stored(self): + sensor = self._make_sensor(query_tags={"env": "prod"}) + assert sensor.query_tags == {"env": "prod"} + + def test_poke_sets_query_tags_on_hook(self): + sensor = self._make_sensor(query_tags={"env": "test"}, include_airflow_query_tags=False) + with patch.object(sensor, "_get_results", return_value=True) as mock_results: + sensor.poke(context=None) + assert sensor.hook.query_tags == {"env": "test"} + mock_results.assert_called_once() + + def test_poke_no_tags_when_disabled_and_no_custom(self): + sensor = self._make_sensor(include_airflow_query_tags=False) + with patch.object(sensor, "_get_results", return_value=True): + sensor.poke(context=None) + assert sensor.hook.query_tags is None + + def test_poke_includes_airflow_tags_from_context(self): + sensor = self._make_sensor(query_tags={"custom": "value"}) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index"]) + mock_ti.dag_id = "my_dag" + mock_ti.task_id = "my_task" + mock_ti.run_id = "run_1" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + mock_context = {"ti": mock_ti} + + with patch.object(sensor, "_get_results", return_value=True): + sensor.poke(context=mock_context) + + tags = sensor.hook.query_tags + assert tags is not None + assert tags["airflow_dag_id"] == "my_dag" + assert tags["airflow_task_id"] == "my_task" + assert tags["custom"] == "value" + + def test_custom_tags_override_airflow_tags(self): + sensor = self._make_sensor(query_tags={"airflow_dag_id": "overridden"}) + mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index"]) + mock_ti.dag_id = "original" + mock_ti.task_id = "t" + mock_ti.run_id = "r" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + mock_context = {"ti": mock_ti} + + with patch.object(sensor, "_get_results", return_value=True): + sensor.poke(context=mock_context) + + assert sensor.hook.query_tags["airflow_dag_id"] == "overridden" diff --git a/providers/databricks/tests/unit/databricks/utils/test_query_tags.py b/providers/databricks/tests/unit/databricks/utils/test_query_tags.py new file mode 100644 index 0000000000000..4c2ee00a77f90 --- /dev/null +++ b/providers/databricks/tests/unit/databricks/utils/test_query_tags.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.databricks.utils.query_tags import ( + build_query_tags, + dict_to_query_tag_list, + get_airflow_query_tags, +) + + +def _make_context(): + ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index"]) + ti.dag_id = "test_dag" + ti.task_id = "test_task" + ti.run_id = "test_run" + ti.try_number = 1 + ti.map_index = -1 + return {"ti": ti} + + +class TestGetAirflowQueryTags: + def test_returns_empty_dict_without_task_instance(self): + assert get_airflow_query_tags({}) == {} + + def test_returns_stringified_context_metadata(self): + result = get_airflow_query_tags(_make_context()) + assert result == { + "airflow_dag_id": "test_dag", + "airflow_task_id": "test_task", + "airflow_run_id": "test_run", + "airflow_try_number": "1", + "airflow_map_index": "-1", + } + + def test_none_valued_attributes_become_none(self): + ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id", "try_number", "map_index"]) + ti.dag_id = "test_dag" + ti.task_id = None + ti.run_id = None + ti.try_number = None + ti.map_index = None + result = get_airflow_query_tags({"ti": ti}) + assert result["airflow_dag_id"] == "test_dag" + assert result["airflow_task_id"] is None + assert result["airflow_run_id"] is None + + +class TestBuildQueryTags: + def test_none_context_returns_user_tags_only(self): + assert build_query_tags(None, {"custom": "value"}, include_airflow_query_tags=True) == { + "custom": "value" + } + + def test_none_context_and_no_user_tags_returns_none(self): + assert build_query_tags(None, {}, include_airflow_query_tags=True) is None + + def test_disabled_airflow_tags_returns_user_tags_only(self): + result = build_query_tags(_make_context(), {"custom": "value"}, include_airflow_query_tags=False) + assert result == {"custom": "value"} + + def test_disabled_airflow_tags_and_no_user_tags_returns_none(self): + assert build_query_tags(_make_context(), {}, include_airflow_query_tags=False) is None + + def test_merges_airflow_and_user_tags(self): + result = build_query_tags(_make_context(), {"custom": "value"}, include_airflow_query_tags=True) + assert result is not None + assert result["airflow_dag_id"] == "test_dag" + assert result["custom"] == "value" + + def test_user_tags_override_airflow_tags_on_collision(self): + result = build_query_tags( + _make_context(), {"airflow_dag_id": "overridden"}, include_airflow_query_tags=True + ) + assert result is not None + assert result["airflow_dag_id"] == "overridden" + + +class TestDictToQueryTagList: + def test_converts_dict_to_key_value_list(self): + assert dict_to_query_tag_list({"a": "1", "b": "2"}) == [ + {"key": "a", "value": "1"}, + {"key": "b", "value": "2"}, + ] + + def test_omits_none_values(self): + assert dict_to_query_tag_list({"a": "1", "b": None}) == [{"key": "a", "value": "1"}] + + def test_empty_dict_returns_empty_list(self): + assert dict_to_query_tag_list({}) == []