diff --git a/integrations/python/dataloader/src/openhouse/dataloader/__init__.py b/integrations/python/dataloader/src/openhouse/dataloader/__init__.py index 947de1801..df266c799 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/__init__.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/__init__.py @@ -1,13 +1,14 @@ from importlib.metadata import version from openhouse.dataloader.catalog import OpenHouseCatalog, OpenHouseCatalogError -from openhouse.dataloader.data_loader import DataLoaderContext, OpenHouseDataLoader +from openhouse.dataloader.data_loader import DataLoaderContext, JvmConfig, OpenHouseDataLoader from openhouse.dataloader.filters import always_true, col __version__ = version("openhouse.dataloader") __all__ = [ "OpenHouseDataLoader", "DataLoaderContext", + "JvmConfig", "OpenHouseCatalog", "OpenHouseCatalogError", "always_true", diff --git a/integrations/python/dataloader/src/openhouse/dataloader/_jvm.py b/integrations/python/dataloader/src/openhouse/dataloader/_jvm.py new file mode 100644 index 000000000..550a17dbd --- /dev/null +++ b/integrations/python/dataloader/src/openhouse/dataloader/_jvm.py @@ -0,0 +1,29 @@ +"""JVM configuration utilities for the HDFS client.""" + +import logging +import os +import threading + +logger = logging.getLogger(__name__) + +LIBHDFS_OPTS_ENV = "LIBHDFS_OPTS" +"""Environment variable read by libhdfs when starting the JNI JVM.""" + +_lock = threading.Lock() + + +def apply_libhdfs_opts(jvm_args: str) -> None: + """Merge *jvm_args* into the JNI JVM options environment variable. + + Appends to any existing value. Must be called before the first + HDFS access in the current process (the JVM is started once and + reads these options only at startup). Thread-safe and idempotent — + duplicate args are not appended. + """ + with _lock: + existing = os.environ.get(LIBHDFS_OPTS_ENV, "") + if jvm_args in existing: + return + merged = f"{existing} {jvm_args}".strip() if existing else jvm_args + os.environ[LIBHDFS_OPTS_ENV] = merged + logger.info("Set %s=%s", LIBHDFS_OPTS_ENV, merged) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py b/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py index e9e7e5f5e..ae20bd9c5 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/_table_scan_context.py @@ -16,6 +16,7 @@ def _unpickle_scan_context( projected_schema: Schema, row_filter: BooleanExpression, table_id: TableIdentifier, + worker_jvm_args: str | None = None, ) -> TableScanContext: return TableScanContext( table_metadata=table_metadata, @@ -23,6 +24,7 @@ def _unpickle_scan_context( projected_schema=projected_schema, row_filter=row_filter, table_id=table_id, + worker_jvm_args=worker_jvm_args, ) @@ -39,6 +41,7 @@ class TableScanContext: projected_schema: Subset of columns to read (equals table schema when no projection) table_id: Identifier for the table being scanned row_filter: Row-level filter expression pushed down to the scan + worker_jvm_args: JVM arguments applied when the JNI JVM is created in worker processes """ table_metadata: TableMetadata @@ -46,9 +49,17 @@ class TableScanContext: projected_schema: Schema table_id: TableIdentifier row_filter: BooleanExpression = AlwaysTrue() + worker_jvm_args: str | None = None def __reduce__(self) -> tuple: return ( _unpickle_scan_context, - (self.table_metadata, dict(self.io.properties), self.projected_schema, self.row_filter, self.table_id), + ( + self.table_metadata, + dict(self.io.properties), + self.projected_schema, + self.row_filter, + self.table_id, + self.worker_jvm_args, + ), ) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 3c43da7b0..f2118d30b 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -10,6 +10,7 @@ from requests import HTTPError from tenacity import Retrying, retry_if_exception, stop_after_attempt, wait_exponential +from openhouse.dataloader._jvm import apply_libhdfs_opts from openhouse.dataloader._table_scan_context import TableScanContext from openhouse.dataloader._timer import log_duration from openhouse.dataloader.data_loader_split import DataLoaderSplit @@ -55,6 +56,29 @@ def _retry[T](fn: Callable[[], T], label: str, max_attempts: int) -> T: raise AssertionError("unreachable") # pragma: no cover +@dataclass(frozen=True) +class JvmConfig: + """JVM arguments for JNI-based storage access (e.g. HDFS via libhdfs). + + The JVM is created once per process. If another library has already + started a JVM these arguments will have no effect. + + Args: + planner_args: JVM arguments (e.g. ``-Xmx2g``) applied when the JNI + JVM is created in the planner process — the process that loads + table metadata and plans splits. + worker_args: JVM arguments applied when the JNI JVM is created in + worker processes that materialize splits. Only honored if the + JVM has not already been started in the worker process. When + splits are materialized in the same process as the planner, + only ``planner_args`` takes effect because the JVM is already + running. + """ + + planner_args: str | None = None + worker_args: str | None = None + + @dataclass class DataLoaderContext: """Context and customization for the DataLoader. @@ -66,11 +90,14 @@ class DataLoaderContext: execution_context: Dictionary of execution context information (e.g. tenant, environment) table_transformer: Transformation to apply to the table before loading (e.g. column masking) udf_registry: UDFs required for the table transformation + jvm_config: JVM configuration for JNI-based storage access. Currently only HDFS is supported + via the ``LIBHDFS_OPTS`` environment variable. See :class:`JvmConfig`. """ execution_context: Mapping[str, str] | None = None table_transformer: TableTransformer | None = None udf_registry: UDFRegistry | None = None + jvm_config: JvmConfig | None = None class OpenHouseDataLoader: @@ -112,6 +139,9 @@ def __init__( self._context = context or DataLoaderContext() self._max_attempts = max_attempts + if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None: + apply_libhdfs_opts(self._context.jvm_config.planner_args) + @cached_property def _iceberg_table(self) -> Table: return _retry( @@ -215,6 +245,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]: projected_schema=scan.projection(), row_filter=row_filter, table_id=self._table_id, + worker_jvm_args=self._context.jvm_config.worker_args if self._context.jvm_config else None, ) # plan_files() materializes all tasks at once (PyIceberg doesn't support streaming) diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py index cd3630df3..38331bbe4 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader_split.py @@ -9,6 +9,7 @@ from pyiceberg.io.pyarrow import ArrowScan from pyiceberg.table import FileScanTask +from openhouse.dataloader._jvm import apply_libhdfs_opts from openhouse.dataloader._table_scan_context import TableScanContext from openhouse.dataloader.filters import _quote_identifier from openhouse.dataloader.table_identifier import TableIdentifier @@ -78,6 +79,8 @@ def __iter__(self) -> Iterator[RecordBatch]: delete files, and partition spec lookups. """ ctx = self._scan_context + if ctx.worker_jvm_args is not None: + apply_libhdfs_opts(ctx.worker_jvm_args) arrow_scan = ArrowScan( table_metadata=ctx.table_metadata, io=ctx.io, @@ -90,8 +93,16 @@ def __iter__(self) -> Iterator[RecordBatch]: if self._transform_sql is None: yield from batches else: + # Materialize the first batch before creating the transform session + # so that the HDFS JVM starts (and picks up worker_jvm_args) before + # any UDF registration code can trigger JNI. + batch_iter = iter(batches) + first = next(batch_iter, None) + if first is None: + return session = _create_transform_session(self._scan_context.table_id, self._udf_registry) - for batch in batches: + yield from self._apply_transform(session, first) + for batch in batch_iter: yield from self._apply_transform(session, batch) def _apply_transform(self, session: SessionContext, batch: RecordBatch) -> Iterator[RecordBatch]: diff --git a/integrations/python/dataloader/tests/integration_tests.py b/integrations/python/dataloader/tests/integration_tests.py index adfc920c1..3ccfe4f48 100644 --- a/integrations/python/dataloader/tests/integration_tests.py +++ b/integrations/python/dataloader/tests/integration_tests.py @@ -6,8 +6,10 @@ """ import logging +import multiprocessing import os import sys +import tempfile import time import pyarrow as pa @@ -15,7 +17,7 @@ import requests from pyiceberg.exceptions import NoSuchTableError -from openhouse.dataloader import OpenHouseDataLoader +from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader from openhouse.dataloader.catalog import OpenHouseCatalog from openhouse.dataloader.filters import col @@ -107,6 +109,50 @@ def close(self) -> None: requests.delete(self._session_url, headers=HEADERS, timeout=REQUEST_TIMEOUT) +def _parse_max_heap_bytes(jvm_output: str) -> int: + """Extract MaxHeapSize value in bytes from -XX:+PrintFlagsFinal output.""" + for line in jvm_output.splitlines(): + parts = line.split() + if len(parts) >= 3 and parts[1] == "MaxHeapSize": + return int(parts[3]) + raise ValueError("MaxHeapSize not found in JVM output") + + +def _assert_jvm_heap(log_path: str, requested_mb: int, upper_bound_mb: int, label: str) -> int: + """Read a JVM flags log file, assert MaxHeapSize <= upper_bound, and return the actual value.""" + with open(log_path) as f: + output = f.read() + assert "MaxHeapSize" in output, f"{label} JVM did not print flags — jvm_args not honored" + heap = _parse_max_heap_bytes(output) + assert heap <= upper_bound_mb * 1024 * 1024, ( + f"{label} MaxHeapSize {heap} exceeds {upper_bound_mb}m — -Xmx{requested_mb}m not honored" + ) + return heap + + +def _materialize_split_in_child(split, jvm_log_path): + """Materialize a single split in this process, capturing stdout+stderr to *jvm_log_path*. + + Intended to run via multiprocessing so the child gets a fresh JVM that + picks up worker_jvm_args from LIBHDFS_OPTS. + """ + saved_stdout = os.dup(1) + saved_stderr = os.dup(2) + log_fd = os.open(jvm_log_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644) + os.dup2(log_fd, 1) + os.dup2(log_fd, 2) + os.close(log_fd) + try: + batches = list(split) + num_rows = sum(b.num_rows for b in batches) + finally: + os.dup2(saved_stdout, 1) + os.close(saved_stdout) + os.dup2(saved_stderr, 2) + os.close(saved_stderr) + print(f" child process read {num_rows} rows from split") + + def _read_all(loader: OpenHouseDataLoader) -> pa.Table: """Read all data from a DataLoader and return as a sorted PyArrow table.""" batches = [batch for split in loader for batch in split] @@ -142,11 +188,18 @@ def read_token() -> str: properties={"DEFAULT_SCHEME": "hdfs", "DEFAULT_NETLOC": HDFS_NETLOC}, ) + # Set jvm_args before any DataLoader is created so LIBHDFS_OPTS is in + # place when the JVM starts. We capture both stdout and stderr to a + # log file because -XX:+PrintFlagsFinal may write to either fd. + jvm_log_fd, jvm_log = tempfile.mkstemp(suffix=".log") + os.close(jvm_log_fd) + ctx = DataLoaderContext(jvm_config=JvmConfig(planner_args="-Xmx127m -XX:+PrintFlagsFinal")) + livy = LivySession(LIVY_URL, token_str) try: # 1. Nonexistent table raises NoSuchTableError with pytest.raises(NoSuchTableError): - loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table="nonexistent_table") + loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table="nonexistent_table", context=ctx) _read_all(loader) print("PASS: nonexistent table raised NoSuchTableError") @@ -155,19 +208,34 @@ def read_token() -> str: f"CREATE TABLE {FQTN} ({CREATE_COLUMNS}) USING iceberg TBLPROPERTIES ('itest.custom-key' = 'custom-value')" ) try: - loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID) - assert list(loader) == [], "Expected no splits for empty table" - assert loader.snapshot_id is None, "Expected no snapshot for empty table" - assert loader.table_properties.get("itest.custom-key") == "custom-value" + # Capture stdout+stderr from here through the first HDFS read + # so we can verify -XX:+PrintFlagsFinal output at the end. + # The JVM starts during the first table load and prints flags then. + saved_stdout = os.dup(1) + saved_stderr = os.dup(2) + log_fd = os.open(jvm_log, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644) + os.dup2(log_fd, 1) + os.dup2(log_fd, 2) + os.close(log_fd) + try: + loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID) + assert list(loader) == [], "Expected no splits for empty table" + assert loader.snapshot_id is None, "Expected no snapshot for empty table" + assert loader.table_properties.get("itest.custom-key") == "custom-value" + + # 3. Write data via Spark + livy.execute(f"INSERT INTO {FQTN} VALUES (1, 'alice', 1.1), (2, 'bob', 2.2), (3, 'charlie', 3.3)") + snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id + assert snap1 is not None + + # 4. Read all data + result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID)) + finally: + os.dup2(saved_stdout, 1) + os.close(saved_stdout) + os.dup2(saved_stderr, 2) + os.close(saved_stderr) print("PASS: empty table returned no splits and custom property is accessible") - - # 3. Write data via Spark - livy.execute(f"INSERT INTO {FQTN} VALUES (1, 'alice', 1.1), (2, 'bob', 2.2), (3, 'charlie', 3.3)") - snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id - assert snap1 is not None - - # 4. Read all data - result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID)) assert result.num_rows == 3 assert result.column(COL_ID).to_pylist() == [1, 2, 3] assert result.column(COL_NAME).to_pylist() == ["alice", "bob", "charlie"] @@ -223,9 +291,36 @@ def read_token() -> str: list(loader) print("PASS: invalid snapshot_id raised ValueError") + # 8. Materialize a split in a child process with worker_jvm_args. + # The child gets a fresh JVM, so -Xmx254m takes effect there + # independently of the planner's -Xmx127m. + worker_ctx = DataLoaderContext(jvm_config=JvmConfig(worker_args="-Xmx254m -XX:+PrintFlagsFinal")) + worker_loader = OpenHouseDataLoader( + catalog=catalog, database=DATABASE_ID, table=TABLE_ID, context=worker_ctx + ) + splits = list(worker_loader) + assert splits, "Expected at least one split" + worker_jvm_log_fd, worker_jvm_log = tempfile.mkstemp(suffix=".log") + os.close(worker_jvm_log_fd) + spawn_ctx = multiprocessing.get_context("spawn") + proc = spawn_ctx.Process(target=_materialize_split_in_child, args=(splits[0], worker_jvm_log)) + proc.start() + proc.join(timeout=120) + assert proc.exitcode == 0, f"Child process failed with exit code {proc.exitcode}" + print("PASS: worker_jvm_args split materialized in child process") + finally: livy.execute(f"DROP TABLE IF EXISTS {FQTN}") + # Verify planner and worker jvm_args were honored by their respective JVMs + planner_heap = _assert_jvm_heap(jvm_log, requested_mb=127, upper_bound_mb=128, label="Planner") + print(f"PASS: planner_jvm_args honored by JVM (MaxHeapSize={planner_heap})") + worker_heap = _assert_jvm_heap(worker_jvm_log, requested_mb=254, upper_bound_mb=256, label="Worker") + assert worker_heap > planner_heap, ( + f"Worker MaxHeapSize ({worker_heap}) should be larger than planner ({planner_heap})" + ) + print(f"PASS: worker_jvm_args honored by child JVM (MaxHeapSize={worker_heap})") + print("All integration tests passed") finally: livy.close() diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index cb56147a0..4ec0edee5 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -15,7 +15,7 @@ from requests import ConnectionError as RequestsConnectionError from requests import HTTPError, Response, Timeout -from openhouse.dataloader import DataLoaderContext, OpenHouseDataLoader, __version__ +from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader, __version__ from openhouse.dataloader.data_loader_split import DataLoaderSplit, to_sql_identifier from openhouse.dataloader.filters import col from openhouse.dataloader.table_transformer import TableTransformer @@ -745,3 +745,23 @@ def test_starts_with_wildcard_literals(tmp_path, filter_expr, expected_names): ) result = _materialize(loader) assert sorted(result.column(COL_NAME).to_pylist()) == sorted(expected_names) + + +# --- JVM args tests --- + + +def test_planner_jvm_args_sets_libhdfs_opts(tmp_path, monkeypatch): + """JvmConfig.planner_args is applied to LIBHDFS_OPTS during __init__.""" + from openhouse.dataloader._jvm import LIBHDFS_OPTS_ENV + + monkeypatch.delenv(LIBHDFS_OPTS_ENV, raising=False) + catalog = _make_real_catalog(tmp_path) + + OpenHouseDataLoader( + catalog=catalog, + database="db", + table="tbl", + context=DataLoaderContext(jvm_config=JvmConfig(planner_args="-Xmx256m")), + ) + + assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx256m" diff --git a/integrations/python/dataloader/tests/test_data_loader_split.py b/integrations/python/dataloader/tests/test_data_loader_split.py index dd382b7cb..306eb3f84 100644 --- a/integrations/python/dataloader/tests/test_data_loader_split.py +++ b/integrations/python/dataloader/tests/test_data_loader_split.py @@ -396,3 +396,29 @@ def test_transform_with_quoted_identifier(tmp_path): assert result.num_rows == 1 assert result.column("name").to_pylist() == ["MASKED"] + + +# --- JVM args tests --- + + +def test_worker_jvm_args_sets_libhdfs_opts(tmp_path, monkeypatch): + """worker_jvm_args is applied to LIBHDFS_OPTS when iterating a split.""" + from openhouse.dataloader._jvm import LIBHDFS_OPTS_ENV + + monkeypatch.delenv(LIBHDFS_OPTS_ENV, raising=False) + + table = pa.table({"x": [1]}) + schema = Schema(NestedField(field_id=1, name="x", field_type=LongType(), required=False)) + + split = _create_test_split(tmp_path, table, FileFormat.PARQUET, schema) + split._scan_context = TableScanContext( + table_metadata=split._scan_context.table_metadata, + io=split._scan_context.io, + projected_schema=split._scan_context.projected_schema, + table_id=split._scan_context.table_id, + worker_jvm_args="-Xmx512m", + ) + + list(split) + + assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx512m" diff --git a/integrations/python/dataloader/tests/test_jvm.py b/integrations/python/dataloader/tests/test_jvm.py new file mode 100644 index 000000000..eccec1a1f --- /dev/null +++ b/integrations/python/dataloader/tests/test_jvm.py @@ -0,0 +1,28 @@ +import os + +import pytest + +from openhouse.dataloader._jvm import LIBHDFS_OPTS_ENV, apply_libhdfs_opts + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Remove LIBHDFS_OPTS before each test so tests are isolated.""" + monkeypatch.delenv(LIBHDFS_OPTS_ENV, raising=False) + + +def test_sets_env_when_unset() -> None: + apply_libhdfs_opts("-Xmx512m") + assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx512m" + + +def test_appends_to_existing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv(LIBHDFS_OPTS_ENV, "-Xmx256m") + apply_libhdfs_opts("-verbose:gc") + assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx256m -verbose:gc" + + +def test_skips_duplicate_args() -> None: + apply_libhdfs_opts("-Xmx512m") + apply_libhdfs_opts("-Xmx512m") + assert os.environ[LIBHDFS_OPTS_ENV] == "-Xmx512m"