Skip to content

Commit 7ccea04

Browse files
DeltaMichaelDilyan Marinovpre-commit-ci[bot]
authored
vdk-oracle: Pass ingestion payload rows in uniform batches (#3194)
## Why? The current ingestion implementation batches payload rows by column keyset. Payloads with the same keyset are batched together and passed to an executemany() call. This is not ideal becauase it can result in a large number of executemany() calls ## What? Make each payload row uniform by identifying missing columns and filling them out with null values. Pass data rows to executemany() in uniform batches. ## How was this tested? Functional tests ## What kind of change is this? Feature/non-breaking --------- Signed-off-by: Dilyan Marinov <mdilyan@vmware.com> Co-authored-by: Dilyan Marinov <mdilyan@vmware.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 54d510a commit 7ccea04

6 files changed

Lines changed: 118 additions & 110 deletions

File tree

projects/vdk-plugins/vdk-oracle/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pip install vdk-oracle
2929
| oracle_sid | The SID of the Oracle database. Note: This gets overridden if oracle_connection_string is set. | free |
3030
| oracle_service_name | The Service name of the Oracle database. Note: This gets overridden if oracle_connection_string is set. | free |
3131
| oracle_thick_mode | Python-oracledb is said to be in Thick mode when Oracle Client libraries are used. True by default. Set to False to disable Oracle Thick mode. More info: https://python-oracledb.readthedocs.io/en/latest/user_guide/appendix_b.html | True |
32+
| oracle_ingest_batch_size | vdk-oracle splits ingestion payloads into batches. Change this config to control the batch size. Default is set to 100. | 100 |
3233

3334
### Example
3435

projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
import re
77
from decimal import Decimal
88
from typing import Any
9-
from typing import Collection
109
from typing import Dict
1110
from typing import List
1211
from typing import Optional
13-
from typing import Set
1412

1513
from vdk.api.plugin.plugin_input import PEP249Connection
1614
from vdk.internal.builtin_plugins.connection.impl.router import ManagedConnectionRouter
@@ -89,10 +87,13 @@ def table_exists(self, table: str) -> bool:
8987

9088

9189
class IngestToOracle(IIngesterPlugin):
92-
def __init__(self, connections: ManagedConnectionRouter):
90+
def __init__(
91+
self, connections: ManagedConnectionRouter, ingest_batch_size: int = 100
92+
):
9393
self.conn: PEP249Connection = connections.open_connection("ORACLE").connect()
9494
self.cursor: ManagedCursor = self.conn.cursor()
9595
self.table_cache: TableCache = TableCache(self.cursor) # New cache for columns
96+
self.ingest_batch_size = ingest_batch_size
9697

9798
@staticmethod
9899
def _get_oracle_type(value: Any) -> str:
@@ -191,40 +192,48 @@ def cast_string_to_type(db_type: str, payload_value: str) -> Any:
191192

192193
return value
193194

194-
# TODO: Look into potential optimizations
195-
# TODO: https://github.com/vmware/versatile-data-kit/issues/2931
196195
def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
197196
if not payload:
198197
return
199198

200-
# group dicts by key set
201-
batches = {}
202-
for p in payload:
203-
batch = frozenset(p.keys())
204-
if batch not in batches:
205-
batches[batch] = []
206-
batches[batch].append(p)
207-
208-
# create queries for groups of dicts with the same key set
209-
queries = []
210-
batch_data = []
211-
for column_names, batch in batches.items():
212-
columns = list(column_names)
213-
query_columns = [_escape_special_chars(col) for col in columns]
214-
insert_sql = f"INSERT INTO {table_name} ({', '.join(query_columns)}) VALUES ({', '.join([':' + str(i + 1) for i in range(len(query_columns))])})"
215-
queries.append(insert_sql)
216-
temp_data = []
217-
for row in batch:
218-
temp = [
219-
self._cast_to_correct_type(table_name, col, row[col])
220-
for col in columns
221-
]
222-
temp_data.append(temp)
223-
batch_data.append(temp_data)
224-
225-
# batch execute queries for dicts with the same key set
226-
for i in range(len(queries)):
227-
self.cursor.executemany(queries[i], batch_data[i])
199+
def split(lst, n):
200+
"""Yield successive n-sized chunks from lst."""
201+
for i in range(0, len(lst), n):
202+
yield lst[i : i + n]
203+
204+
query, params = self._populate_query_parameters_tuple(table_name, payload)
205+
batches = list(split(params, self.ingest_batch_size))
206+
for batch in batches:
207+
self.cursor.executemany(query, batch)
208+
209+
def _populate_query_parameters_tuple(
210+
self, destination_table: str, payload: List[dict]
211+
) -> (str, list):
212+
"""
213+
Prepare the SQL query and parameters for bulk insertion.
214+
215+
Returns insert into destination table tuple of query and parameters;
216+
E.g. for a table dest_table with columns val1, val2 and payload size 2, this method will return:
217+
'INSERT INTO dest_table (val1, val2) VALUES (:0, :1)',
218+
[('val1', 'val2'), ('val1', 'val2')]
219+
"""
220+
columns = self.table_cache.get_columns(destination_table)
221+
query_columns = [_escape_special_chars(col) for col in columns]
222+
223+
placeholders = ", ".join(f":{i}" for i in range(len(columns)))
224+
query = f"INSERT INTO {destination_table} ({', '.join(query_columns)}) VALUES ({placeholders})"
225+
226+
parameters = []
227+
for obj in payload:
228+
row = tuple(
229+
self._cast_to_correct_type(
230+
destination_table, column.lower(), obj.get(column.lower())
231+
)
232+
for column in columns
233+
)
234+
parameters.append(row)
235+
236+
return query, parameters
228237

229238
def ingest_payload(
230239
self,

projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/oracle_configuration.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ORACLE_PORT = "ORACLE_PORT"
2020
ORACLE_SID = "ORACLE_SID"
2121
ORACLE_SERVICE_NAME = "ORACLE_SERVICE_NAME"
22+
ORACLE_INGEST_BATCH_SIZE = "ORACLE_INGEST_BATCH_SIZE"
2223

2324

2425
class OracleConfiguration:
@@ -55,6 +56,9 @@ def oracle_thick_mode(self) -> bool:
5556
def oracle_thick_mode_lib_dir(self) -> Optional[str]:
5657
return self.__config.get_value(ORACLE_THICK_MODE_LIB_DIR)
5758

59+
def oracle_ingest_batch_size(self) -> Optional[int]:
60+
return int(self.__config.get_value(ORACLE_INGEST_BATCH_SIZE))
61+
5862
@staticmethod
5963
def add_definitions(config_builder: ConfigurationBuilder):
6064
config_builder.add(
@@ -122,3 +126,8 @@ def add_definitions(config_builder: ConfigurationBuilder):
122126
"Before setting this follow instruction in "
123127
"https://python-oracledb.readthedocs.io/en/latest/user_guide/initialization.html#enablingthick ",
124128
)
129+
config_builder.add(
130+
key=ORACLE_INGEST_BATCH_SIZE,
131+
default_value=100,
132+
description="Batch size when ingesting records into Oracle.",
133+
)

projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/oracle_plugin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def initialize_job(self, context: JobContext):
5050
),
5151
)
5252
context.ingester.add_ingester_factory_method(
53-
"oracle", (lambda: IngestToOracle(context.connections))
53+
"oracle",
54+
lambda: IngestToOracle(
55+
context.connections, conf.oracle_ingest_batch_size()
56+
),
5457
)
5558

