-
Notifications
You must be signed in to change notification settings - Fork 4
enhance: make initial batch_size configurable #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -250,6 +250,7 @@ async def _sample_table( | |
| llm_config: LLMConfig, | ||
| config: MockConfig, | ||
| progress_callback: Callable[..., Awaitable[None]] | None = None, | ||
| batch_size = None, | ||
| ) -> pd.DataFrame: | ||
| # provide a default progress callback if none is provided | ||
| if progress_callback is None: | ||
|
|
@@ -282,6 +283,7 @@ async def default_progress_callback(**kwargs): | |
| n_workers=n_workers, | ||
| llm_config=llm_config, | ||
| progress_callback=progress_callback, | ||
| batch_size=batch_size, | ||
| ) | ||
| table_df = await _convert_table_rows_generator_to_df( | ||
| table_rows_generator=table_rows_generator, | ||
|
|
@@ -443,6 +445,7 @@ def _create_table_prompt( | |
| if n_rows is not None: | ||
| prompt += f"Number of data rows to {verb}: `{n_rows}`.\n\n" | ||
|
|
||
|
|
||
| if target_primary_key is not None: | ||
| prompt += f"Add prefix to all values of Target Table Primary Key. The prefix is 'B{batch_idx}-'." | ||
| prompt += " There is one exception: if primary keys are in existing data, don't add prefix to them." | ||
|
|
@@ -823,8 +826,8 @@ async def _create_table_rows_generator( | |
| n_workers: int, | ||
| llm_config: LLMConfig, | ||
| progress_callback: Callable[..., Awaitable[None]] | None = None, | ||
| batch_size: int | None = 20, # generate 20 root table rows at a time | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing
|
||
| ) -> AsyncGenerator[dict]: | ||
| batch_size = 20 # generate 20 root table rows at a time | ||
|
|
||
| def supports_structured_outputs(model: str) -> bool: | ||
| model = model.removeprefix("litellm_proxy/") | ||
|
|
@@ -1260,6 +1263,7 @@ async def _sample_common( | |
| n_workers: int = 10, | ||
| return_type: Literal["auto", "dict"] = "auto", | ||
| progress_callback: Callable[..., Awaitable[None]] | None = None, | ||
| batch_size: int | None = None | ||
| ): | ||
| tables: dict[str, TableConfig] = _harmonize_tables(tables, existing_data) | ||
| config = MockConfig(tables) | ||
|
|
@@ -1295,6 +1299,7 @@ async def _sample_common( | |
| llm_config=llm_config, | ||
| config=config, | ||
| progress_callback=progress_callback, | ||
| batch_size=batch_size | ||
| ) | ||
| data[table_name] = df | ||
|
|
||
|
|
@@ -1318,6 +1323,7 @@ def sample( | |
| n_workers: int = 10, | ||
| return_type: Literal["auto", "dict"] = "auto", | ||
| progress_callback: Callable[..., Awaitable[None]] | None = None, | ||
| batch_size: int | None = None | ||
| ) -> pd.DataFrame | dict[str, pd.DataFrame]: | ||
| """ | ||
| Generate synthetic data from scratch or enrich existing data with new columns. | ||
|
|
@@ -1605,6 +1611,7 @@ def sample_common_sync(*args, **kwargs) -> pd.DataFrame | dict[str, pd.DataFrame | |
| n_workers=n_workers, | ||
| return_type=return_type, | ||
| progress_callback=progress_callback, | ||
| batch_size=batch_size | ||
| ) | ||
| return future.result() | ||
|
|
||
|
|
@@ -1621,6 +1628,7 @@ async def _asample( | |
| n_workers: int = 10, | ||
| return_type: Literal["auto", "dict"] = "auto", | ||
| progress_callback: Callable[..., Awaitable[None]] | None = None, | ||
| batch_size: int | None = None | ||
| ) -> pd.DataFrame | dict[str, pd.DataFrame]: | ||
| return await _sample_common( | ||
| tables=tables, | ||
|
|
@@ -1633,6 +1641,7 @@ async def _asample( | |
| n_workers=n_workers, | ||
| return_type=return_type, | ||
| progress_callback=progress_callback, | ||
| batch_size=batch_size | ||
| ) | ||
|
|
||
|
|
||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing type annotation on
batch_sizeparameterLow Severity
The
batch_sizeparameter in_sample_tableis declared asbatch_size = Nonewithout a type annotation, while every other function in the call chain (_sample_common,sample,_asample,_create_table_rows_generator) consistently annotates it asbatch_size: int | None. This inconsistency reduces readability and breaks the typing contract across the codebase.