Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from importlib.metadata import version

from openhouse.dataloader.catalog import OpenHouseCatalog, OpenHouseCatalogError
from openhouse.dataloader.data_loader import DataLoaderContext, OpenHouseDataLoader
from openhouse.dataloader.data_loader import DataLoaderContext, JvmConfig, OpenHouseDataLoader
from openhouse.dataloader.filters import always_true, col

__version__ = version("openhouse.dataloader")
__all__ = [
"OpenHouseDataLoader",
"DataLoaderContext",
"JvmConfig",
"OpenHouseCatalog",
"OpenHouseCatalogError",
"always_true",
Expand Down
29 changes: 29 additions & 0 deletions integrations/python/dataloader/src/openhouse/dataloader/_jvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""JVM configuration utilities for the HDFS client."""

import logging
import os
import threading

logger = logging.getLogger(__name__)

LIBHDFS_OPTS_ENV = "LIBHDFS_OPTS"
"""Environment variable read by libhdfs when starting the JNI JVM."""

_lock = threading.Lock()


def apply_libhdfs_opts(jvm_args: str) -> None:
"""Merge *jvm_args* into the JNI JVM options environment variable.

Appends to any existing value. Must be called before the first
HDFS access in the current process (the JVM is started once and
reads these options only at startup). Thread-safe and idempotent —
duplicate args are not appended.
"""
with _lock:
existing = os.environ.get(LIBHDFS_OPTS_ENV, "")
if jvm_args in existing:
return
merged = f"{existing} {jvm_args}".strip() if existing else jvm_args
os.environ[LIBHDFS_OPTS_ENV] = merged
logger.info("Set %s=%s", LIBHDFS_OPTS_ENV, merged)
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ def _unpickle_scan_context(
projected_schema: Schema,
row_filter: BooleanExpression,
table_id: TableIdentifier,
worker_jvm_args: str | None = None,
) -> TableScanContext:
return TableScanContext(
table_metadata=table_metadata,
io=load_file_io(properties=io_properties, location=table_metadata.location),
projected_schema=projected_schema,
row_filter=row_filter,
table_id=table_id,
worker_jvm_args=worker_jvm_args,
)


Expand All @@ -39,16 +41,25 @@ class TableScanContext:
projected_schema: Subset of columns to read (equals table schema when no projection)
table_id: Identifier for the table being scanned
row_filter: Row-level filter expression pushed down to the scan
worker_jvm_args: JVM arguments applied when the JNI JVM is created in worker processes
"""

table_metadata: TableMetadata
io: FileIO
projected_schema: Schema
table_id: TableIdentifier
row_filter: BooleanExpression = AlwaysTrue()
worker_jvm_args: str | None = None

def __reduce__(self) -> tuple:
return (
_unpickle_scan_context,
(self.table_metadata, dict(self.io.properties), self.projected_schema, self.row_filter, self.table_id),
(
self.table_metadata,
dict(self.io.properties),
self.projected_schema,
self.row_filter,
self.table_id,
self.worker_jvm_args,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from requests import HTTPError
from tenacity import Retrying, retry_if_exception, stop_after_attempt, wait_exponential

from openhouse.dataloader._jvm import apply_libhdfs_opts
from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader._timer import log_duration
from openhouse.dataloader.data_loader_split import DataLoaderSplit
Expand Down Expand Up @@ -55,6 +56,29 @@ def _retry[T](fn: Callable[[], T], label: str, max_attempts: int) -> T:
raise AssertionError("unreachable") # pragma: no cover


@dataclass(frozen=True)
class JvmConfig:
"""JVM arguments for JNI-based storage access (e.g. HDFS via libhdfs).

The JVM is created once per process. If another library has already
started a JVM these arguments will have no effect.

Args:
planner_args: JVM arguments (e.g. ``-Xmx2g``) applied when the JNI
JVM is created in the planner process — the process that loads
table metadata and plans splits.
worker_args: JVM arguments applied when the JNI JVM is created in
worker processes that materialize splits. Only honored if the
JVM has not already been started in the worker process. When
splits are materialized in the same process as the planner,
only ``planner_args`` takes effect because the JVM is already
running.
"""

planner_args: str | None = None
worker_args: str | None = None


@dataclass
class DataLoaderContext:
"""Context and customization for the DataLoader.
Expand All @@ -66,11 +90,14 @@ class DataLoaderContext:
execution_context: Dictionary of execution context information (e.g. tenant, environment)
table_transformer: Transformation to apply to the table before loading (e.g. column masking)
udf_registry: UDFs required for the table transformation
jvm_config: JVM configuration for JNI-based storage access. Currently only HDFS is supported
via the ``LIBHDFS_OPTS`` environment variable. See :class:`JvmConfig`.
"""

execution_context: Mapping[str, str] | None = None
table_transformer: TableTransformer | None = None
udf_registry: UDFRegistry | None = None
jvm_config: JvmConfig | None = None


class OpenHouseDataLoader:
Expand Down Expand Up @@ -112,6 +139,9 @@ def __init__(
self._context = context or DataLoaderContext()
self._max_attempts = max_attempts

if self._context.jvm_config is not None and self._context.jvm_config.planner_args is not None:
apply_libhdfs_opts(self._context.jvm_config.planner_args)

@cached_property
def _iceberg_table(self) -> Table:
return _retry(
Expand Down Expand Up @@ -215,6 +245,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
projected_schema=scan.projection(),
row_filter=row_filter,
table_id=self._table_id,
worker_jvm_args=self._context.jvm_config.worker_args if self._context.jvm_config else None,
)

# plan_files() materializes all tasks at once (PyIceberg doesn't support streaming)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import FileScanTask

from openhouse.dataloader._jvm import apply_libhdfs_opts
from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader.filters import _quote_identifier
from openhouse.dataloader.table_identifier import TableIdentifier
Expand Down Expand Up @@ -78,6 +79,8 @@ def __iter__(self) -> Iterator[RecordBatch]:
delete files, and partition spec lookups.
"""
ctx = self._scan_context
if ctx.worker_jvm_args is not None:
apply_libhdfs_opts(ctx.worker_jvm_args)
arrow_scan = ArrowScan(
table_metadata=ctx.table_metadata,
io=ctx.io,
Expand All @@ -90,8 +93,16 @@ def __iter__(self) -> Iterator[RecordBatch]:
if self._transform_sql is None:
yield from batches
else:
# Materialize the first batch before creating the transform session
# so that the HDFS JVM starts (and picks up worker_jvm_args) before
# any UDF registration code can trigger JNI.
batch_iter = iter(batches)
first = next(batch_iter, None)
if first is None:
return
session = _create_transform_session(self._scan_context.table_id, self._udf_registry)
for batch in batches:
yield from self._apply_transform(session, first)
for batch in batch_iter:
yield from self._apply_transform(session, batch)

def _apply_transform(self, session: SessionContext, batch: RecordBatch) -> Iterator[RecordBatch]:
Expand Down
123 changes: 109 additions & 14 deletions integrations/python/dataloader/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
"""

import logging
import multiprocessing
import os
import sys
import tempfile
import time

import pyarrow as pa
import pytest
import requests
from pyiceberg.exceptions import NoSuchTableError

from openhouse.dataloader import OpenHouseDataLoader
from openhouse.dataloader import DataLoaderContext, JvmConfig, OpenHouseDataLoader
from openhouse.dataloader.catalog import OpenHouseCatalog
from openhouse.dataloader.filters import col

Expand Down Expand Up @@ -107,6 +109,50 @@ def close(self) -> None:
requests.delete(self._session_url, headers=HEADERS, timeout=REQUEST_TIMEOUT)


def _parse_max_heap_bytes(jvm_output: str) -> int:
"""Extract MaxHeapSize value in bytes from -XX:+PrintFlagsFinal output."""
for line in jvm_output.splitlines():
parts = line.split()
if len(parts) >= 3 and parts[1] == "MaxHeapSize":
return int(parts[3])
raise ValueError("MaxHeapSize not found in JVM output")


def _assert_jvm_heap(log_path: str, requested_mb: int, upper_bound_mb: int, label: str) -> int:
"""Read a JVM flags log file, assert MaxHeapSize <= upper_bound, and return the actual value."""
with open(log_path) as f:
output = f.read()
assert "MaxHeapSize" in output, f"{label} JVM did not print flags — jvm_args not honored"
heap = _parse_max_heap_bytes(output)
assert heap <= upper_bound_mb * 1024 * 1024, (
f"{label} MaxHeapSize {heap} exceeds {upper_bound_mb}m — -Xmx{requested_mb}m not honored"
)
return heap


def _materialize_split_in_child(split, jvm_log_path):
"""Materialize a single split in this process, capturing stdout+stderr to *jvm_log_path*.

Intended to run via multiprocessing so the child gets a fresh JVM that
picks up worker_jvm_args from LIBHDFS_OPTS.
"""
saved_stdout = os.dup(1)
saved_stderr = os.dup(2)
log_fd = os.open(jvm_log_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
os.dup2(log_fd, 1)
os.dup2(log_fd, 2)
os.close(log_fd)
try:
batches = list(split)
num_rows = sum(b.num_rows for b in batches)
finally:
os.dup2(saved_stdout, 1)
os.close(saved_stdout)
os.dup2(saved_stderr, 2)
os.close(saved_stderr)
print(f" child process read {num_rows} rows from split")


def _read_all(loader: OpenHouseDataLoader) -> pa.Table:
"""Read all data from a DataLoader and return as a sorted PyArrow table."""
batches = [batch for split in loader for batch in split]
Expand Down Expand Up @@ -142,11 +188,18 @@ def read_token() -> str:
properties={"DEFAULT_SCHEME": "hdfs", "DEFAULT_NETLOC": HDFS_NETLOC},
)

# Set jvm_args before any DataLoader is created so LIBHDFS_OPTS is in
# place when the JVM starts. We capture both stdout and stderr to a
# log file because -XX:+PrintFlagsFinal may write to either fd.
jvm_log_fd, jvm_log = tempfile.mkstemp(suffix=".log")
os.close(jvm_log_fd)
ctx = DataLoaderContext(jvm_config=JvmConfig(planner_args="-Xmx127m -XX:+PrintFlagsFinal"))

livy = LivySession(LIVY_URL, token_str)
try:
# 1. Nonexistent table raises NoSuchTableError
with pytest.raises(NoSuchTableError):
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table="nonexistent_table")
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table="nonexistent_table", context=ctx)
_read_all(loader)
print("PASS: nonexistent table raised NoSuchTableError")

Expand All @@ -155,19 +208,34 @@ def read_token() -> str:
f"CREATE TABLE {FQTN} ({CREATE_COLUMNS}) USING iceberg TBLPROPERTIES ('itest.custom-key' = 'custom-value')"
)
try:
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID)
assert list(loader) == [], "Expected no splits for empty table"
assert loader.snapshot_id is None, "Expected no snapshot for empty table"
assert loader.table_properties.get("itest.custom-key") == "custom-value"
# Capture stdout+stderr from here through the first HDFS read
# so we can verify -XX:+PrintFlagsFinal output at the end.
# The JVM starts during the first table load and prints flags then.
saved_stdout = os.dup(1)
saved_stderr = os.dup(2)
log_fd = os.open(jvm_log, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o644)
os.dup2(log_fd, 1)
os.dup2(log_fd, 2)
os.close(log_fd)
try:
loader = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID)
assert list(loader) == [], "Expected no splits for empty table"
assert loader.snapshot_id is None, "Expected no snapshot for empty table"
assert loader.table_properties.get("itest.custom-key") == "custom-value"

