diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index b8a28f78c..ea68adf70 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -719,33 +719,20 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); - // Groupings are set using the AggregateRel grouping_expression mechanism - if (!rel.getGroupingExpressionsList().isEmpty()) { - List allGroupingExpressions = - rel.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(java.util.stream.Collectors.toList()); - - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List references = grouping.getExpressionReferencesList(); - List groupExpressions = new ArrayList<>(); - for (int ref : references) { - groupExpressions.add(allGroupingExpressions.get(ref)); - } - groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build()); - } + // Convert grouping expressions from the aggregate-level grouping_expressions list + // Each grouping references expressions by index into this list + List allGroupingExpressions = + rel.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(java.util.stream.Collectors.toList()); - } else { - // Groupings are set using the deprecated Grouping grouping_expressions mechanism - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - groupings.add( - Aggregate.Grouping.builder() - .expressions( - grouping.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(java.util.stream.Collectors.toList())) - .build()); + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List references = grouping.getExpressionReferencesList(); + List groupExpressions = new ArrayList<>(); + for (int ref : references) { + groupExpressions.add(allGroupingExpressions.get(ref)); } + groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build()); } List measures = new ArrayList<>(rel.getMeasuresCount()); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 0b9aebada..aa92ee9ea 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -216,8 +216,6 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { private AggregateRel.Grouping toProto( Aggregate.Grouping grouping, List uniqueGroupingExpressions) { return AggregateRel.Grouping.newBuilder() - .addAllGroupingExpressions( - grouping.getExpressions().stream().map(this::toProto).collect(Collectors.toList())) .addAllExpressionReferences( grouping.getExpressions().stream() .map(e -> uniqueGroupingExpressions.indexOf(e)) diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java index 07f1e25cd..25028ca64 100644 --- a/core/src/test/java/io/substrait/relation/AggregateRelTest.java +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -73,27 +73,15 @@ private static List> getExpressionReferences(AggregateRel aggregat .collect(Collectors.toList()); } - /** - * Helper method to extract deprecated grouping expressions from an AggregateRel. - * - * @param aggregateRel the AggregateRel to extract grouping expressions from - * @return a list of lists, where each inner list contains the grouping expressions for a grouping - */ - private static List> getGroupingExpressions(AggregateRel aggregateRel) { - return aggregateRel.getGroupingsList().stream() - .map(grouping -> grouping.getGroupingExpressionsList()) - .collect(Collectors.toList()); - } - @Test - void testDeprecatedGroupingExpressionConversion() { + void testGroupingExpressionConversion() { Expression col1Ref = createFieldReference(0); Expression col2Ref = createFieldReference(1); AggregateRel.Grouping grouping = AggregateRel.Grouping.newBuilder() - .addGroupingExpressions(col1Ref) // deprecated proto form - .addGroupingExpressions(col2Ref) + .addExpressionReferences(0) + .addExpressionReferences(1) .build(); // Build an input ReadRel @@ -103,10 +91,12 @@ void testDeprecatedGroupingExpressionConversion() { .setBaseSchema(namedStruct) .build(); - // Build the AggregateRel with the new grouping_expressions field + // Build the AggregateRel with grouping_expressions field AggregateRel aggrProto = AggregateRel.newBuilder() .setInput(Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) .addGroupings(grouping) .build(); @@ -123,10 +113,8 @@ void testDeprecatedGroupingExpressionConversion() { assertTrue(roundtripRel.hasAggregate()); AggregateRel roundtripAgg = roundtripRel.getAggregate(); - // Verify new expression_references structure + // Verify expression_references structure assertEquals(List.of(List.of(0, 1)), getExpressionReferences(roundtripAgg)); - // Verify backward compatibility: deprecated grouping_expressions field is also populated - assertEquals(List.of(List.of(col1Ref, col2Ref)), getGroupingExpressions(roundtripAgg)); // Verify aggregate-level grouping_expressions field is populated assertEquals(List.of(col1Ref, col2Ref), roundtripAgg.getGroupingExpressionsList()); } @@ -149,7 +137,7 @@ void testAggregateWithSingleGrouping() { .setBaseSchema(namedStruct) .build(); - // Build the AggregateRel with the new grouping_expressions field + // Build the AggregateRel with grouping_expressions field AggregateRel aggrProto = AggregateRel.newBuilder() .setInput(Rel.newBuilder().setRead(readProto)) @@ -171,10 +159,8 @@ void testAggregateWithSingleGrouping() { assertTrue(roundtripRel.hasAggregate()); AggregateRel roundtripAgg = roundtripRel.getAggregate(); - // Verify new expression_references structure + // Verify expression_references structure assertEquals(List.of(List.of(0, 1)), getExpressionReferences(roundtripAgg)); - // Verify backward compatibility: deprecated grouping_expressions field is also populated - assertEquals(List.of(List.of(col1Ref, col2Ref)), getGroupingExpressions(roundtripAgg)); // Verify aggregate-level grouping_expressions field is populated assertEquals(List.of(col1Ref, col2Ref), roundtripAgg.getGroupingExpressionsList()); } @@ -200,7 +186,7 @@ void testAggregateWithMultipleGroupings() { .setBaseSchema(namedStruct) .build(); - // Build the AggregateRel with the new grouping_expressions field + // Build the AggregateRel with grouping_expressions field AggregateRel aggrProto = AggregateRel.newBuilder() .setInput(Rel.newBuilder().setRead(readProto)) @@ -224,11 +210,8 @@ void testAggregateWithMultipleGroupings() { assertTrue(roundtripRel.hasAggregate()); AggregateRel roundtripAgg = roundtripRel.getAggregate(); - // Verify new expression_references structure + // Verify expression_references structure assertEquals(List.of(List.of(0, 1), List.of(1)), getExpressionReferences(roundtripAgg)); - // Verify backward compatibility: deprecated grouping_expressions field is also populated - assertEquals( - List.of(List.of(col1Ref, col2Ref), List.of(col2Ref)), getGroupingExpressions(roundtripAgg)); // Verify aggregate-level grouping_expressions field is populated assertEquals(List.of(col1Ref, col2Ref), roundtripAgg.getGroupingExpressionsList()); }