5659

projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-data-frame-schema-inference/10_ingest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55

66

77
def run(job_input: IJobInput):
8-
df = DataFrame.from_dict({"A": [1], "B": [2], "C": [3]})
8+
df = DataFrame.from_dict({"a": [1], "b": [2], "c": [3]})
99

1010
job_input.send_object_for_ingestion(payload=df, destination_table="test_table")

projects/vdk-plugins/vdk-oracle/tests/test_plugin.py

Lines changed: 61 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -178,85 +178,72 @@ def test_oracle_ingest_data_frame_schema_inference(self):
178178

179179
def _verify_query_execution(runner):
180180
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM todoitem"])
181-
expected = (
182-
" ID DESCRIPTION DONE\n"
183-
"---- ------------- ------\n"
184-
" 1 Task 1 1\n"
185-
)
186-
assert expected in check_result.output
181+
expected = [
182+
" ID DESCRIPTION DONE\n",
183+
"---- ------------- ------\n",
184+
" 1 Task 1 1\n",
185+
]
186+
for row in expected:
187+
assert row in check_result.output
187188

188189

189190
def _verify_ingest_execution(runner):
190191
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
191-
expected = (
192-
" ID STR_DATA INT_DATA FLOAT_DATA BOOL_DATA "
193-
"TIMESTAMP_DATA DECIMAL_DATA\n"
194-
"---- ---------- ---------- ------------ ----------- "
195-
"------------------- --------------\n"
196-
" 5 string 12 1.2 1 2023-11-21 "
197-
"08:12:53 0.1\n"
198-
)
199-
assert expected in check_result.output
192+
expected = [
193+
" ID STR_DATA INT_DATA FLOAT_DATA BOOL_DATA TIMESTAMP_DATA DECIMAL_DATA\n",
194+
"---- ---------- ---------- ------------ ----------- ------------------- --------------\n",
195+
" 5 string 12 1.2 1 2023-11-21 08:12:53 0.1\n",
196+
]
197+
for row in expected:
198+
assert row in check_result.output
200199

