Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -719,33 +719,20 @@ protected Aggregate newAggregate(AggregateRel rel) {

List<Aggregate.Grouping> groupings = new ArrayList<>(rel.getGroupingsCount());

// Groupings are set using the AggregateRel grouping_expression mechanism
if (!rel.getGroupingExpressionsList().isEmpty()) {
List<Expression> allGroupingExpressions =
rel.getGroupingExpressionsList().stream()
.map(protoExprConverter::from)
.collect(java.util.stream.Collectors.toList());

for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
List<Integer> references = grouping.getExpressionReferencesList();
List<Expression> groupExpressions = new ArrayList<>();
for (int ref : references) {
groupExpressions.add(allGroupingExpressions.get(ref));
}
groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build());
}
// Convert grouping expressions from the aggregate-level grouping_expressions list
// Each grouping references expressions by index into this list
List<Expression> allGroupingExpressions =
rel.getGroupingExpressionsList().stream()
.map(protoExprConverter::from)
.collect(java.util.stream.Collectors.toList());

} else {
// Groupings are set using the deprecated Grouping grouping_expressions mechanism
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
groupings.add(
Aggregate.Grouping.builder()
.expressions(
grouping.getGroupingExpressionsList().stream()
.map(protoExprConverter::from)
.collect(java.util.stream.Collectors.toList()))
.build());
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
List<Integer> references = grouping.getExpressionReferencesList();
List<Expression> groupExpressions = new ArrayList<>();
for (int ref : references) {
groupExpressions.add(allGroupingExpressions.get(ref));
}
groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build());
}

List<Aggregate.Measure> measures = new ArrayList<>(rel.getMeasuresCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) {
private AggregateRel.Grouping toProto(
Aggregate.Grouping grouping, List<Expression> uniqueGroupingExpressions) {
return AggregateRel.Grouping.newBuilder()
.addAllGroupingExpressions(
grouping.getExpressions().stream().map(this::toProto).collect(Collectors.toList()))
.addAllExpressionReferences(
grouping.getExpressions().stream()
.map(e -> uniqueGroupingExpressions.indexOf(e))
Expand Down
39 changes: 11 additions & 28 deletions core/src/test/java/io/substrait/relation/AggregateRelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,27 +73,15 @@ private static List<List<Integer>> getExpressionReferences(AggregateRel aggregat
.collect(Collectors.toList());
}

/**
* Helper method to extract deprecated grouping expressions from an AggregateRel.
*
* @param aggregateRel the AggregateRel to extract grouping expressions from
* @return a list of lists, where each inner list contains the grouping expressions for a grouping
*/
private static List<List<Expression>> getGroupingExpressions(AggregateRel aggregateRel) {
return aggregateRel.getGroupingsList().stream()
.map(grouping -> grouping.getGroupingExpressionsList())
.collect(Collectors.toList());
}

@Test
void testDeprecatedGroupingExpressionConversion() {
void testGroupingExpressionConversion() {
Expression col1Ref = createFieldReference(0);
Expression col2Ref = createFieldReference(1);

AggregateRel.Grouping grouping =
AggregateRel.Grouping.newBuilder()
.addGroupingExpressions(col1Ref) // deprecated proto form
.addGroupingExpressions(col2Ref)
.addExpressionReferences(0)
.addExpressionReferences(1)
.build();

// Build an input ReadRel
Expand All @@ -103,10 +91,12 @@ void testDeprecatedGroupingExpressionConversion() {
.setBaseSchema(namedStruct)
.build();

// Build the AggregateRel with the new grouping_expressions field
// Build the AggregateRel with grouping_expressions field
AggregateRel aggrProto =
AggregateRel.newBuilder()
.setInput(Rel.newBuilder().setRead(readProto))
.addGroupingExpressions(col1Ref)
.addGroupingExpressions(col2Ref)
.addGroupings(grouping)
.build();

Expand All @@ -123,10 +113,8 @@ void testDeprecatedGroupingExpressionConversion() {
assertTrue(roundtripRel.hasAggregate());

AggregateRel roundtripAgg = roundtripRel.getAggregate();
// Verify new expression_references structure
// Verify expression_references structure
assertEquals(List.of(List.of(0, 1)), getExpressionReferences(roundtripAgg));
// Verify backward compatibility: deprecated grouping_expressions field is also populated
assertEquals(List.of(List.of(col1Ref, col2Ref)), getGroupingExpressions(roundtripAgg));
// Verify aggregate-level grouping_expressions field is populated
assertEquals(List.of(col1Ref, col2Ref), roundtripAgg.getGroupingExpressionsList());
}
Expand All @@ -149,7 +137,7 @@ void testAggregateWithSingleGrouping() {
.setBaseSchema(namedStruct)
.build();

// Build the AggregateRel with the new grouping_expressions field
// Build the AggregateRel with grouping_expressions field
AggregateRel aggrProto =
AggregateRel.newBuilder()
.setInput(Rel.newBuilder().setRead(readProto))
Expand All @@ -171,10 +159,8 @@ void testAggregateWithSingleGrouping() {
assertTrue(roundtripRel.hasAggregate());

AggregateRel roundtripAgg = roundtripRel.getAggregate();
// Verify new expression_references structure
// Verify expression_references structure
assertEquals(List.of(List.of(0, 1)), getExpressionReferences(roundtripAgg));
// Verify backward compatibility: deprecated grouping_expressions field is also populated
assertEquals(List.of(List.of(col1Ref, col2Ref)), getGroupingExpressions(roundtripAgg));
// Verify aggregate-level grouping_expressions field is populated
assertEquals(List.of(col1Ref, col2Ref), roundtripAgg.getGroupingExpressionsList());
}
Expand All @@ -200,7 +186,7 @@ void testAggregateWithMultipleGroupings() {
.setBaseSchema(namedStruct)
.build();

// Build the AggregateRel with the new grouping_expressions field
// Build the AggregateRel with grouping_expressions field
AggregateRel aggrProto =
AggregateRel.newBuilder()
.setInput(Rel.newBuilder().setRead(readProto))
Expand All @@ -224,11 +210,8 @@ void testAggregateWithMultipleGroupings() {
assertTrue(roundtripRel.hasAggregate());

AggregateRel roundtripAgg = roundtripRel.getAggregate();
// Verify new expression_references structure
// Verify expression_references structure
assertEquals(List.of(List.of(0, 1), List.of(1)), getExpressionReferences(roundtripAgg));
// Verify backward compatibility: deprecated grouping_expressions field is also populated
assertEquals(
List.of(List.of(col1Ref, col2Ref), List.of(col2Ref)), getGroupingExpressions(roundtripAgg));
// Verify aggregate-level grouping_expressions field is populated
assertEquals(List.of(col1Ref, col2Ref), roundtripAgg.getGroupingExpressionsList());
}
Expand Down
Loading