diff --git a/containers/kafka/scripts/start-kafka.sh b/containers/kafka/scripts/start-kafka.sh index b7a4af13..2fb336aa 100755 --- a/containers/kafka/scripts/start-kafka.sh +++ b/containers/kafka/scripts/start-kafka.sh @@ -10,4 +10,4 @@ $KAFKA_HOME/bin/kafka-server-start.sh -daemon $KAFKA_HOME/config/server.properti $KAFKA_HOME/bin/kafka-topics.sh --create --bootstrap-server localhost:9092 --replication-factor 1 --partitions 1 --topic dq-sparkexpectations-stats # Keep the container running -tail -f /dev/null +tail -f /dev/null \ No newline at end of file diff --git a/docs/user_guide/file_based_rules.md b/docs/user_guide/file_based_rules.md new file mode 100644 index 00000000..abf63771 --- /dev/null +++ b/docs/user_guide/file_based_rules.md @@ -0,0 +1,279 @@ +# File-Based Rules (YAML & JSON) + +Instead of managing data quality rules in a database table with SQL INSERT statements, +you can define them in **YAML** or **JSON** files and load them directly into a Spark +DataFrame. This approach makes rules easy to version-control, review in pull requests, +and share across environments. + +## Loading Rules + +```python +from spark_expectations.rules import load_rules + +# Auto-detect format from file extension, selecting the "DEV" environment +rules_df = load_rules("path/to/rules.yaml", options={"dq_env": "DEV"}) +rules_df = load_rules("path/to/rules.json", options={"dq_env": "DEV"}) +``` + +You can also be explicit about the format or use the format-specific helpers: + +```python +from spark_expectations.rules import load_rules, load_rules_from_yaml, load_rules_from_json + +# Explicit format +rules_df = load_rules("path/to/rules.yaml", format="yaml", options={"dq_env": "DEV"}) + +# Convenience helpers +rules_df = load_rules_from_yaml("path/to/rules.yaml", options={"dq_env": "DEV"}) +rules_df = load_rules_from_json("path/to/rules.json", options={"dq_env": "DEV"}) +``` + +All loader functions accept an optional `spark` parameter. If omitted, the active +`SparkSession` is used automatically. You can pass one explicitly when needed: + +```python +from pyspark.sql import SparkSession + +spark = SparkSession.builder.getOrCreate() +rules_df = load_rules_from_yaml("path/to/rules.yaml", spark, options={"dq_env": "DEV"}) +``` + +The `options={"dq_env": ""}` parameter selects which environment block to use +when the rules file contains a `dq_env` section. See [Environment-Aware Rules](#environment-aware-rules-dq_env) +below. + +The returned `rules_df` has the same schema as the rules table and can be passed +directly to `SparkExpectations`: + +```python +from spark_expectations.core.expectations import SparkExpectations, WrappedDataFrameWriter + +se = SparkExpectations( + product_id="your_product", + rules_df=rules_df, # <-- loaded from YAML/JSON + stats_table="catalog.schema.stats", + stats_table_writer=WrappedDataFrameWriter().mode("append").format("delta"), + target_and_error_table_writer=WrappedDataFrameWriter().mode("append").format("delta"), +) +``` + +--- + +## Environment-Aware Rules (`dq_env`) + +The recommended format uses a `dq_env` section to define per-environment settings +such as `table_name`, `action_if_failed`, and `priority`. This lets you keep a +single rules file that works across dev, QA, and production by simply switching +the `dq_env` option at load time. + +=== "YAML" + + ```yaml + product_id: your_product + + dq_env: + DEV: + table_name: catalog_dev.schema.orders + action_if_failed: ignore + is_active: true + priority: medium + QA: + table_name: catalog_qa.schema.orders + action_if_failed: ignore + is_active: true + priority: medium + PROD: + table_name: catalog_prod.schema.orders + action_if_failed: fail + is_active: true + priority: high + + rules: + - rule: order_id_not_null + rule_type: row_dq + column_name: order_id + expectation: "order_id IS NOT NULL" + action_if_failed: drop + tag: completeness + description: "Order ID must not be null" + priority: high + + - rule: total_positive + rule_type: row_dq + column_name: total + expectation: "total > 0" + tag: validity + description: "Total must be positive" + + - rule: row_count + rule_type: agg_dq + expectation: "count(*) > 0" + action_if_failed: fail + tag: completeness + description: "Table must have rows" + ``` + +=== "JSON" + + ```json + { + "product_id": "your_product", + "dq_env": { + "DEV": { + "table_name": "catalog_dev.schema.orders", + "action_if_failed": "ignore", + "is_active": true, + "priority": "medium" + }, + "QA": { + "table_name": "catalog_qa.schema.orders", + "action_if_failed": "ignore", + "is_active": true, + "priority": "medium" + }, + "PROD": { + "table_name": "catalog_prod.schema.orders", + "action_if_failed": "fail", + "is_active": true, + "priority": "high" + } + }, + "rules": [ + { + "rule": "order_id_not_null", + "rule_type": "row_dq", + "column_name": "order_id", + "expectation": "order_id IS NOT NULL", + "action_if_failed": "drop", + "tag": "completeness", + "description": "Order ID must not be null", + "priority": "high" + }, + { + "rule": "total_positive", + "rule_type": "row_dq", + "column_name": "total", + "expectation": "total > 0", + "tag": "validity", + "description": "Total must be positive" + }, + { + "rule": "row_count", + "rule_type": "agg_dq", + "expectation": "count(*) > 0", + "action_if_failed": "fail", + "tag": "completeness", + "description": "Table must have rows" + } + ] + } + ``` + +**Structure:** + +- Top-level `product_id` identifies the product. +- `dq_env` is a mapping of environment names (e.g. `DEV`, `QA`, `PROD`) to + environment-specific settings. Each environment block can contain: + - `table_name` -- the table the rules apply to in that environment. + - Any default field (`action_if_failed`, `is_active`, `priority`, etc.) that + applies to all rules unless a rule overrides it. +- `rules` is a flat list of rule definitions. Each rule needs at least `rule` and + `expectation`. +- When loading, pass `options={"dq_env": ""}` to select the environment: + +```python +# For development +rules_df = load_rules_from_yaml("rules.yaml", options={"dq_env": "DEV"}) + +# For production +rules_df = load_rules_from_yaml("rules.yaml", options={"dq_env": "PROD"}) +``` + +**Defaults cascade:** built-in defaults --> `dq_env[]` values --> per-rule overrides. + +--- + +## Defaults + +The `dq_env` environment values let you set values that +apply to every rule unless a rule explicitly overrides them. This avoids repeating +common settings on every single rule. + +The built-in defaults (used when neither the file nor the rule specifies a value) are: + +| Field | Default | +|------------------------------------|-----------| +| `action_if_failed` | `ignore` | +| `enable_for_source_dq_validation` | `true` | +| `enable_for_target_dq_validation` | `true` | +| `is_active` | `true` | +| `enable_error_drop_alert` | `false` | +| `error_drop_threshold` | `0` | +| `priority` | `medium` | + +--- + +## Rules Schema Reference + +Every rule, regardless of input format, is normalised into a row with these 17 columns: + +| Column | Required | Description | +|------------------------------------|:--------:|---------------------------------------------------------------------------------------------------| +| `product_id` | Yes | Unique product identifier for DQ execution | +| `table_name` | Yes | The table the rule applies to | +| `rule_type` | Yes | `row_dq`, `agg_dq`, or `query_dq` | +| `rule` | Yes | Short name for the rule | +| `expectation` | Yes | The DQ rule condition (SQL expression) | +| `column_name` | | Column the rule applies to (relevant for `row_dq`) | +| `action_if_failed` | | `ignore`, `drop` (row_dq only), or `fail` | +| `tag` | | Category tag (e.g. `completeness`, `validity`) | +| `description` | | Human-readable description of the rule | +| `enable_for_source_dq_validation` | | Run agg/query rules on the source DataFrame | +| `enable_for_target_dq_validation` | | Run agg/query rules on the post-row_dq DataFrame | +| `is_active` | | Whether the rule is active | +| `enable_error_drop_alert` | | Send alert when rows are dropped | +| `error_drop_threshold` | | Threshold for error drop alerts | +| `query_dq_delimiter` | | Delimiter for custom query_dq alias queries (default `@`) | +| `enable_querydq_custom_output` | | Capture custom query output in a separate table | +| `priority` | | `low`, `medium`, or `high` | + +--- + +## Full Example + +Here is a complete example loading rules from YAML with `dq_env` and running DQ checks: + +```python +from pyspark.sql import DataFrame, SparkSession +from spark_expectations.rules import load_rules_from_yaml +from spark_expectations.core.expectations import SparkExpectations, WrappedDataFrameWriter +from spark_expectations.config.user_config import Constants as user_config + +spark = SparkSession.builder.getOrCreate() + +# Load rules from YAML, selecting the "DEV" environment +rules_df = load_rules_from_yaml("path/to/rules.yaml", spark, options={"dq_env": "DEV"}) + +# Configure writer and streaming +writer = WrappedDataFrameWriter().mode("append").format("delta") +streaming_config = {user_config.se_enable_streaming: False} + +se = SparkExpectations( + product_id="your_product", + rules_df=rules_df, + stats_table="catalog.schema.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + stats_streaming_options=streaming_config, +) + +@se.with_expectations( + target_table="catalog.schema.orders", + write_to_table=True, + write_to_temp_table=True, +) +def process_orders() -> DataFrame: + return spark.read.table("catalog.schema.raw_orders") + +process_orders() +``` diff --git a/examples/resources/sample_rules.json b/examples/resources/sample_rules.json new file mode 100644 index 00000000..30e12040 --- /dev/null +++ b/examples/resources/sample_rules.json @@ -0,0 +1,170 @@ +{ + "product_id": "your_product", + "dq_env": { + "DEV": { + "table_name": "dq_spark_dev.customer_order", + "action_if_failed": "ignore", + "enable_for_source_dq_validation": true, + "enable_for_target_dq_validation": false, + "is_active": true, + "enable_error_drop_alert": false, + "error_drop_threshold": 0, + "priority": "medium" + }, + "QA": { + "table_name": "dq_spark_qa.customer_order", + "action_if_failed": "ignore", + "enable_for_source_dq_validation": true, + "enable_for_target_dq_validation": false, + "is_active": true, + "enable_error_drop_alert": false, + "error_drop_threshold": 0, + "priority": "medium" + }, + "PROD": { + "table_name": "dq_spark_prod.customer_order", + "action_if_failed": "ignore", + "enable_for_source_dq_validation": true, + "enable_for_target_dq_validation": false, + "is_active": true, + "enable_error_drop_alert": false, + "error_drop_threshold": 0, + "priority": "medium" + } + }, + "rules": [ + { + "rule": "customer_id_not_null", + "rule_type": "row_dq", + "column_name": "customer_id", + "expectation": "customer_id IS NOT NULL", + "action_if_failed": "drop", + "tag": "completeness", + "description": "Customer ID must not be null", + "priority": "high" + }, + { + "rule": "order_id_not_null", + "rule_type": "row_dq", + "column_name": "order_id", + "expectation": "order_id IS NOT NULL", + "action_if_failed": "drop", + "tag": "completeness", + "description": "Order ID must not be null", + "priority": "high" + }, + { + "rule": "order_date_not_null", + "rule_type": "row_dq", + "column_name": "order_date", + "expectation": "order_date IS NOT NULL", + "tag": "completeness", + "description": "Order date must not be null" + }, + { + "rule": "product_id_not_null", + "rule_type": "row_dq", + "column_name": "product_id", + "expectation": "product_id IS NOT NULL", + "tag": "completeness", + "description": "Product ID must not be null" + }, + { + "rule": "quantity_positive", + "rule_type": "row_dq", + "column_name": "quantity", + "expectation": "quantity > 0", + "action_if_failed": "drop", + "tag": "validity", + "description": "Order quantity must be greater than zero", + "priority": "high" + }, + { + "rule": "sales_greater_than_zero", + "rule_type": "row_dq", + "column_name": "sales", + "expectation": "sales > 2", + "tag": "accuracy", + "description": "Sales value should be greater than 2" + }, + { + "rule": "discount_within_range", + "rule_type": "row_dq", + "column_name": "discount", + "expectation": "discount >= 0 AND discount * 100 < 60", + "action_if_failed": "drop", + "tag": "validity", + "description": "Discount must be between 0 and 0.6" + }, + { + "rule": "ship_mode_in_set", + "rule_type": "row_dq", + "column_name": "ship_mode", + "expectation": "lower(trim(ship_mode)) in ('second class', 'standard class')", + "action_if_failed": "drop", + "tag": "validity", + "description": "Ship mode must belong to the allowed set" + }, + { + "rule": "ship_date_after_order_date", + "rule_type": "row_dq", + "column_name": "ship_date", + "expectation": "ship_date IS NULL OR ship_date >= order_date", + "tag": "consistency", + "description": "Ship date must not be earlier than order date" + }, + { + "rule": "profit_threshold", + "rule_type": "row_dq", + "column_name": "profit", + "expectation": "profit > 0", + "tag": "accuracy", + "description": "Profit should be greater than zero" + }, + { + "rule": "table_row_count", + "rule_type": "agg_dq", + "expectation": "count(*) > 0", + "action_if_failed": "fail", + "tag": "completeness", + "description": "Table must contain at least one row", + "priority": "high" + }, + { + "rule": "customer_id_completeness_ratio", + "rule_type": "agg_dq", + "expectation": "count(customer_id) / count(*) > 0.99", + "action_if_failed": "fail", + "tag": "completeness", + "description": "At least 99% of rows must have a customer ID" + }, + { + "rule": "sum_of_sales", + "rule_type": "agg_dq", + "expectation": "sum(sales) > 10000", + "tag": "validity", + "description": "Total sales must exceed 10000" + }, + { + "rule": "sum_of_quantity", + "rule_type": "agg_dq", + "expectation": "sum(quantity) > 10000", + "tag": "validity", + "description": "Total quantity must exceed 10000" + }, + { + "rule": "distinct_of_ship_mode", + "rule_type": "agg_dq", + "expectation": "count(distinct ship_mode) <= 4", + "tag": "validity", + "description": "Ship mode should have at most 4 distinct values" + }, + { + "rule": "max_discount_within_bounds", + "rule_type": "agg_dq", + "expectation": "max(discount) <= 100", + "tag": "validity", + "description": "No discount should exceed 100%" + } + ] +} diff --git a/examples/resources/sample_rules.yaml b/examples/resources/sample_rules.yaml new file mode 100644 index 00000000..faa901d3 --- /dev/null +++ b/examples/resources/sample_rules.yaml @@ -0,0 +1,144 @@ +product_id: your_product + +dq_env: + DEV: + table_name: dq_spark_dev.customer_order + action_if_failed: ignore + enable_for_source_dq_validation: true + enable_for_target_dq_validation: false + is_active: true + enable_error_drop_alert: false + error_drop_threshold: 0 + priority: medium + QA: + table_name: dq_spark_qa.customer_order + action_if_failed: ignore + enable_for_source_dq_validation: true + enable_for_target_dq_validation: false + is_active: true + enable_error_drop_alert: false + error_drop_threshold: 0 + priority: medium + PROD: + table_name: dq_spark_prod.customer_order + action_if_failed: ignore + enable_for_source_dq_validation: true + enable_for_target_dq_validation: false + is_active: true + enable_error_drop_alert: false + error_drop_threshold: 0 + priority: medium + +rules: + # ── row_dq: completeness ── + - rule: customer_id_not_null + rule_type: row_dq + column_name: customer_id + expectation: "customer_id IS NOT NULL" + action_if_failed: drop + tag: completeness + description: "Customer ID must not be null" + priority: high + - rule: order_id_not_null + rule_type: row_dq + column_name: order_id + expectation: "order_id IS NOT NULL" + action_if_failed: drop + tag: completeness + description: "Order ID must not be null" + priority: high + - rule: order_date_not_null + rule_type: row_dq + column_name: order_date + expectation: "order_date IS NOT NULL" + tag: completeness + description: "Order date must not be null" + - rule: product_id_not_null + rule_type: row_dq + column_name: product_id + expectation: "product_id IS NOT NULL" + tag: completeness + description: "Product ID must not be null" + + # ── row_dq: validity ── + - rule: quantity_positive + rule_type: row_dq + column_name: quantity + expectation: "quantity > 0" + action_if_failed: drop + tag: validity + description: "Order quantity must be greater than zero" + priority: high + - rule: sales_greater_than_zero + rule_type: row_dq + column_name: sales + expectation: "sales > 2" + tag: accuracy + description: "Sales value should be greater than 2" + - rule: discount_within_range + rule_type: row_dq + column_name: discount + expectation: "discount >= 0 AND discount * 100 < 60" + action_if_failed: drop + tag: validity + description: "Discount must be between 0 and 0.6" + - rule: ship_mode_in_set + rule_type: row_dq + column_name: ship_mode + expectation: "lower(trim(ship_mode)) in ('second class', 'standard class')" + action_if_failed: drop + tag: validity + description: "Ship mode must belong to the allowed set" + + # ── row_dq: consistency ── + - rule: ship_date_after_order_date + rule_type: row_dq + column_name: ship_date + expectation: "ship_date IS NULL OR ship_date >= order_date" + tag: consistency + description: "Ship date must not be earlier than order date" + + # ── row_dq: accuracy ── + - rule: profit_threshold + rule_type: row_dq + column_name: profit + expectation: "profit > 0" + tag: accuracy + description: "Profit should be greater than zero" + + # ── agg_dq: completeness ── + - rule: table_row_count + rule_type: agg_dq + expectation: "count(*) > 0" + action_if_failed: fail + tag: completeness + description: "Table must contain at least one row" + priority: high + - rule: customer_id_completeness_ratio + rule_type: agg_dq + expectation: "count(customer_id) / count(*) > 0.99" + action_if_failed: fail + tag: completeness + description: "At least 99% of rows must have a customer ID" + + # ── agg_dq: validity ── + - rule: sum_of_sales + rule_type: agg_dq + expectation: "sum(sales) > 10000" + tag: validity + description: "Total sales must exceed 10000" + - rule: sum_of_quantity + rule_type: agg_dq + expectation: "sum(quantity) > 10000" + tag: validity + description: "Total quantity must exceed 10000" + - rule: distinct_of_ship_mode + rule_type: agg_dq + expectation: "count(distinct ship_mode) <= 4" + tag: validity + description: "Ship mode should have at most 4 distinct values" + - rule: max_discount_within_bounds + rule_type: agg_dq + expectation: "max(discount) <= 100" + tag: validity + description: "No discount should exceed 100%" diff --git a/examples/scripts/sample_dq_delta.py b/examples/scripts/sample_dq_delta.py index 24407472..c01d6608 100644 --- a/examples/scripts/sample_dq_delta.py +++ b/examples/scripts/sample_dq_delta.py @@ -29,6 +29,7 @@ } job_info = str(dic_job_info) +# --- Option 1: Load rules from a Delta table (default) --- se: SparkExpectations = SparkExpectations( product_id="your_product", rules_df=spark.table("dq_spark_dev.dq_rules"), diff --git a/examples/scripts/sample_dq_yaml_json.py b/examples/scripts/sample_dq_yaml_json.py new file mode 100644 index 00000000..2e13d80b --- /dev/null +++ b/examples/scripts/sample_dq_yaml_json.py @@ -0,0 +1,135 @@ +"""Example: Loading DQ rules from YAML or JSON files. + +This script demonstrates how to use ``spark_expectations.rules`` to load +data-quality rules from a YAML (or JSON) file instead of a Delta table. +The loaded rules are passed as a DataFrame to ``SparkExpectations``. +""" + +import os +from typing import Dict, Union + +from pyspark.sql import DataFrame + +from spark_expectations import _log +from spark_expectations.config.user_config import Constants as user_config +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) +from spark_expectations.rules import load_rules_from_yaml, load_rules_from_json + +from examples.scripts.base_setup import set_up_delta + +RESOURCES_DIR = os.path.join(os.path.dirname(__file__), "..", "resources") + +writer = WrappedDataFrameWriter().mode("append").format("delta") +spark = set_up_delta() + +dic_job_info = { + "job": "job_name", + "Region": "NA", + "env": "dev", + "Snapshot": "2024-04-15", + "data_object_name": "customer_order", +} +job_info = str(dic_job_info) + +your_product = 'your_product' +# ── Load rules from YAML ──────────────────────────────────────────────── +rules_df = load_rules_from_yaml( + os.path.join(RESOURCES_DIR, "sample_rules.yaml"), + spark, + options={"dq_env": "DEV"}, +) + +# ── Alternatively, load rules from JSON ───────────────────────────────── +# rules_df = load_rules_from_json( +# os.path.join(RESOURCES_DIR, "sample_rules.json"), +# spark, +# options={"dq_env": "DEV"}, +# ) + +# ── Or auto-detect the format from the file extension ─────────────────── +# from spark_expectations.rules import load_rules +# rules_df = load_rules( +# os.path.join(RESOURCES_DIR, "sample_rules.yaml"), +# spark, +# options={"dq_env": "DEV"}, +# ) + +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=rules_df, + stats_table="dq_spark_dev.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + stats_streaming_options={ + user_config.se_enable_streaming: False, + user_config.se_streaming_stats_topic_name: "dq-sparkexpectations-stats", + }, +) + +user_conf: Dict[str, Union[str, int, bool, Dict[str, str]]] = { + user_config.se_notifications_enable_email: False, + user_config.se_notifications_enable_slack: False, + user_config.se_notifications_on_start: False, + user_config.se_notifications_on_completion: False, + user_config.se_notifications_on_fail: False, + user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, + user_config.se_notifications_on_error_drop_threshold: 15, + user_config.se_enable_query_dq_detailed_result: True, + user_config.se_enable_agg_dq_detailed_result: True, + user_config.se_enable_error_table: True, + user_config.se_dq_rules_params: { + "env": "dev", + "table": "product", + "data_object_name": "customer_order", + "data_source": "customer_source", + "data_layer": "Integrated", + }, + user_config.se_job_metadata: job_info, +} + + +@se.with_expectations( + target_table="dq_spark_dev.customer_order", + write_to_table=True, + user_conf=user_conf, + target_table_view="order", +) +def build_new() -> DataFrame: + _df_order: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(RESOURCES_DIR, "order.csv")) + ) + _df_order.createOrReplaceTempView("order_source") + _df_order.createOrReplaceTempView("order_target") + + _df_product: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(RESOURCES_DIR, "product.csv")) + ) + _df_product.createOrReplaceTempView("product") + + _df_customer: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(RESOURCES_DIR, "customer_source.csv")) + ) + _df_customer.createOrReplaceTempView("customer_source") + _df_customer.createOrReplaceTempView("customer_target") + + return _df_order + + +if __name__ == "__main__": + build_new() + + spark.sql("use dq_spark_dev") + spark.sql("select * from dq_spark_dev.dq_stats").show(truncate=False) + spark.sql("select * from dq_spark_dev.customer_order").show(truncate=False) + + _log.info("DQ run with YAML rules completed.") diff --git a/mkdocs.yml b/mkdocs.yml index 844efebc..604a1de3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -53,6 +53,7 @@ nav: - Quickstart: user_guide/quickstart.md - Reference: - Data Quality Rules: user_guide/data_quality_rules.md + - File-Based Rules (YAML & JSON): user_guide/file_based_rules.md - Data Quality Metrics: user_guide/data_quality_metrics.md - User Config: - Example Config: user_guide/user_config/user_config.md diff --git a/pyproject.toml b/pyproject.toml index b606d713..91484277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ requires-python = ">=3.9,<=3.13" dependencies = [ "pluggy>=1", + "pyyaml>=6.0", "requests>=2.28.1", "sqlglot>=21.0,<23.0", ] @@ -75,7 +76,6 @@ dependencies = [ "pytest-mock~=3.14.0", "types-requests==2.28.11.16", "types-setuptools==67.7.0.2", - "pyyaml~=6.0.2", "types-PyYAML>=6.0.2", "setuptools==80.9.0" ] diff --git a/spark_expectations/rules/__init__.py b/spark_expectations/rules/__init__.py new file mode 100644 index 00000000..df9195b4 --- /dev/null +++ b/spark_expectations/rules/__init__.py @@ -0,0 +1,115 @@ +"""Rule-loader plugin system and convenience functions. + +Usage:: + + from spark_expectations.rules import load_rules, load_rules_from_yaml, load_rules_from_json + + rules_df = load_rules("path/to/rules.yaml", spark) + rules_df = load_rules_from_yaml("path/to/rules.yaml", spark) + rules_df = load_rules_from_json("path/to/rules.json", spark) +""" + +import functools +from typing import Dict, Optional + +import pluggy + +from pyspark.sql import DataFrame +from pyspark.sql.session import SparkSession + +from spark_expectations import _log +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules.plugins.base_rule_loader import ( + SPARK_EXPECTATIONS_RULE_LOADER_PLUGIN, + SparkExpectationsRuleLoader, +) +from spark_expectations.rules.plugins.json_loader import SparkExpectationsJsonRuleLoaderImpl +from spark_expectations.rules.plugins.yaml_loader import SparkExpectationsYamlRuleLoaderImpl + + +@functools.lru_cache +def get_rule_loader_hook() -> pluggy.PluginManager: + """Build and cache the rule-loader PluginManager.""" + pm = pluggy.PluginManager(SPARK_EXPECTATIONS_RULE_LOADER_PLUGIN) + pm.add_hookspecs(SparkExpectationsRuleLoader) + + pm.load_setuptools_entrypoints(SPARK_EXPECTATIONS_RULE_LOADER_PLUGIN) + + pm.register(SparkExpectationsYamlRuleLoaderImpl(), "spark_expectations_yaml_rule_loader") + pm.register(SparkExpectationsJsonRuleLoaderImpl(), "spark_expectations_json_rule_loader") + + for name, plugin_instance in pm.list_name_plugin(): + _log.info(f"Loaded rule-loader plugin: {name} ({plugin_instance.__class__.__name__})") + + return pm + + +_rule_loader_hook = get_rule_loader_hook().hook + + +def load_rules( + path: str, + spark: Optional[SparkSession] = None, + format: str = "auto", + options: Optional[Dict[str, str]] = None, +) -> DataFrame: + """Load rules from *path*, auto-detecting the format from the file extension. + + Args: + path: File path readable by Python (local, DBFS fuse, mounted volume). + spark: Optional SparkSession; if ``None`` the active session is used. + format: ``"auto"`` (detect from extension), ``"yaml"``, or ``"json"``. + options: Extra options forwarded to the loader plugin. + + Returns: + A Spark DataFrame with the standard rules schema. + + Raises: + SparkExpectationsUserInputOrConfigInvalidException: When no plugin can + handle the requested format or the file is invalid. + """ + result = _rule_loader_hook.load_rules(path=path, format=format, options=options or {}, spark=spark) + + if result is None: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"No rule-loader plugin could handle format='{format}' for path '{path}'. " + f"Supported formats: yaml, json." + ) + + return result + + +def load_rules_from_yaml( + path: str, + spark: Optional[SparkSession] = None, + options: Optional[Dict[str, str]] = None, +) -> DataFrame: + """Load rules from a YAML file. + + Args: + path: Path to a ``.yaml`` or ``.yml`` file. + spark: Optional SparkSession; if ``None`` the active session is used. + options: Extra options forwarded to the loader plugin. + + Returns: + A Spark DataFrame with the standard rules schema. + """ + return load_rules(path=path, spark=spark, format="yaml", options=options) + + +def load_rules_from_json( + path: str, + spark: Optional[SparkSession] = None, + options: Optional[Dict[str, str]] = None, +) -> DataFrame: + """Load rules from a JSON file. + + Args: + path: Path to a ``.json`` file. + spark: Optional SparkSession; if ``None`` the active session is used. + options: Extra options forwarded to the loader plugin. + + Returns: + A Spark DataFrame with the standard rules schema. + """ + return load_rules(path=path, spark=spark, format="json", options=options) diff --git a/spark_expectations/rules/plugins/__init__.py b/spark_expectations/rules/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spark_expectations/rules/plugins/_flatten.py b/spark_expectations/rules/plugins/_flatten.py new file mode 100644 index 00000000..2b08fe57 --- /dev/null +++ b/spark_expectations/rules/plugins/_flatten.py @@ -0,0 +1,265 @@ +"""Shared logic for converting rule definitions to flat row dicts, +and helpers for building the standard rules DataFrame schema. + +The recommended input format uses ``dq_env`` to define per-environment +settings (``table_name``, ``action_if_failed``, etc.) alongside a flat +``rules`` list:: + + product_id: my_product + dq_env: + DEV: + table_name: catalog_dev.schema.orders + action_if_failed: ignore + is_active: true + priority: medium + PROD: + table_name: catalog_prod.schema.orders + action_if_failed: fail + is_active: true + priority: high + rules: + - rule: col1_not_null + rule_type: row_dq + column_name: col1 + expectation: "col1 IS NOT NULL" + tag: completeness + +A simpler format without ``dq_env`` is also supported, using a top-level +``table_name`` and optional ``defaults``:: + + product_id: my_product + table_name: db.my_table + defaults: + action_if_failed: ignore + rules: + - rule: col1_not_null + rule_type: row_dq + expectation: "col1 IS NOT NULL" +""" + +from typing import Any, Dict, List, Optional + +from pyspark.sql import DataFrame +from pyspark.sql.session import SparkSession +from pyspark.sql.types import BooleanType, DataType, IntegerType, StringType, StructField, StructType + +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException + +VALID_RULE_TYPES = {"row_dq", "agg_dq", "query_dq"} + +RULES_SCHEMA_COLUMNS = [ + "product_id", + "table_name", + "rule_type", + "rule", + "column_name", + "expectation", + "action_if_failed", + "tag", + "description", + "enable_for_source_dq_validation", + "enable_for_target_dq_validation", + "is_active", + "enable_error_drop_alert", + "error_drop_threshold", + "query_dq_delimiter", + "enable_querydq_custom_output", + "priority", +] + +COLUMN_DEFAULTS: Dict[str, Any] = { + "column_name": "", + "expectation": "", + "action_if_failed": "ignore", + "tag": "", + "description": "", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": 0, + "query_dq_delimiter": None, + "enable_querydq_custom_output": False, + "priority": "medium", +} + +BOOLEAN_COLUMNS = { + "enable_for_source_dq_validation", + "enable_for_target_dq_validation", + "is_active", + "enable_error_drop_alert", + "enable_querydq_custom_output", +} + +INT_COLUMNS = { + "error_drop_threshold", +} + +REQUIRED_RULE_FIELDS = {"rule", "expectation"} + + +def _col_type(col: str) -> DataType: + """Return the Spark DataType for a rules-schema column.""" + if col in BOOLEAN_COLUMNS: + return BooleanType() + if col in INT_COLUMNS: + return IntegerType() + return StringType() + + +def rules_schema() -> StructType: + """Build the StructType for the standard 17-column rules DataFrame.""" + return StructType([StructField(col, _col_type(col), True) for col in RULES_SCHEMA_COLUMNS]) + + +def rows_to_dataframe(rows: List[Dict[str, Any]], spark: SparkSession) -> DataFrame: + """Convert a list of normalised row dicts into a Spark DataFrame.""" + return spark.createDataFrame(rows, schema=rules_schema()) + + +def flatten_rules_list( + data: Dict[str, Any], env: Optional[str] = None +) -> List[Dict[str, Any]]: + """Convert a rules-list definition into a flat list of row dicts. + + Expected structure (dq_env -- recommended):: + + product_id: ... + dq_env: + DEV: + table_name: ... + action_if_failed: ignore + ... + PROD: + table_name: ... + action_if_failed: fail + ... + rules: + - rule: ... + rule_type: row_dq + expectation: ... + + When ``dq_env`` is present the *env* parameter selects which + environment block supplies the ``table_name`` and default values. + Environment lookup is case-insensitive (``DEV``, ``dev``, ``Dev`` + all match). + + A simpler structure without ``dq_env`` is also supported:: + + product_id: ... + table_name: ... + defaults: + action_if_failed: ignore + rules: + - rule: ... + expectation: ... + + Returns: + List of dicts, each representing one rule row. + """ + product_id = data.get("product_id") + if not product_id: + raise SparkExpectationsUserInputOrConfigInvalidException( + "'product_id' is required at the top level of the rules file." + ) + + dq_env = data.get("dq_env") + if dq_env is not None: + if not isinstance(dq_env, dict) or not dq_env: + raise SparkExpectationsUserInputOrConfigInvalidException( + "'dq_env' must be a non-empty mapping of environment names to config." + ) + if not env: + raise SparkExpectationsUserInputOrConfigInvalidException( + "'dq_env' is present in the rules file but no environment was " + "specified. Pass the environment via options={'dq_env': ''}." + ) + env_lower = env.lower() + env_map = {k.lower(): v for k, v in dq_env.items()} + env_config = env_map.get(env_lower) + if not env_config or not isinstance(env_config, dict): + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Environment '{env}' not found in 'dq_env'. " + f"Available environments: {sorted(dq_env.keys())}." + ) + table_name = env_config.get("table_name", "") + env_defaults = {k: v for k, v in env_config.items() if k != "table_name"} + user_defaults = {**(data.get("defaults") or {}), **env_defaults} + else: + table_name = data.get("table_name", "") + user_defaults = data.get("defaults") or {} + + merged_defaults = {**COLUMN_DEFAULTS, **user_defaults} + + rules_list = data.get("rules") + if not rules_list or not isinstance(rules_list, list): + raise SparkExpectationsUserInputOrConfigInvalidException( + "'rules' must be a non-empty list of rule definitions." + ) + + rows: List[Dict[str, Any]] = [] + for rule_def in rules_list: + if not isinstance(rule_def, dict): + raise SparkExpectationsUserInputOrConfigInvalidException( + "Each entry in 'rules' must be a dict." + ) + + missing = REQUIRED_RULE_FIELDS - set(rule_def.keys()) + if missing: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rule '{rule_def.get('rule', '')}' is missing required fields: {sorted(missing)}." + ) + + row = {**merged_defaults, **rule_def} + row["product_id"] = product_id + if "table_name" not in rule_def and table_name: + row["table_name"] = table_name + + rule_type = row.get("rule_type", "") + if not rule_type: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rule '{row.get('rule')}' is missing 'rule_type'. " + f"Must be one of {sorted(VALID_RULE_TYPES)}." + ) + if rule_type not in VALID_RULE_TYPES: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Invalid rule_type '{rule_type}' for rule '{row.get('rule')}'. " + f"Must be one of {sorted(VALID_RULE_TYPES)}." + ) + + rows.append(_normalise_row(row)) + + return rows + + +# ── helpers ────────────────────────────────────────────────────────────── + + +def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]: + """Ensure all schema columns are present with their correct types.""" + normalised: Dict[str, Any] = {} + for col in RULES_SCHEMA_COLUMNS: + value = row.get(col, COLUMN_DEFAULTS.get(col, "")) + normalised[col] = _cast_value(col, value) + return normalised + + +def _cast_value(col: str, value: Any) -> Any: + """Cast a value to its expected type based on the column name.""" + if value is None: + return COLUMN_DEFAULTS.get(col, False if col in BOOLEAN_COLUMNS else 0 if col in INT_COLUMNS else "") + + if col in BOOLEAN_COLUMNS: + if isinstance(value, bool): + return value + return value.lower() in ("true", "1", "yes") if isinstance(value, str) else bool(value) + + if col in INT_COLUMNS: + try: + return int(value) + except (ValueError, TypeError) as exc: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Column '{col}' expects an integer value, got: {value!r}" + ) from exc + + return str(value) diff --git a/spark_expectations/rules/plugins/base_rule_loader.py b/spark_expectations/rules/plugins/base_rule_loader.py new file mode 100644 index 00000000..323c8511 --- /dev/null +++ b/spark_expectations/rules/plugins/base_rule_loader.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Dict, Optional + +import pluggy + +from pyspark.sql import DataFrame +from pyspark.sql.session import SparkSession + +SPARK_EXPECTATIONS_RULE_LOADER_PLUGIN = "spark_expectations_rule_loader_plugins" + +rule_loader_plugin_spec = pluggy.HookspecMarker(SPARK_EXPECTATIONS_RULE_LOADER_PLUGIN) +spark_expectations_rule_loader_impl = pluggy.HookimplMarker(SPARK_EXPECTATIONS_RULE_LOADER_PLUGIN) + + +class SparkExpectationsRuleLoader: + """Base hook spec for rule loader plugins. + + Each plugin should check whether it can handle the given format/path + and return a DataFrame with the standard rules schema, or None if + it cannot handle the request. + """ + + @rule_loader_plugin_spec(firstresult=True) + def load_rules( + self, + path: str, + format: str, + options: Dict[str, str], + spark: Optional[SparkSession] = None, + ) -> Optional[DataFrame]: + """Load rules from *path* and return a Spark DataFrame. + + Args: + path: File path readable by Python (local, DBFS fuse, mounted volume). + format: Requested format (``yaml``, ``json``, or ``auto``). + options: Extra loader-specific options forwarded by the caller. + spark: Optional SparkSession; if ``None`` the active session is used. + + Returns: + A Spark DataFrame with the standard rules schema, or ``None`` + when this plugin cannot handle the requested format. + """ diff --git a/spark_expectations/rules/plugins/json_loader.py b/spark_expectations/rules/plugins/json_loader.py new file mode 100644 index 00000000..637fe480 --- /dev/null +++ b/spark_expectations/rules/plugins/json_loader.py @@ -0,0 +1,72 @@ +"""Pluggy plugin that loads DQ rules from a JSON file.""" + +import json +import os +from typing import Any, Dict, Optional + +from pyspark.sql import DataFrame +from pyspark.sql.session import SparkSession + +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules.plugins._flatten import ( + flatten_rules_list, + rows_to_dataframe, +) +from spark_expectations.rules.plugins.base_rule_loader import spark_expectations_rule_loader_impl + +JSON_EXTENSIONS = {".json"} + + +def _read_json(path: str) -> Dict[str, Any]: + try: + with open(path, "r", encoding="utf-8") as fh: + data = json.load(fh) + except FileNotFoundError as exc: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rules JSON file not found: {path}" + ) from exc + except json.JSONDecodeError as exc: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Error parsing JSON rules file '{path}': {exc}" + ) from exc + + if data is None: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rules JSON file is empty or null: {path}" + ) + if not isinstance(data, dict): + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rules file must contain an object at the top level, got {type(data).__name__}: {path}" + ) + return data + + +class SparkExpectationsJsonRuleLoaderImpl: + """Loads DQ rules from a JSON file.""" + + @spark_expectations_rule_loader_impl + def load_rules( # pylint: disable=unused-argument + self, + path: str, + format: str, + options: Dict[str, str], + spark: Optional[SparkSession] = None, + ) -> Optional[DataFrame]: + if format == "json": + pass + elif format == "auto" and os.path.splitext(path)[1].lower() in JSON_EXTENSIONS: + pass + else: + return None + + if spark is None: + spark = SparkSession.getActiveSession() + if spark is None: + raise SparkExpectationsUserInputOrConfigInvalidException( + "No active SparkSession found. Please create a SparkSession before loading rules." + ) + + data = _read_json(path) + env = (options or {}).get("dq_env") + rows = flatten_rules_list(data, env=env) + return rows_to_dataframe(rows, spark) diff --git a/spark_expectations/rules/plugins/yaml_loader.py b/spark_expectations/rules/plugins/yaml_loader.py new file mode 100644 index 00000000..06d300fa --- /dev/null +++ b/spark_expectations/rules/plugins/yaml_loader.py @@ -0,0 +1,73 @@ +"""Pluggy plugin that loads DQ rules from a YAML file.""" + +import os +from typing import Any, Dict, Optional + +import yaml + +from pyspark.sql import DataFrame +from pyspark.sql.session import SparkSession + +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules.plugins._flatten import ( + flatten_rules_list, + rows_to_dataframe, +) +from spark_expectations.rules.plugins.base_rule_loader import spark_expectations_rule_loader_impl + +YAML_EXTENSIONS = {".yaml", ".yml"} + + +def _read_yaml(path: str) -> Dict[str, Any]: + try: + with open(path, "r", encoding="utf-8") as fh: + data = yaml.safe_load(fh) + except FileNotFoundError as exc: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rules YAML file not found: {path}" + ) from exc + except yaml.YAMLError as exc: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Error parsing YAML rules file '{path}': {exc}" + ) from exc + + if data is None: + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rules YAML file is empty: {path}" + ) + if not isinstance(data, dict): + raise SparkExpectationsUserInputOrConfigInvalidException( + f"Rules file must contain a mapping at the top level, got {type(data).__name__}: {path}" + ) + return data + + +class SparkExpectationsYamlRuleLoaderImpl: + """Loads DQ rules from a YAML file.""" + + @spark_expectations_rule_loader_impl + def load_rules( # pylint: disable=unused-argument + self, + path: str, + format: str, + options: Dict[str, str], + spark: Optional[SparkSession] = None, + ) -> Optional[DataFrame]: + if format == "yaml": + pass + elif format == "auto" and os.path.splitext(path)[1].lower() in YAML_EXTENSIONS: + pass + else: + return None + + if spark is None: + spark = SparkSession.getActiveSession() + if spark is None: + raise SparkExpectationsUserInputOrConfigInvalidException( + "No active SparkSession found. Please create a SparkSession before loading rules." + ) + + data = _read_yaml(path) + env = (options or {}).get("dq_env") + rows = flatten_rules_list(data, env=env) + return rows_to_dataframe(rows, spark) diff --git a/tests/unit/rules/__init__.py b/tests/unit/rules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/rules/plugins/__init__.py b/tests/unit/rules/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/rules/plugins/test_flatten.py b/tests/unit/rules/plugins/test_flatten.py new file mode 100644 index 00000000..3fb468a9 --- /dev/null +++ b/tests/unit/rules/plugins/test_flatten.py @@ -0,0 +1,452 @@ +"""Tests for spark_expectations.rules.plugins._flatten""" + +import pytest + +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules.plugins._flatten import ( + COLUMN_DEFAULTS, + RULES_SCHEMA_COLUMNS, + _cast_value, + flatten_rules_list, +) + + +@pytest.fixture +def minimal_rules_list(): + return { + "product_id": "prod1", + "table_name": "db.table1", + "rules": [ + { + "rule": "col1_not_null", + "rule_type": "row_dq", + "expectation": "col1 IS NOT NULL", + "column_name": "col1", + "tag": "completeness", + } + ], + } + + +@pytest.fixture +def minimal_dq_env_rules(): + return { + "product_id": "prod1", + "dq_env": { + "DEV": { + "table_name": "catalog.schema.orders", + "action_if_failed": "ignore", + "is_active": True, + "priority": "medium", + }, + "QA": { + "table_name": "catalog2.schema.orders", + "action_if_failed": "ignore", + "is_active": True, + "priority": "medium", + }, + "PROD": { + "table_name": "catalog3.schema.orders", + "action_if_failed": "fail", + "is_active": True, + "priority": "high", + }, + }, + "rules": [ + { + "rule": "col1_not_null", + "rule_type": "row_dq", + "expectation": "col1 IS NOT NULL", + "column_name": "col1", + "tag": "completeness", + } + ], + } + + +# ── rules-list format (classic with table_name) ──────────────────────── + + +def test_flatten_rules_list_basic(minimal_rules_list): + rows = flatten_rules_list(minimal_rules_list) + assert len(rows) == 1 + row = rows[0] + assert row["product_id"] == "prod1" + assert row["table_name"] == "db.table1" + assert row["rule_type"] == "row_dq" + assert row["rule"] == "col1_not_null" + assert row["expectation"] == "col1 IS NOT NULL" + + +def test_flatten_rules_list_all_schema_columns_present(minimal_rules_list): + rows = flatten_rules_list(minimal_rules_list) + for col in RULES_SCHEMA_COLUMNS: + assert col in rows[0], f"Missing column: {col}" + + +def test_flatten_rules_list_defaults_applied(minimal_rules_list): + rows = flatten_rules_list(minimal_rules_list) + row = rows[0] + assert row["action_if_failed"] == COLUMN_DEFAULTS["action_if_failed"] + assert row["is_active"] == COLUMN_DEFAULTS["is_active"] + assert row["priority"] == COLUMN_DEFAULTS["priority"] + + +def test_flatten_rules_list_user_defaults_override_column_defaults(): + data = { + "product_id": "prod1", + "table_name": "t1", + "defaults": {"action_if_failed": "fail", "priority": "high"}, + "rules": [{"rule": "r1", "rule_type": "row_dq", "expectation": "x > 0"}], + } + rows = flatten_rules_list(data) + assert rows[0]["action_if_failed"] == "fail" + assert rows[0]["priority"] == "high" + + +def test_flatten_rules_list_rule_level_overrides_defaults(): + data = { + "product_id": "prod1", + "table_name": "t1", + "defaults": {"action_if_failed": "ignore"}, + "rules": [ + { + "rule": "r1", + "rule_type": "row_dq", + "expectation": "x > 0", + "action_if_failed": "drop", + } + ], + } + rows = flatten_rules_list(data) + assert rows[0]["action_if_failed"] == "drop" + + +def test_flatten_rules_list_rule_type_from_defaults(): + data = { + "product_id": "prod1", + "table_name": "t1", + "defaults": {"rule_type": "row_dq"}, + "rules": [{"rule": "r1", "expectation": "x > 0"}], + } + rows = flatten_rules_list(data) + assert rows[0]["rule_type"] == "row_dq" + + +def test_flatten_rules_list_rule_type_per_rule_overrides_default(): + data = { + "product_id": "prod1", + "table_name": "t1", + "defaults": {"rule_type": "row_dq"}, + "rules": [ + {"rule": "r1", "expectation": "x > 0"}, + {"rule": "r2", "rule_type": "agg_dq", "expectation": "count(*) > 0"}, + ], + } + rows = flatten_rules_list(data) + assert rows[0]["rule_type"] == "row_dq" + assert rows[1]["rule_type"] == "agg_dq" + + +def test_flatten_rules_list_table_name_override_per_rule(): + data = { + "product_id": "prod1", + "table_name": "default_table", + "rules": [ + {"rule": "r1", "rule_type": "row_dq", "expectation": "x > 0"}, + {"rule": "r2", "rule_type": "row_dq", "expectation": "y > 0", "table_name": "other_table"}, + ], + } + rows = flatten_rules_list(data) + assert rows[0]["table_name"] == "default_table" + assert rows[1]["table_name"] == "other_table" + + +def test_flatten_rules_list_mixed_rule_types(): + data = { + "product_id": "prod1", + "table_name": "t1", + "rules": [ + {"rule": "r1", "rule_type": "row_dq", "expectation": "c1 > 0"}, + {"rule": "r2", "rule_type": "agg_dq", "expectation": "count(*) > 10"}, + {"rule": "r3", "rule_type": "query_dq", "expectation": "SELECT 1"}, + ], + } + rows = flatten_rules_list(data) + assert len(rows) == 3 + assert {r["rule_type"] for r in rows} == {"row_dq", "agg_dq", "query_dq"} + + +def test_flatten_rules_list_missing_product_id_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="product_id"): + flatten_rules_list({"rules": [{"rule": "r", "expectation": "x"}]}) + + +def test_flatten_rules_list_empty_rules_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="non-empty"): + flatten_rules_list({"product_id": "p1", "rules": []}) + + +def test_flatten_rules_list_missing_rule_field_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="missing required"): + flatten_rules_list({ + "product_id": "p1", + "rules": [{"expectation": "x > 0"}], + }) + + +def test_flatten_rules_list_missing_expectation_field_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="missing required"): + flatten_rules_list({ + "product_id": "p1", + "rules": [{"rule": "r1"}], + }) + + +def test_flatten_rules_list_missing_rule_type_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="missing 'rule_type'"): + flatten_rules_list({ + "product_id": "p1", + "table_name": "t1", + "rules": [{"rule": "r1", "expectation": "x > 0"}], + }) + + +def test_flatten_rules_list_invalid_rule_type_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="Invalid rule_type"): + flatten_rules_list({ + "product_id": "p1", + "rules": [{"rule": "r1", "rule_type": "bad_type", "expectation": "x > 0"}], + }) + + +def test_flatten_rules_list_boolean_values_are_native(minimal_rules_list): + rows = flatten_rules_list(minimal_rules_list) + row = rows[0] + assert isinstance(row["is_active"], bool) + assert isinstance(row["enable_error_drop_alert"], bool) + + +def test_flatten_rules_list_no_table_name_at_top_or_rule(): + data = { + "product_id": "p1", + "rules": [{"rule": "r1", "rule_type": "row_dq", "expectation": "x > 0"}], + } + rows = flatten_rules_list(data) + assert rows[0]["table_name"] == "" + + +# ── rules-list format (dq_env) ───────────────────────────────────────── + + +def test_flatten_rules_list_dq_env_basic(minimal_dq_env_rules): + rows = flatten_rules_list(minimal_dq_env_rules, env="DEV") + assert len(rows) == 1 + row = rows[0] + assert row["product_id"] == "prod1" + assert row["table_name"] == "catalog.schema.orders" + assert row["rule"] == "col1_not_null" + assert row["action_if_failed"] == "ignore" + assert row["priority"] == "medium" + + +def test_flatten_rules_list_dq_env_selects_correct_env(minimal_dq_env_rules): + rows_dev = flatten_rules_list(minimal_dq_env_rules, env="DEV") + rows_prod = flatten_rules_list(minimal_dq_env_rules, env="PROD") + assert rows_dev[0]["table_name"] == "catalog.schema.orders" + assert rows_prod[0]["table_name"] == "catalog3.schema.orders" + assert rows_prod[0]["action_if_failed"] == "fail" + assert rows_prod[0]["priority"] == "high" + + +def test_flatten_rules_list_dq_env_qa_env(minimal_dq_env_rules): + rows = flatten_rules_list(minimal_dq_env_rules, env="QA") + assert rows[0]["table_name"] == "catalog2.schema.orders" + + +def test_flatten_rules_list_dq_env_all_schema_columns(minimal_dq_env_rules): + rows = flatten_rules_list(minimal_dq_env_rules, env="DEV") + for col in RULES_SCHEMA_COLUMNS: + assert col in rows[0], f"Missing column: {col}" + + +def test_flatten_rules_list_dq_env_rule_overrides_env_defaults(): + data = { + "product_id": "prod1", + "dq_env": { + "DEV": { + "table_name": "t1", + "action_if_failed": "ignore", + "priority": "medium", + }, + }, + "rules": [ + { + "rule": "r1", + "rule_type": "row_dq", + "expectation": "x > 0", + "action_if_failed": "drop", + "priority": "high", + } + ], + } + rows = flatten_rules_list(data, env="DEV") + assert rows[0]["action_if_failed"] == "drop" + assert rows[0]["priority"] == "high" + + +def test_flatten_rules_list_dq_env_case_insensitive_lookup(minimal_dq_env_rules): + for env_value in ("DEV", "dev", "Dev", "dEv"): + rows = flatten_rules_list(minimal_dq_env_rules, env=env_value) + assert rows[0]["table_name"] == "catalog.schema.orders" + + +def test_flatten_rules_list_dq_env_lowercase_option_matches_uppercase_key(): + data = { + "product_id": "prod1", + "dq_env": { + "DEV": { + "table_name": "dev.orders", + "priority": "medium", + }, + }, + "rules": [ + {"rule": "r1", "rule_type": "row_dq", "expectation": "x > 0"} + ], + } + rows = flatten_rules_list(data, env="dev") + assert rows[0]["table_name"] == "dev.orders" + + +def test_flatten_rules_list_dq_env_no_env_raises(minimal_dq_env_rules): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="no environment"): + flatten_rules_list(minimal_dq_env_rules) + + +def test_flatten_rules_list_dq_env_missing_env_raises(minimal_dq_env_rules): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="not found"): + flatten_rules_list(minimal_dq_env_rules, env="staging") + + +def test_flatten_rules_list_dq_env_empty_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="non-empty mapping"): + flatten_rules_list({ + "product_id": "p1", + "dq_env": {}, + "rules": [{"rule": "r1", "expectation": "x > 0"}], + }, env="DEV") + + +def test_flatten_rules_list_dq_env_not_dict_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="non-empty mapping"): + flatten_rules_list({ + "product_id": "p1", + "dq_env": "not_a_dict", + "rules": [{"rule": "r1", "expectation": "x > 0"}], + }, env="DEV") + + +def test_flatten_rules_list_dq_env_with_multiple_rules(): + data = { + "product_id": "prod1", + "dq_env": { + "DEV": { + "table_name": "dev.orders", + "action_if_failed": "ignore", + "is_active": True, + "priority": "medium", + }, + }, + "rules": [ + { + "rule": "r1", + "rule_type": "row_dq", + "expectation": "col1 IS NOT NULL", + "action_if_failed": "drop", + "priority": "high", + }, + { + "rule": "r2", + "rule_type": "agg_dq", + "expectation": "count(*) > 0", + "action_if_failed": "fail", + }, + ], + } + rows = flatten_rules_list(data, env="DEV") + assert len(rows) == 2 + assert rows[0]["table_name"] == "dev.orders" + assert rows[0]["action_if_failed"] == "drop" + assert rows[0]["priority"] == "high" + assert rows[1]["table_name"] == "dev.orders" + assert rows[1]["action_if_failed"] == "fail" + assert rows[1]["priority"] == "medium" + + +# ── helper coverage ──────────────────────────────────────────────────── + + +def test_flatten_rules_list_non_dict_rule_entry_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="must be a dict"): + flatten_rules_list({ + "product_id": "p1", + "rules": ["not_a_dict"], + }) + + +def test_flatten_rules_list_rules_not_a_list_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="non-empty"): + flatten_rules_list({ + "product_id": "p1", + "rules": "not_a_list", + }) + + +def test_flatten_rules_list_non_numeric_error_drop_threshold_raises(): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="expects an integer"): + flatten_rules_list({ + "product_id": "p1", + "table_name": "t1", + "rules": [ + { + "rule": "r1", + "rule_type": "row_dq", + "expectation": "x > 0", + "error_drop_threshold": "abc", + } + ], + }) + + +# ── _cast_value coverage ─────────────────────────────────────────────── + + +def test_cast_value_none_boolean_column(): + assert _cast_value("is_active", None) is True + assert _cast_value("enable_error_drop_alert", None) is False + + +def test_cast_value_none_int_column(): + assert _cast_value("error_drop_threshold", None) == 0 + + +def test_cast_value_none_string_column(): + assert _cast_value("description", None) == "" + assert _cast_value("tag", None) == "" + + +def test_cast_value_boolean_from_string(): + assert _cast_value("is_active", "true") is True + assert _cast_value("is_active", "True") is True + assert _cast_value("is_active", "1") is True + assert _cast_value("is_active", "yes") is True + assert _cast_value("is_active", "false") is False + assert _cast_value("is_active", "0") is False + assert _cast_value("is_active", "no") is False + + +def test_cast_value_boolean_from_non_bool_non_str(): + assert _cast_value("is_active", 1) is True + assert _cast_value("is_active", 0) is False + + diff --git a/tests/unit/rules/plugins/test_json_loader.py b/tests/unit/rules/plugins/test_json_loader.py new file mode 100644 index 00000000..c3059694 --- /dev/null +++ b/tests/unit/rules/plugins/test_json_loader.py @@ -0,0 +1,260 @@ +"""Tests for spark_expectations.rules.plugins.json_loader""" + +import json +import os +from unittest.mock import patch + +import pytest + +from spark_expectations.core import get_spark_session +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules.plugins._flatten import RULES_SCHEMA_COLUMNS +from spark_expectations.rules.plugins.json_loader import SparkExpectationsJsonRuleLoaderImpl + +get_spark_session() + +SAMPLE_RULES_JSON = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "..", "examples", "resources", "sample_rules.json" +) + + +@pytest.fixture +def json_loader(): + return SparkExpectationsJsonRuleLoaderImpl() + + +@pytest.fixture +def dq_env_json_file(tmp_path): + data = { + "product_id": "test_product", + "dq_env": { + "DEV": { + "table_name": "db.test_table", + "action_if_failed": "ignore", + "is_active": True, + "priority": "medium", + }, + "QA": { + "table_name": "db_qa.test_table", + "action_if_failed": "ignore", + "is_active": True, + "priority": "medium", + }, + "PROD": { + "table_name": "db_prod.test_table", + "action_if_failed": "fail", + "is_active": True, + "priority": "high", + }, + }, + "rules": [ + { + "rule": "col1_not_null", + "rule_type": "row_dq", + "column_name": "col1", + "expectation": "col1 IS NOT NULL", + "action_if_failed": "drop", + "tag": "completeness", + "description": "col1 must not be null", + }, + { + "rule": "col2_positive", + "rule_type": "row_dq", + "column_name": "col2", + "expectation": "col2 > 0", + "tag": "validity", + "description": "col2 must be positive", + }, + { + "rule": "row_count", + "rule_type": "agg_dq", + "column_name": "", + "expectation": "count(*) > 0", + "action_if_failed": "fail", + "tag": "completeness", + "description": "Must have rows", + }, + ], + } + path = tmp_path / "rules.json" + path.write_text(json.dumps(data, indent=2)) + return str(path) + + +@pytest.fixture +def rules_list_json_file(tmp_path): + data = { + "product_id": "test_product", + "table_name": "db.test_table", + "defaults": { + "action_if_failed": "ignore", + "is_active": True, + "priority": "medium", + }, + "rules": [ + { + "rule": "col1_not_null", + "rule_type": "row_dq", + "column_name": "col1", + "expectation": "col1 IS NOT NULL", + "action_if_failed": "drop", + "tag": "completeness", + "description": "col1 must not be null", + }, + { + "rule": "col2_positive", + "rule_type": "row_dq", + "column_name": "col2", + "expectation": "col2 > 0", + "tag": "validity", + "description": "col2 must be positive", + }, + { + "rule": "row_count", + "rule_type": "agg_dq", + "column_name": "", + "expectation": "count(*) > 0", + "action_if_failed": "fail", + "tag": "completeness", + "description": "Must have rows", + }, + ], + } + path = tmp_path / "rules.json" + path.write_text(json.dumps(data, indent=2)) + return str(path) + + +def test_returns_none_for_non_json_format(json_loader): + result = json_loader.load_rules(path="rules.yaml", format="yaml", options={}) + assert result is None + + +def test_returns_none_for_auto_non_json_extension(json_loader): + result = json_loader.load_rules(path="rules.yaml", format="auto", options={}) + assert result is None + + +def test_handles_json_format_explicit(json_loader, dq_env_json_file): + df = json_loader.load_rules(path=dq_env_json_file, format="json", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 3 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS) + + +def test_handles_auto_json_extension(json_loader, dq_env_json_file): + df = json_loader.load_rules(path=dq_env_json_file, format="auto", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 3 + + +def test_dq_env_values(json_loader, dq_env_json_file): + df = json_loader.load_rules(path=dq_env_json_file, format="json", options={"dq_env": "DEV"}) + rows = [r.asDict() for r in df.collect()] + row_dq_rules = [r for r in rows if r["rule_type"] == "row_dq"] + agg_dq_rules = [r for r in rows if r["rule_type"] == "agg_dq"] + + assert len(row_dq_rules) == 2 + assert len(agg_dq_rules) == 1 + + col1_rule = next(r for r in row_dq_rules if r["rule"] == "col1_not_null") + assert col1_rule["product_id"] == "test_product" + assert col1_rule["table_name"] == "db.test_table" + assert col1_rule["action_if_failed"] == "drop" + + col2_rule = next(r for r in row_dq_rules if r["rule"] == "col2_positive") + assert col2_rule["action_if_failed"] == "ignore" + assert col2_rule["priority"] == "medium" + + +def test_dq_env_selects_prod(json_loader, dq_env_json_file): + df = json_loader.load_rules(path=dq_env_json_file, format="json", options={"dq_env": "PROD"}) + rows = [r.asDict() for r in df.collect()] + col2_rule = next(r for r in rows if r["rule"] == "col2_positive") + assert col2_rule["table_name"] == "db_prod.test_table" + assert col2_rule["action_if_failed"] == "fail" + assert col2_rule["priority"] == "high" + + +def test_rules_list_values(json_loader, rules_list_json_file): + df = json_loader.load_rules(path=rules_list_json_file, format="json", options={}) + rows = [r.asDict() for r in df.collect()] + row_dq_rules = [r for r in rows if r["rule_type"] == "row_dq"] + agg_dq_rules = [r for r in rows if r["rule_type"] == "agg_dq"] + + assert len(row_dq_rules) == 2 + assert len(agg_dq_rules) == 1 + + col1_rule = next(r for r in row_dq_rules if r["rule"] == "col1_not_null") + assert col1_rule["product_id"] == "test_product" + assert col1_rule["table_name"] == "db.test_table" + assert col1_rule["action_if_failed"] == "drop" + + col2_rule = next(r for r in row_dq_rules if r["rule"] == "col2_positive") + assert col2_rule["action_if_failed"] == "ignore" + assert col2_rule["priority"] == "medium" + + +def test_file_not_found_raises(json_loader): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="not found"): + json_loader.load_rules(path="/nonexistent/rules.json", format="json", options={}) + + +def test_invalid_json_raises(json_loader, tmp_path): + bad = tmp_path / "bad.json" + bad.write_text("{invalid json}") + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="Error parsing"): + json_loader.load_rules(path=str(bad), format="json", options={}) + + +def test_empty_json_raises(json_loader, tmp_path): + empty = tmp_path / "empty.json" + empty.write_text("null") + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="empty"): + json_loader.load_rules(path=str(empty), format="json", options={}) + + +def test_sample_rules_json_loads(json_loader): + """Ensure the shipped sample_rules.json example loads correctly.""" + df = json_loader.load_rules(path=SAMPLE_RULES_JSON, format="json", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 16 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS) + + +def test_sample_rules_json_values(json_loader): + """Verify key fields from the sample JSON are parsed correctly.""" + df = json_loader.load_rules(path=SAMPLE_RULES_JSON, format="json", options={"dq_env": "DEV"}) + rows = {r["rule"]: r.asDict() for r in df.collect()} + assert rows["customer_id_not_null"]["action_if_failed"] == "drop" + assert rows["customer_id_not_null"]["rule_type"] == "row_dq" + assert rows["customer_id_not_null"]["priority"] == "high" + assert rows["order_date_not_null"]["action_if_failed"] == "ignore" + assert rows["table_row_count"]["rule_type"] == "agg_dq" + assert rows["table_row_count"]["action_if_failed"] == "fail" + + +def test_sample_rules_json_env_table_name(json_loader): + """Verify that table_name changes per environment.""" + df_dev = json_loader.load_rules(path=SAMPLE_RULES_JSON, format="json", options={"dq_env": "DEV"}) + df_qa = json_loader.load_rules(path=SAMPLE_RULES_JSON, format="json", options={"dq_env": "QA"}) + rows_dev = {r["rule"]: r.asDict() for r in df_dev.collect()} + rows_qa = {r["rule"]: r.asDict() for r in df_qa.collect()} + assert rows_dev["customer_id_not_null"]["table_name"] == "dq_spark_dev.customer_order" + assert rows_qa["customer_id_not_null"]["table_name"] == "dq_spark_qa.customer_order" + + +def test_no_active_spark_session_raises(json_loader, dq_env_json_file): + """Verify error when no SparkSession is active.""" + with patch("spark_expectations.rules.plugins.json_loader.SparkSession") as mock_spark: + mock_spark.getActiveSession.return_value = None + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="No active SparkSession"): + json_loader.load_rules(path=dq_env_json_file, format="json", options={"dq_env": "DEV"}) + + +def test_non_dict_json_raises(json_loader, tmp_path): + """Verify error when JSON top-level is not an object.""" + list_json = tmp_path / "list.json" + list_json.write_text('[1, 2, 3]') + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="object at the top level"): + json_loader.load_rules(path=str(list_json), format="json", options={}) diff --git a/tests/unit/rules/plugins/test_yaml_loader.py b/tests/unit/rules/plugins/test_yaml_loader.py new file mode 100644 index 00000000..226887bc --- /dev/null +++ b/tests/unit/rules/plugins/test_yaml_loader.py @@ -0,0 +1,257 @@ +"""Tests for spark_expectations.rules.plugins.yaml_loader""" + +import os +import textwrap +from unittest.mock import patch + +import pytest + +from spark_expectations.core import get_spark_session +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules.plugins._flatten import RULES_SCHEMA_COLUMNS +from spark_expectations.rules.plugins.yaml_loader import SparkExpectationsYamlRuleLoaderImpl + +get_spark_session() + +SAMPLE_RULES_YAML = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "..", "examples", "resources", "sample_rules.yaml" +) + + +@pytest.fixture +def yaml_loader(): + return SparkExpectationsYamlRuleLoaderImpl() + + +@pytest.fixture +def dq_env_yaml_file(tmp_path): + content = textwrap.dedent("""\ + product_id: test_product + + dq_env: + DEV: + table_name: db.test_table + action_if_failed: ignore + is_active: true + priority: medium + QA: + table_name: db_qa.test_table + action_if_failed: ignore + is_active: true + priority: medium + PROD: + table_name: db_prod.test_table + action_if_failed: fail + is_active: true + priority: high + + rules: + - rule: col1_not_null + rule_type: row_dq + column_name: col1 + expectation: "col1 IS NOT NULL" + action_if_failed: drop + tag: completeness + description: "col1 must not be null" + priority: high + - rule: col2_positive + rule_type: row_dq + column_name: col2 + expectation: "col2 > 0" + tag: validity + description: "col2 must be positive" + - rule: row_count + rule_type: agg_dq + expectation: "count(*) > 0" + action_if_failed: fail + tag: completeness + description: "Must have rows" + """) + path = tmp_path / "rules.yaml" + path.write_text(content) + return str(path) + + +@pytest.fixture +def rules_list_yaml_file(tmp_path): + content = textwrap.dedent("""\ + product_id: test_product + table_name: db.test_table + + defaults: + action_if_failed: ignore + is_active: true + priority: medium + + rules: + - rule: col1_not_null + rule_type: row_dq + column_name: col1 + expectation: "col1 IS NOT NULL" + action_if_failed: drop + tag: completeness + description: "col1 must not be null" + priority: high + - rule: col2_positive + rule_type: row_dq + column_name: col2 + expectation: "col2 > 0" + tag: validity + description: "col2 must be positive" + - rule: row_count + rule_type: agg_dq + expectation: "count(*) > 0" + action_if_failed: fail + tag: completeness + description: "Must have rows" + """) + path = tmp_path / "rules.yaml" + path.write_text(content) + return str(path) + + +def test_returns_none_for_non_yaml_format(yaml_loader): + result = yaml_loader.load_rules(path="rules.json", format="json", options={}) + assert result is None + + +def test_returns_none_for_auto_non_yaml_extension(yaml_loader): + result = yaml_loader.load_rules(path="rules.json", format="auto", options={}) + assert result is None + + +def test_handles_yaml_format_explicit(yaml_loader, dq_env_yaml_file): + df = yaml_loader.load_rules(path=dq_env_yaml_file, format="yaml", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 3 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS) + + +def test_handles_auto_yaml_extension(yaml_loader, dq_env_yaml_file): + df = yaml_loader.load_rules(path=dq_env_yaml_file, format="auto", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 3 + + +def test_handles_auto_yml_extension(yaml_loader, dq_env_yaml_file, tmp_path): + yml_path = tmp_path / "rules.yml" + import shutil + shutil.copy(dq_env_yaml_file, str(yml_path)) + df = yaml_loader.load_rules(path=str(yml_path), format="auto", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 3 + + +def test_dq_env_values(yaml_loader, dq_env_yaml_file): + df = yaml_loader.load_rules(path=dq_env_yaml_file, format="yaml", options={"dq_env": "DEV"}) + rows = [r.asDict() for r in df.collect()] + row_dq_rules = [r for r in rows if r["rule_type"] == "row_dq"] + agg_dq_rules = [r for r in rows if r["rule_type"] == "agg_dq"] + + assert len(row_dq_rules) == 2 + assert len(agg_dq_rules) == 1 + + col1_rule = next(r for r in row_dq_rules if r["rule"] == "col1_not_null") + assert col1_rule["product_id"] == "test_product" + assert col1_rule["table_name"] == "db.test_table" + assert col1_rule["action_if_failed"] == "drop" + assert col1_rule["priority"] == "high" + + col2_rule = next(r for r in row_dq_rules if r["rule"] == "col2_positive") + assert col2_rule["action_if_failed"] == "ignore" + assert col2_rule["priority"] == "medium" + + +def test_dq_env_selects_prod(yaml_loader, dq_env_yaml_file): + df = yaml_loader.load_rules(path=dq_env_yaml_file, format="yaml", options={"dq_env": "PROD"}) + rows = [r.asDict() for r in df.collect()] + col2_rule = next(r for r in rows if r["rule"] == "col2_positive") + assert col2_rule["table_name"] == "db_prod.test_table" + assert col2_rule["action_if_failed"] == "fail" + assert col2_rule["priority"] == "high" + + +def test_rules_list_values(yaml_loader, rules_list_yaml_file): + df = yaml_loader.load_rules(path=rules_list_yaml_file, format="yaml", options={}) + rows = [r.asDict() for r in df.collect()] + row_dq_rules = [r for r in rows if r["rule_type"] == "row_dq"] + agg_dq_rules = [r for r in rows if r["rule_type"] == "agg_dq"] + + assert len(row_dq_rules) == 2 + assert len(agg_dq_rules) == 1 + + col1_rule = next(r for r in row_dq_rules if r["rule"] == "col1_not_null") + assert col1_rule["product_id"] == "test_product" + assert col1_rule["table_name"] == "db.test_table" + assert col1_rule["action_if_failed"] == "drop" + assert col1_rule["priority"] == "high" + + col2_rule = next(r for r in row_dq_rules if r["rule"] == "col2_positive") + assert col2_rule["action_if_failed"] == "ignore" + assert col2_rule["priority"] == "medium" + + +def test_file_not_found_raises(yaml_loader): + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="not found"): + yaml_loader.load_rules(path="/nonexistent/rules.yaml", format="yaml", options={}) + + +def test_invalid_yaml_raises(yaml_loader, tmp_path): + bad = tmp_path / "bad.yaml" + bad.write_text("{{invalid: yaml: [}") + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="Error parsing"): + yaml_loader.load_rules(path=str(bad), format="yaml", options={}) + + +def test_empty_yaml_raises(yaml_loader, tmp_path): + empty = tmp_path / "empty.yaml" + empty.write_text("") + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="empty"): + yaml_loader.load_rules(path=str(empty), format="yaml", options={}) + + +def test_sample_rules_yaml_loads(yaml_loader): + """Ensure the shipped sample_rules.yaml example loads correctly.""" + df = yaml_loader.load_rules(path=SAMPLE_RULES_YAML, format="yaml", options={"dq_env": "DEV"}) + assert df is not None + assert df.count() == 16 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS) + + +def test_sample_rules_yaml_values(yaml_loader): + """Verify key fields from the sample YAML are parsed correctly.""" + df = yaml_loader.load_rules(path=SAMPLE_RULES_YAML, format="yaml", options={"dq_env": "DEV"}) + rows = {r["rule"]: r.asDict() for r in df.collect()} + assert rows["customer_id_not_null"]["action_if_failed"] == "drop" + assert rows["customer_id_not_null"]["rule_type"] == "row_dq" + assert rows["customer_id_not_null"]["priority"] == "high" + assert rows["order_date_not_null"]["action_if_failed"] == "ignore" + assert rows["order_date_not_null"]["priority"] == "medium" + assert rows["table_row_count"]["rule_type"] == "agg_dq" + assert rows["table_row_count"]["action_if_failed"] == "fail" + + +def test_sample_rules_yaml_env_table_name(yaml_loader): + """Verify that table_name changes per environment.""" + df_dev = yaml_loader.load_rules(path=SAMPLE_RULES_YAML, format="yaml", options={"dq_env": "DEV"}) + df_qa = yaml_loader.load_rules(path=SAMPLE_RULES_YAML, format="yaml", options={"dq_env": "QA"}) + rows_dev = {r["rule"]: r.asDict() for r in df_dev.collect()} + rows_qa = {r["rule"]: r.asDict() for r in df_qa.collect()} + assert rows_dev["customer_id_not_null"]["table_name"] == "dq_spark_dev.customer_order" + assert rows_qa["customer_id_not_null"]["table_name"] == "dq_spark_qa.customer_order" + + +def test_no_active_spark_session_raises(yaml_loader, dq_env_yaml_file): + """Verify error when no SparkSession is active.""" + with patch("spark_expectations.rules.plugins.yaml_loader.SparkSession") as mock_spark: + mock_spark.getActiveSession.return_value = None + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="No active SparkSession"): + yaml_loader.load_rules(path=dq_env_yaml_file, format="yaml", options={"dq_env": "DEV"}) + + +def test_non_dict_yaml_raises(yaml_loader, tmp_path): + """Verify error when YAML top-level is not a mapping.""" + list_yaml = tmp_path / "list.yaml" + list_yaml.write_text("- item1\n- item2\n") + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="mapping at the top level"): + yaml_loader.load_rules(path=str(list_yaml), format="yaml", options={}) diff --git a/tests/unit/rules/test__init__.py b/tests/unit/rules/test__init__.py new file mode 100644 index 00000000..61fd93bb --- /dev/null +++ b/tests/unit/rules/test__init__.py @@ -0,0 +1,139 @@ +"""Tests for spark_expectations.rules (plugin manager and convenience functions)""" + +import json +import textwrap + +import pytest + +from spark_expectations.core import get_spark_session +from spark_expectations.core.exceptions import SparkExpectationsUserInputOrConfigInvalidException +from spark_expectations.rules import ( + get_rule_loader_hook, + load_rules, + load_rules_from_json, + load_rules_from_yaml, +) +from spark_expectations.rules.plugins._flatten import RULES_SCHEMA_COLUMNS +from spark_expectations.rules.plugins.json_loader import SparkExpectationsJsonRuleLoaderImpl +from spark_expectations.rules.plugins.yaml_loader import SparkExpectationsYamlRuleLoaderImpl + +get_spark_session() + + +def test_hook_returns_plugin_manager(): + pm = get_rule_loader_hook() + assert pm is not None + + +def test_yaml_plugin_registered(): + pm = get_rule_loader_hook() + plugin = pm.get_plugin("spark_expectations_yaml_rule_loader") + assert isinstance(plugin, SparkExpectationsYamlRuleLoaderImpl) + + +def test_json_plugin_registered(): + pm = get_rule_loader_hook() + plugin = pm.get_plugin("spark_expectations_json_rule_loader") + assert isinstance(plugin, SparkExpectationsJsonRuleLoaderImpl) + + +def test_two_plugins_registered(): + pm = get_rule_loader_hook() + plugins = pm.list_name_plugin() + names = [name for name, _ in plugins] + assert "spark_expectations_yaml_rule_loader" in names + assert "spark_expectations_json_rule_loader" in names + + +@pytest.fixture +def yaml_file(tmp_path): + content = textwrap.dedent("""\ + product_id: test_product + dq_env: + DEV: + table_name: db.t1 + action_if_failed: ignore + priority: medium + rules: + - rule: r1 + rule_type: row_dq + expectation: "col1 > 0" + column_name: col1 + tag: validity + """) + path = tmp_path / "rules.yaml" + path.write_text(content) + return str(path) + + +@pytest.fixture +def json_file(tmp_path): + data = { + "product_id": "test_product", + "dq_env": { + "DEV": { + "table_name": "db.t1", + "action_if_failed": "ignore", + "priority": "medium", + }, + }, + "rules": [ + { + "rule": "r1", + "rule_type": "row_dq", + "expectation": "col1 > 0", + "column_name": "col1", + "tag": "validity", + } + ], + } + path = tmp_path / "rules.json" + path.write_text(json.dumps(data)) + return str(path) + + +def test_load_rules_auto_yaml(yaml_file): + df = load_rules(yaml_file, options={"dq_env": "DEV"}) + assert df.count() == 1 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS) + + +def test_load_rules_auto_json(json_file): + df = load_rules(json_file, options={"dq_env": "DEV"}) + assert df.count() == 1 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS) + + +def test_load_rules_explicit_yaml(yaml_file): + df = load_rules(yaml_file, format="yaml", options={"dq_env": "DEV"}) + assert df.count() == 1 + + +def test_load_rules_explicit_json(json_file): + df = load_rules(json_file, format="json", options={"dq_env": "DEV"}) + assert df.count() == 1 + + +def test_load_rules_unsupported_format_raises(tmp_path): + path = tmp_path / "rules.csv" + path.write_text("a,b,c") + with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match="No rule-loader"): + load_rules(str(path)) + + +def test_load_rules_from_yaml(yaml_file): + df = load_rules_from_yaml(yaml_file, options={"dq_env": "DEV"}) + assert df.count() == 1 + + +def test_load_rules_from_json(json_file): + df = load_rules_from_json(json_file, options={"dq_env": "DEV"}) + assert df.count() == 1 + + +def test_load_rules_with_explicit_spark_session(yaml_file): + """Exercise the `spark is not None` branch in load_rules.""" + spark = get_spark_session() + df = load_rules(yaml_file, spark=spark, options={"dq_env": "DEV"}) + assert df.count() == 1 + assert set(df.columns) == set(RULES_SCHEMA_COLUMNS)