201200

202201
def _verify_ingest_execution_special_chars(runner):
203202
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
204-
expected = (
205-
" ID @str_data %int_data *float*data* BOOL_DATA "
206-
"TIMESTAMP_DATA DECIMAL_DATA\n"
207-
"---- ----------- ----------- -------------- ----------- "
208-
"------------------- --------------\n"
209-
" 5 string 12 1.2 1 2023-11-21 "
210-
"08:12:53 0.1\n"
211-
)
212-
assert expected in check_result.output
203+
expected = [
204+
" ID @str_data %int_data *float*data* BOOL_DATA TIMESTAMP_DATA DECIMAL_DATA\n",
205+
"---- ----------- ----------- -------------- ----------- ------------------- --------------\n",
206+
" 5 string 12 1.2 1 2023-11-21 08:12:53 0.1\n",
207+
]
208+
for row in expected:
209+
assert row in check_result.output
213210

214211

215212
def _verify_ingest_execution_type_inference(runner):
216213
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
217-
expected = (
218-
" ID STR_DATA INT_DATA NAN_INT_DATA FLOAT_DATA BOOL_DATA "
219-
"TIMESTAMP_DATA DECIMAL_DATA\n"
220-
"---- ---------- ---------- -------------- ------------ ----------- "
221-
"------------------- --------------\n"
222-
" 5 string 12 1.2 1 "
223-
"2023-11-21 08:12:53 0.1\n"
224-
)
225-
assert expected in check_result.output
214+
expected = [
215+
" ID STR_DATA INT_DATA NAN_INT_DATA FLOAT_DATA BOOL_DATA TIMESTAMP_DATA DECIMAL_DATA\n",
216+
"---- ---------- ---------- -------------- ------------ ----------- ------------------- --------------\n",
217+
" 5 string 12 1.2 1 2023-11-21 08:12:53 0.1\n",
218+
]
219+
for row in expected:
220+
assert row in check_result.output
226221

227222

228223
def _verify_ingest_execution_no_table(runner):
229224
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
230-
expected = (
231-
" ID STR_DATA INT_DATA FLOAT_DATA BOOL_DATA "
232-
"TIMESTAMP_DATA DECIMAL_DATA\n"
233-
"---- ---------- ---------- ------------ ----------- "
234-
"------------------- --------------\n"
235-
" 0 string 12 1.2 1 "
236-
"2023-11-21T08:12:53 1.1\n"
237-
" 1 string 12 1.2 1 "
238-
"2023-11-21T08:12:53 1.1\n"
239-
" 2 string 12 1.2 1 "
240-
"2023-11-21T08:12:53 1.1\n"
241-
)
242-
assert expected in check_result.output
225+
expected = [
226+
" ID STR_DATA INT_DATA FLOAT_DATA BOOL_DATA TIMESTAMP_DATA DECIMAL_DATA\n",
227+
"---- ---------- ---------- ------------ ----------- ------------------- --------------\n",
228+
" 0 string 12 1.2 1 2023-11-21T08:12:53 1.1\n",
229+
" 1 string 12 1.2 1 2023-11-21T08:12:53 1.1\n",
230+
" 2 string 12 1.2 1 2023-11-21T08:12:53 1.1\n",
231+
]
232+
for row in expected:
233+
assert row in check_result.output
243234

