diff --git a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java index bbb63b4e..b564e1b2 100644 --- a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java +++ b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java @@ -2290,15 +2290,15 @@ void testSetAllFieldTypes() throws Exception { SubDocumentUpdate.of("rating", 4.5f), SubDocumentUpdate.of("weight", 123.456), // Case 2: Top-level arrays - SubDocumentUpdate.of("tags", new String[] {"tag4", "tag5", "tag6"}), - SubDocumentUpdate.of("numbers", new Integer[] {10, 20, 30}), - SubDocumentUpdate.of("scores", new Double[] {1.1, 2.2, 3.3}), - SubDocumentUpdate.of("flags", new Boolean[] {true, false, true}), + SubDocumentUpdate.of("tags", new String[]{"tag4", "tag5", "tag6"}), + SubDocumentUpdate.of("numbers", new Integer[]{10, 20, 30}), + SubDocumentUpdate.of("scores", new Double[]{1.1, 2.2, 3.3}), + SubDocumentUpdate.of("flags", new Boolean[]{true, false, true}), // Case 3 & 4: One nested path in JSONB (props) - tests nested primitive SubDocumentUpdate.of("props.brand", "NewBrand"), // Use 'sales' JSONB column for nested array test SubDocumentUpdate.of( - "sales.regions", SubDocumentValue.of(new String[] {"US", "EU", "APAC"}))); + "sales.regions", SubDocumentValue.of(new String[]{"US", "EU", "APAC"}))); UpdateOptions options = UpdateOptions.builder().returnDocumentType(ReturnDocumentType.AFTER_UPDATE).build(); @@ -2510,7 +2510,7 @@ void testSetMultipleNestedPathsInSameJsonbColumn() throws Exception { SubDocumentUpdate.of("props.size", "XL"), SubDocumentUpdate.of("props.newField", "newValue"), SubDocumentUpdate.of( - "props.owners", SubDocumentValue.of(new String[] {"owner1", "owner2"}))); + "props.owners", SubDocumentValue.of(new String[]{"owner1", "owner2"}))); UpdateOptions options = UpdateOptions.builder().returnDocumentType(ReturnDocumentType.AFTER_UPDATE).build(); @@ -2818,7 +2818,7 @@ void testAddArrayValue() { SubDocumentUpdate.builder() .subDocument("price") .operator(UpdateOperator.ADD) - .subDocumentValue(SubDocumentValue.of(new Integer[] {1, 2, 3})) + .subDocumentValue(SubDocumentValue.of(new Integer[]{1, 2, 3})) .build()); UpdateOptions options = @@ -2865,19 +2865,19 @@ void testAppendToListAllCases() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"newTag1", "newTag2"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"newTag1", "newTag2"})) .build(), // Nested JSONB array: append to existing props.colors SubDocumentUpdate.builder() .subDocument("props.colors") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"green", "yellow"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"green", "yellow"})) .build(), // Nested JSONB: append to non-existent array (creates it) SubDocumentUpdate.builder() .subDocument("sales.regions") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"US", "EU"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"US", "EU"})) .build()); UpdateOptions options = @@ -2956,13 +2956,13 @@ void testAddToListIfAbsentAllCases() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) - .subDocumentValue(SubDocumentValue.of(new String[] {"existing1", "newTag"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"existing1", "newTag"})) .build(), // Nested JSONB: 'red' exists, 'green' is new → adds only 'green' SubDocumentUpdate.builder() .subDocument("props.colors") .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) - .subDocumentValue(SubDocumentValue.of(new String[] {"red", "green"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"red", "green"})) .build()); UpdateOptions options = @@ -3033,13 +3033,13 @@ void testRemoveAllFromListAllCases() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"tag1"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"tag1"})) .build(), // Nested JSONB: remove 'red' and 'blue' → leaves green SubDocumentUpdate.builder() .subDocument("props.colors") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"red", "blue"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"red", "blue"})) .build()); UpdateOptions options = @@ -3512,7 +3512,7 @@ void testBulkUpdateAllOperatorTypes() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.APPEND_TO_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"newTag1", "newTag2"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"newTag1", "newTag2"})) .build())); updates.put( @@ -3521,7 +3521,7 @@ void testBulkUpdateAllOperatorTypes() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) - .subDocumentValue(SubDocumentValue.of(new String[] {"hygiene", "uniqueTag"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"hygiene", "uniqueTag"})) .build())); updates.put( @@ -3530,7 +3530,7 @@ void testBulkUpdateAllOperatorTypes() throws Exception { SubDocumentUpdate.builder() .subDocument("tags") .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) - .subDocumentValue(SubDocumentValue.of(new String[] {"plastic"})) + .subDocumentValue(SubDocumentValue.of(new String[]{"plastic"})) .build())); BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); @@ -3576,6 +3576,170 @@ void testBulkUpdateAllOperatorTypes() throws Exception { } } + @Test + @DisplayName( + "Should efficiently batch updates across multiple key groups with complex operations") + void testBulkUpdateMultipleGroupsComplexOperations() throws Exception { + Map> updates = new LinkedHashMap<>(); + + // ===== Group 1: Top-level primitive + top-level array (3 keys: 1, 5, 8) ===== + // All have item="Soap" - these should be batched together + // This tests: SET on primitive field, APPEND_TO_LIST on array field + List group1Updates = + List.of( + SubDocumentUpdate.of("price", 99), // SET operator (top-level primitive) + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.APPEND_TO_LIST) + .subDocumentValue(SubDocumentValue.of(new String[]{"updated-tag", "batch-test"})) + .build()); // APPEND_TO_LIST on top-level array + + updates.put(rawKey("1"), group1Updates); + updates.put(rawKey("5"), group1Updates); + updates.put(rawKey("8"), group1Updates); + + // ===== Group 2: Nested JSONB updates (2 keys: 3, 7) ===== + // Both have props - these should be batched together + // This tests: SET on nested JSONB fields + List group2Updates = + List.of( + SubDocumentUpdate.builder() + .subDocument("props.brand") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("PremiumBrand")) + .build(), // SET on nested JSONB primitive + SubDocumentUpdate.builder() + .subDocument("props.size") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("XL")) + .build()); // SET on another nested field + + updates.put(rawKey("3"), group2Updates); + updates.put(rawKey("7"), group2Updates); + + // ===== Group 3: ADD operator + REMOVE_ALL_FROM_LIST (2 keys: 2, 6) ===== + // Both have quantity and tags - these should be batched together + // This tests: ADD on numeric field, REMOVE_ALL_FROM_LIST on array + List group3Updates = + List.of( + SubDocumentUpdate.builder() + .subDocument("quantity") + .operator(UpdateOperator.ADD) + .subDocumentValue(SubDocumentValue.of(100)) + .build(), // ADD to numeric field + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[]{"glass", "plastic"})) + .build()); // REMOVE_ALL_FROM_LIST + + updates.put(rawKey("2"), group3Updates); + updates.put(rawKey("6"), group3Updates); + + // Execute bulk update - should have 3 groups with 2-3 keys each + BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); + + // Total unique keys: 1, 2, 3, 5, 6, 7, 8 = 7 keys + assertEquals(7, result.getUpdatedCount(), "Should update 7 rows"); + + // Verify keys 1, 5, 8 have Group 1 updates (top-level primitive + array) + for (String id : List.of("1", "5", "8")) { + try (CloseableIterator iter = flatCollection.find(queryById(id))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(99, json.get("price").asInt(), "Key " + id + " price should be 99"); + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertTrue( + tagList.contains("updated-tag"), "Key " + id + " should contain 'updated-tag'"); + assertTrue(tagList.contains("batch-test"), "Key " + id + " should contain 'batch-test'"); + } + } + + // Verify keys 3, 7 have Group 2 updates (nested JSONB) + for (String id : List.of("3", "7")) { + try (CloseableIterator iter = flatCollection.find(queryById(id))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode props = json.get("props"); + assertNotNull(props, "Key " + id + " should have props"); + assertEquals( + "PremiumBrand", + props.get("brand").asText(), + "Key " + id + " brand should be updated"); + assertEquals("XL", props.get("size").asText(), "Key " + id + " size should be XL"); + } + } + + // Verify keys 2, 6 have Group 3 updates (ADD + REMOVE_ALL_FROM_LIST) + try (CloseableIterator iter = flatCollection.find(queryById("2"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(101, json.get("quantity").asInt()); // 1 + 100 + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertFalse(tagList.contains("glass"), "Key 2 should not have 'glass' tag"); + } + + try (CloseableIterator iter = flatCollection.find(queryById("6"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals(105, json.get("quantity").asInt()); // 5 + 100 + JsonNode tags = json.get("tags"); + List tagList = new ArrayList<>(); + tags.forEach(t -> tagList.add(t.asText())); + assertFalse(tagList.contains("plastic"), "Key 6 should not have 'plastic' tag"); + } + } + + @Test + @DisplayName( + "Should batch keys whose update shape matches by column:operator:path but whose value " + + "arrays differ in length (nested JSONB REMOVE_ALL_FROM_LIST)") + void testBulkUpdateSameShapeDifferentParamCardinality() throws Exception { + Map> updates = new LinkedHashMap<>(); + + updates.put( + rawKey("1"), + List.of( + SubDocumentUpdate.builder() + .subDocument("props.colors") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[]{"Blue"})) + .build())); + + updates.put( + rawKey("5"), + List.of( + SubDocumentUpdate.builder() + .subDocument("props.colors") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[]{"Orange", "Blue"})) + .build())); + + BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); + + assertEquals(2, result.getUpdatedCount()); + + try (CloseableIterator iter = flatCollection.find(queryById("1"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode colors = json.get("props").get("colors"); + List colorList = new ArrayList<>(); + colors.forEach(c -> colorList.add(c.asText())); + assertEquals(List.of("Green"), colorList); + } + + try (CloseableIterator iter = flatCollection.find(queryById("5"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode colors = json.get("props").get("colors"); + assertEquals(0, colors.size()); + } + } + @Test @DisplayName("Should handle edge cases: empty map, null map, non-existent keys") void testBulkUpdateEdgeCases() throws Exception { diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java index 3fb71934..cbbcbfa3 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java @@ -97,6 +97,7 @@ public class FlatPostgresCollection extends PostgresCollection { "Write operations are not supported for flat collections yet!"; private static final String MISSING_COLUMN_STRATEGY_CONFIG = "missingColumnStrategy"; private static final String DEFAULT_PRIMARY_KEY_COLUMN = "key"; + private static final String SHAPE_KEY_DELIMITER = "\u0001"; private static final Map UPDATE_PARSER_MAP = Map.ofEntries( @@ -880,59 +881,35 @@ public BulkUpdateResult bulkUpdate( String tableName = tableIdentifier.getTableName(); String quotedPkColumn = PostgresUtils.wrapFieldNamesWithDoubleQuotes(getPKForTable(tableName)); - - Set updatedKeys = new HashSet<>(); - long batchUpdateTimestamp = System.currentTimeMillis(); - try (Connection connection = client.getPooledConnection()) { - for (Map.Entry> entry : updates.entrySet()) { - Key key = entry.getKey(); - Collection keyUpdates = entry.getValue(); + int totalUpdated = 0; - if (keyUpdates == null || keyUpdates.isEmpty()) { - continue; - } + try (Connection connection = client.getPooledConnection()) { + // Group keys by their "SQL shape" (same SET-clause fragments AND param count) + Map keyGroups = + groupKeysByUpdateShape(connection, updates, tableName); + // Execute one multi-row UPDATE per group (or fallback to single-key if group size = 1) + for (Map.Entry entry : keyGroups.entrySet()) { try { - boolean updated = - updateSingleKey( - connection, key, keyUpdates, tableName, quotedPkColumn, batchUpdateTimestamp); - if (updated) { - updatedKeys.add(key); - } + int updated = + executeBatchUpdate( + connection, entry.getValue(), tableName, quotedPkColumn, batchUpdateTimestamp); + totalUpdated += updated; } catch (Exception e) { - LOGGER.warn("Failed to update key {}: {}", key, e.getMessage()); - // Continue with other keys - no cross-key atomicity + LOGGER.warn( + "Failed to update key group (size: {}): {}", + entry.getValue().getKeys().size(), + e.getMessage()); + // Continue with other groups - no cross-group atomicity } } } catch (SQLException e) { throw new IOException("Failed to get connection for bulk update", e); } - return new BulkUpdateResult(updatedKeys.size()); - } - - private boolean updateSingleKey( - Connection connection, - Key key, - Collection keyUpdates, - String tableName, - String quotedPkColumn, - long keyUpdateTimestamp) - throws IOException, SQLException { - - updateValidator.validate(keyUpdates); - Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); - - return executeKeyUpdate( - connection, - key, - keyUpdates, - tableName, - quotedPkColumn, - resolvedColumns, - keyUpdateTimestamp); + return new BulkUpdateResult(totalUpdated); } private boolean executeKeyUpdate( @@ -978,6 +955,151 @@ private boolean executeKeyUpdate( } } + /** + * Groups keys that produce identical SET-clause SQL together. Two keys share a shape only if + * {@code buildSetClauseFragments} renders the exact same fragment list and the same number of + * bind parameters — required because {@code executeBatchUpdate} reuses one PreparedStatement per + * group. Operators whose generated SQL or placeholder count varies with the input value (e.g., + * nested-JSONB REMOVE_ALL_FROM_LIST emitting 1+N placeholders) will land in distinct groups. + */ + private Map groupKeysByUpdateShape( + Connection connection, Map> updates, String tableName) { + + Map groups = new LinkedHashMap<>(); + + for (Map.Entry> entry : updates.entrySet()) { + Key key = entry.getKey(); + Collection keyUpdates = entry.getValue(); + + if (keyUpdates == null || keyUpdates.isEmpty()) { + continue; + } + + try { + updateValidator.validate(keyUpdates); + Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); + + List setFragments = new ArrayList<>(); + List params = new ArrayList<>(); + boolean hasUpdates = + buildSetClauseFragments( + connection, keyUpdates, tableName, resolvedColumns, setFragments, params); + if (!hasUpdates) { + continue; + } + + String shapeKey = computeUpdateShapeKey(setFragments, params.size()); + + groups + .computeIfAbsent(shapeKey, k -> new KeyUpdateGroup(resolvedColumns, setFragments)) + .addKeyWithParams(key, params); + + } catch (Exception e) { + LOGGER.warn("Failed to group key {}: {}", key, e.getMessage()); + } + } + + return groups; + } + + private String computeUpdateShapeKey(List setFragments, int paramCount) { + return paramCount + "|" + String.join(SHAPE_KEY_DELIMITER, setFragments); + } + + /** + * Executes a batch UPDATE for all keys in the group using JDBC batching. The group's setFragments + * and per-key params were rendered during grouping, so all keys here are guaranteed to share the + * same SQL and placeholder count. + */ + private int executeBatchUpdate( + Connection connection, + KeyUpdateGroup keyGroup, + String tableName, + String quotedPkColumn, + long epochMillis) + throws SQLException { + + List keys = keyGroup.getKeys(); + List> allKeyParams = keyGroup.getKeyParams(); + + List setFragments = new ArrayList<>(keyGroup.getSetFragments()); + List timestampParam = new ArrayList<>(); + appendLastUpdatedTimestamp(setFragments, timestampParam, tableName, epochMillis); + + String sql = + String.format( + "UPDATE %s SET %s WHERE %s = ?", + tableIdentifier, String.join(", ", setFragments), quotedPkColumn); + + LOGGER.debug("Executing batch update SQL: {} for {} keys", sql, keys.size()); + + try (PreparedStatement ps = connection.prepareStatement(sql)) { + for (int i = 0; i < keys.size(); i++) { + int idx = 1; + for (Object param : allKeyParams.get(i)) { + ps.setObject(idx++, param); + } + for (Object param : timestampParam) { + ps.setObject(idx++, param); + } + ps.setObject(idx, keys.get(i).toString()); // WHERE clause parameter + ps.addBatch(); + } + + int[] results = ps.executeBatch(); + int totalUpdated = 0; + for (int result : results) { + if (result > 0) { + totalUpdated++; + } + } + + LOGGER.debug("Batch update affected {} rows out of {} keys", totalUpdated, keys.size()); + return totalUpdated; + } catch (SQLException e) { + LOGGER.warn("Failed to execute batch update. SQL: {}, Error: {}", sql, e.getMessage()); + throw e; + } + } + + /** + * Holds a group of keys that share the same SET-clause SQL. {@code setFragments} is rendered once + * during grouping; {@code keyParams} stores the bind values for each key in lockstep with {@code + * keys}. + */ + private static class KeyUpdateGroup { + private final Map resolvedColumns; + private final List setFragments; + private final List keys = new ArrayList<>(); + private final List> keyParams = new ArrayList<>(); + + KeyUpdateGroup(Map resolvedColumns, List setFragments) { + this.resolvedColumns = resolvedColumns; + this.setFragments = setFragments; + } + + void addKeyWithParams(Key key, List params) { + keys.add(key); + keyParams.add(params); + } + + Map getResolvedColumns() { + return resolvedColumns; + } + + List getSetFragments() { + return setFragments; + } + + List getKeys() { + return keys; + } + + List> getKeyParams() { + return keyParams; + } + } + /** * Validates all updates and resolves column names. *