# 3. Write data via Spark
livy.execute(f"INSERT INTO {FQTN} VALUES (1, 'alice', 1.1), (2, 'bob', 2.2), (3, 'charlie', 3.3)")
snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id
assert snap1 is not None

# 4. Read all data
result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID))
finally:
os.dup2(saved_stdout, 1)
os.close(saved_stdout)
os.dup2(saved_stderr, 2)
os.close(saved_stderr)
print("PASS: empty table returned no splits and custom property is accessible")

# 3. Write data via Spark
livy.execute(f"INSERT INTO {FQTN} VALUES (1, 'alice', 1.1), (2, 'bob', 2.2), (3, 'charlie', 3.3)")
snap1 = OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID).snapshot_id
assert snap1 is not None

# 4. Read all data
result = _read_all(OpenHouseDataLoader(catalog=catalog, database=DATABASE_ID, table=TABLE_ID))
assert result.num_rows == 3
assert result.column(COL_ID).to_pylist() == [1, 2, 3]
assert result.column(COL_NAME).to_pylist() == ["alice", "bob", "charlie"]
Expand Down Expand Up @@ -223,9 +291,36 @@ def read_token() -> str:
list(loader)
print("PASS: invalid snapshot_id raised ValueError")

# 8. Materialize a split in a child process with worker_jvm_args.
# The child gets a fresh JVM, so -Xmx254m takes effect there
# independently of the planner's -Xmx127m.
worker_ctx = DataLoaderContext(jvm_config=JvmConfig(worker_args="-Xmx254m -XX:+PrintFlagsFinal"))
worker_loader = OpenHouseDataLoader(
catalog=catalog, database=DATABASE_ID, table=TABLE_ID, context=worker_ctx
)
splits = list(worker_loader)
assert splits, "Expected at least one split"
worker_jvm_log_fd, worker_jvm_log = tempfile.mkstemp(suffix=".log")
os.close(worker_jvm_log_fd)
spawn_ctx = multiprocessing.get_context("spawn")
proc = spawn_ctx.Process(target=_materialize_split_in_child, args=(splits[0], worker_jvm_log))
proc.start()
proc.join(timeout=120)
assert proc.exitcode == 0, f"Child process failed with exit code {proc.exitcode}"
print("PASS: worker_jvm_args split materialized in child process")

finally:
livy.execute(f"DROP TABLE IF EXISTS {FQTN}")

# Verify planner and worker jvm_args were honored by their respective JVMs
planner_heap = _assert_jvm_heap(jvm_log, requested_mb=127, upper_bound_mb=128, label="Planner")
print(f"PASS: planner_jvm_args honored by JVM (MaxHeapSize={planner_heap})")
worker_heap = _assert_jvm_heap(worker_jvm_log, requested_mb=254, upper_bound_mb=256, label="Worker")
assert worker_heap > planner_heap, (
f"Worker MaxHeapSize ({worker_heap}) should be larger than planner ({planner_heap})"
)
print(f"PASS: worker_jvm_args honored by child JVM (MaxHeapSize={worker_heap})")

print("All integration tests passed")
finally:
livy.close()
Loading