Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion mostlyai/mock/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ async def _sample_table(
llm_config: LLMConfig,
config: MockConfig,
progress_callback: Callable[..., Awaitable[None]] | None = None,
batch_size = None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type annotation on batch_size parameter

Low Severity

The batch_size parameter in _sample_table is declared as batch_size = None without a type annotation, while every other function in the call chain (_sample_common, sample, _asample, _create_table_rows_generator) consistently annotates it as batch_size: int | None. This inconsistency reduces readability and breaks the typing contract across the codebase.

Fix in Cursor Fix in Web

) -> pd.DataFrame:
# provide a default progress callback if none is provided
if progress_callback is None:
Expand Down Expand Up @@ -282,6 +283,7 @@ async def default_progress_callback(**kwargs):
n_workers=n_workers,
llm_config=llm_config,
progress_callback=progress_callback,
batch_size=batch_size,
)
table_df = await _convert_table_rows_generator_to_df(
table_rows_generator=table_rows_generator,
Expand Down Expand Up @@ -443,6 +445,7 @@ def _create_table_prompt(
if n_rows is not None:
prompt += f"Number of data rows to {verb}: `{n_rows}`.\n\n"


if target_primary_key is not None:
prompt += f"Add prefix to all values of Target Table Primary Key. The prefix is 'B{batch_idx}-'."
prompt += " There is one exception: if primary keys are in existing data, don't add prefix to them."
Expand Down Expand Up @@ -823,8 +826,8 @@ async def _create_table_rows_generator(
n_workers: int,
llm_config: LLMConfig,
progress_callback: Callable[..., Awaitable[None]] | None = None,
batch_size: int | None = 20, # generate 20 root table rows at a time
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing None overrides default batch_size of 20

High Severity

When sample() is called without specifying batch_size, the default None is explicitly passed through the entire call chain (sample_asample_sample_common_sample_table_create_table_rows_generator), overriding the default value of 20 in _create_table_rows_generator. For root tables without existing data or foreign keys, batch_size remains None, causing a TypeError at if ideal_batch_size < batch_size: when comparing int with None. This breaks the default (no batch_size argument) usage path entirely.

Additional Locations (2)

Fix in Cursor Fix in Web

) -> AsyncGenerator[dict]:
batch_size = 20 # generate 20 root table rows at a time

def supports_structured_outputs(model: str) -> bool:
model = model.removeprefix("litellm_proxy/")
Expand Down Expand Up @@ -1260,6 +1263,7 @@ async def _sample_common(
n_workers: int = 10,
return_type: Literal["auto", "dict"] = "auto",
progress_callback: Callable[..., Awaitable[None]] | None = None,
batch_size: int | None = None
):
tables: dict[str, TableConfig] = _harmonize_tables(tables, existing_data)
config = MockConfig(tables)
Expand Down Expand Up @@ -1295,6 +1299,7 @@ async def _sample_common(
llm_config=llm_config,
config=config,
progress_callback=progress_callback,
batch_size=batch_size
)
data[table_name] = df

Expand All @@ -1318,6 +1323,7 @@ def sample(
n_workers: int = 10,
return_type: Literal["auto", "dict"] = "auto",
progress_callback: Callable[..., Awaitable[None]] | None = None,
batch_size: int | None = None
) -> pd.DataFrame | dict[str, pd.DataFrame]:
"""
Generate synthetic data from scratch or enrich existing data with new columns.
Expand Down Expand Up @@ -1605,6 +1611,7 @@ def sample_common_sync(*args, **kwargs) -> pd.DataFrame | dict[str, pd.DataFrame
n_workers=n_workers,
return_type=return_type,
progress_callback=progress_callback,
batch_size=batch_size
)
return future.result()

Expand All @@ -1621,6 +1628,7 @@ async def _asample(
n_workers: int = 10,
return_type: Literal["auto", "dict"] = "auto",
progress_callback: Callable[..., Awaitable[None]] | None = None,
batch_size: int | None = None
) -> pd.DataFrame | dict[str, pd.DataFrame]:
return await _sample_common(
tables=tables,
Expand All @@ -1633,6 +1641,7 @@ async def _asample(
n_workers=n_workers,
return_type=return_type,
progress_callback=progress_callback,
batch_size=batch_size
)


Expand Down