244235

245236
def _verify_ingest_execution_no_table_special_chars(runner):
246237
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
247-
expected = (
248-
" ID @str_data %int_data *float*data* BOOL_DATA "
249-
"TIMESTAMP_DATA DECIMAL_DATA\n"
250-
"---- ----------- ----------- -------------- ----------- "
251-
"------------------- --------------\n"
252-
" 0 string 12 1.2 1 "
253-
"2023-11-21T08:12:53 1.1\n"
254-
" 1 string 12 1.2 1 "
255-
"2023-11-21T08:12:53 1.1\n"
256-
" 2 string 12 1.2 1 "
257-
"2023-11-21T08:12:53 1.1\n"
258-
)
259-
assert expected in check_result.output
238+
expected = [
239+
" ID @str_data %int_data *float*data* BOOL_DATA TIMESTAMP_DATA DECIMAL_DATA\n",
240+
"---- ----------- ----------- -------------- ----------- ------------------- --------------\n",
241+
" 0 string 12 1.2 1 2023-11-21T08:12:53 1.1\n",
242+
" 1 string 12 1.2 1 2023-11-21T08:12:53 1.1\n",
243+
" 2 string 12 1.2 1 2023-11-21T08:12:53 1.1\n",
244+
]
245+
for row in expected:
246+
assert row in check_result.output
260247

261248

262249
def _verify_ingest_execution_different_payloads_no_table(runner):
@@ -301,21 +288,20 @@ def _verify_ingest_execution_different_payloads_no_table_special_chars(runner):
301288

302289
def _verify_ingest_execution_different_payloads(runner):
303290
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
304-
expected = (
305-
" ID STR_DATA INT_DATA FLOAT_DATA BOOL_DATA TIMESTAMP_DATA\n"
306-
"---- ---------- ---------- ------------ ----------- "
307-
"-------------------\n"
308-
" 0\n"
309-
" 1 string\n"
310-
" 2 string 12\n"
311-
" 3 string 12 1.2\n"
312-
" 6 string 12 1.2\n"
313-
" 4 string 12 1.2 1\n"
314-
" 7 string 12 1.2 1\n"
315-
" 5 string 12 1.2 1 2023-11-21 "
316-
"08:12:53\n"
317-
)
318-
assert expected in check_result.output
291+
expected = [
292+
" ID STR_DATA INT_DATA FLOAT_DATA BOOL_DATA TIMESTAMP_DATA\n",
293+
"---- ---------- ---------- ------------ ----------- -------------------\n",
294+
" 0\n",
295+
" 1 string\n",
296+
" 2 string 12\n",
297+
" 3 string 12 1.2\n",
298+
" 4 string 12 1.2 1\n",
299+
" 5 string 12 1.2 1 2023-11-21 08:12:53\n"
300+
" 6 string 12 1.2\n",
301+
" 7 string 12 1.2 1\n",
302+
]
303+
for row in expected:
304+
assert row in check_result.output
319305

320306

321307
def _verify_ingest_blob(runner):
@@ -352,5 +338,5 @@ def _verify_ingest_nan_and_none_execution(runner):
352338

353339
def _verify_ingest_data_frame_schema_inference(runner):
354340
check_result = runner.invoke(["sql-query", "--query", "SELECT * FROM test_table"])
355-
expected = " A B C\n--- --- ---\n 1 2 3\n"
341+
expected = "A B C\n--- --- ---\n 1 2 3\n"
356342
assert expected in check_result.output

0 commit comments

Comments
 (0)