Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.ComparisonJoinKey;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
Expand Down Expand Up @@ -440,10 +441,7 @@ public HashJoin hashJoin(
return HashJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.keys(this.comparisonJoinKeys(left, right, leftKeys, rightKeys))
.joinType(joinType)
.remap(remap)
.build();
Expand Down Expand Up @@ -490,15 +488,38 @@ public MergeJoin mergeJoin(
return MergeJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.keys(this.comparisonJoinKeys(left, right, leftKeys, rightKeys))
.joinType(joinType)
.remap(remap)
.build();
}

/**
* Builds a list of {@link ComparisonJoinKey}s pairing the given left/right field indexes with an
* {@link ComparisonJoinKey.SimpleComparisonType#EQ} comparison.
*
* @param left the left input relation
* @param right the right input relation
* @param leftKeys field indexes from the left relation
* @param rightKeys field indexes from the right relation
* @return the list of equality join keys
*/
public List<ComparisonJoinKey> comparisonJoinKeys(
Rel left, Rel right, List<Integer> leftKeys, List<Integer> rightKeys) {
if (leftKeys.size() != rightKeys.size()) {
throw new IllegalArgumentException("Number of left and right keys must be equal.");
}
List<ComparisonJoinKey> keys = new java.util.ArrayList<>(leftKeys.size());
for (int i = 0; i < leftKeys.size(); i++) {
keys.add(
ComparisonJoinKey.of(
this.fieldReference(left, leftKeys.get(i)),
this.fieldReference(right, rightKeys.get(i)),
ComparisonJoinKey.SimpleComparisonType.EQ));
}
return keys;
}

/**
* Creates a nested loop join between two relations.
*
Expand Down
81 changes: 73 additions & 8 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.AbstractExchangeRel;
import io.substrait.relation.physical.BroadcastExchange;
import io.substrait.relation.physical.ComparisonJoinKey;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.ImmutableBroadcastExchange;
import io.substrait.relation.physical.ImmutableExchangeTarget;
Expand All @@ -64,6 +65,7 @@
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.jspecify.annotations.NonNull;

/** Converts from {@link io.substrait.proto.Rel} to {@link io.substrait.relation.Rel} */
Expand Down Expand Up @@ -861,8 +863,6 @@ protected Set newSet(SetRel rel) {
protected Rel newHashJoin(HashJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
List<io.substrait.proto.Expression.FieldReference> leftKeys = rel.getLeftKeysList();
List<io.substrait.proto.Expression.FieldReference> rightKeys = rel.getRightKeysList();

Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Expand All @@ -877,8 +877,13 @@ protected Rel newHashJoin(HashJoinRel rel) {
HashJoin.builder()
.left(left)
.right(right)
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
.keys(
comparisonJoinKeys(
rel.getKeysList(),
rel.getLeftKeysList(),
rel.getRightKeysList(),
leftConverter,
rightConverter))
.joinType(HashJoin.JoinType.fromProto(rel.getType()))
.postJoinFilter(
Optional.ofNullable(
Expand All @@ -896,8 +901,6 @@ protected Rel newHashJoin(HashJoinRel rel) {
protected Rel newMergeJoin(MergeJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
List<io.substrait.proto.Expression.FieldReference> leftKeys = rel.getLeftKeysList();
List<io.substrait.proto.Expression.FieldReference> rightKeys = rel.getRightKeysList();

Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Expand All @@ -912,8 +915,13 @@ protected Rel newMergeJoin(MergeJoinRel rel) {
MergeJoin.builder()
.left(left)
.right(right)
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
.keys(
comparisonJoinKeys(
rel.getKeysList(),
rel.getLeftKeysList(),
rel.getRightKeysList(),
leftConverter,
rightConverter))
.joinType(MergeJoin.JoinType.fromProto(rel.getType()))
.postJoinFilter(
Optional.ofNullable(
Expand All @@ -929,6 +937,63 @@ protected Rel newMergeJoin(MergeJoinRel rel) {
return builder.build();
}

/**
* Builds the {@link ComparisonJoinKey} list for a hash/merge join, preferring the {@code keys}
* field. The deprecated {@code left_keys}/{@code right_keys} fields are only consulted when
* {@code keys} is empty, in which case they are paired up with a {@link
* ComparisonJoinKey.SimpleComparisonType#EQ} comparison.
*/
private List<ComparisonJoinKey> comparisonJoinKeys(
List<io.substrait.proto.ComparisonJoinKey> keys,
List<io.substrait.proto.Expression.FieldReference> leftKeys,
List<io.substrait.proto.Expression.FieldReference> rightKeys,
ProtoExpressionConverter leftConverter,
ProtoExpressionConverter rightConverter) {
if (!keys.isEmpty()) {
return keys.stream()
.map(key -> comparisonJoinKey(key, leftConverter, rightConverter))
.collect(Collectors.toList());
}
if (leftKeys.size() != rightKeys.size()) {
throw new IllegalArgumentException("Number of left and right keys must be equal.");
}
return IntStream.range(0, leftKeys.size())
.mapToObj(
i ->
ComparisonJoinKey.of(
leftConverter.from(leftKeys.get(i)),
rightConverter.from(rightKeys.get(i)),
ComparisonJoinKey.SimpleComparisonType.EQ))
.collect(Collectors.toList());
}

private ComparisonJoinKey comparisonJoinKey(
io.substrait.proto.ComparisonJoinKey key,
ProtoExpressionConverter leftConverter,
ProtoExpressionConverter rightConverter) {
io.substrait.proto.ComparisonJoinKey.ComparisonType comparison = key.getComparison();
final ComparisonJoinKey.ComparisonType comparisonType;
switch (comparison.getInnerTypeCase()) {
case SIMPLE:
comparisonType =
ComparisonJoinKey.SimpleComparison.of(
ComparisonJoinKey.SimpleComparisonType.fromProto(comparison.getSimple()));
break;
case CUSTOM_FUNCTION_REFERENCE:
comparisonType =
ComparisonJoinKey.CustomComparison.of(comparison.getCustomFunctionReference());
break;
default:
throw new IllegalArgumentException(
"Unsupported comparison type: " + comparison.getInnerTypeCase());
}
return ComparisonJoinKey.builder()
.left(leftConverter.from(key.getLeft()))
.right(rightConverter.from(key.getRight()))
.comparison(comparisonType)
.build();
}

protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Expand Down
38 changes: 24 additions & 14 deletions core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.relation.physical.BroadcastExchange;
import io.substrait.relation.physical.ComparisonJoinKey;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.MultiBucketExchange;
Expand Down Expand Up @@ -438,23 +439,20 @@ public Optional<Rel> visit(ExtensionTable extensionTable, EmptyVisitationContext
public Optional<Rel> visit(HashJoin hashJoin, EmptyVisitationContext context) throws E {
Optional<Rel> left = hashJoin.getLeft().accept(this, context);
Optional<Rel> right = hashJoin.getRight().accept(this, context);
Optional<List<FieldReference>> leftKeys =
transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference);
Optional<List<FieldReference>> rightKeys =
transformList(hashJoin.getRightKeys(), context, this::visitFieldReference);
Optional<List<ComparisonJoinKey>> keys =
transformList(hashJoin.getKeys(), context, this::visitComparisonJoinKey);
Optional<Expression> postFilter =
visitOptionalExpression(hashJoin.getPostJoinFilter(), context);

if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
if (allEmpty(left, right, keys, postFilter)) {
return Optional.empty();
}
return Optional.of(
HashJoin.builder()
.from(hashJoin)
.left(left.orElse(hashJoin.getLeft()))
.right(right.orElse(hashJoin.getRight()))
.leftKeys(leftKeys.orElse(hashJoin.getLeftKeys()))
.rightKeys(rightKeys.orElse(hashJoin.getRightKeys()))
.keys(keys.orElse(hashJoin.getKeys()))
.postJoinFilter(or(postFilter, hashJoin::getPostJoinFilter))
.build());
}
Expand All @@ -463,23 +461,20 @@ public Optional<Rel> visit(HashJoin hashJoin, EmptyVisitationContext context) th
public Optional<Rel> visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E {
Optional<Rel> left = mergeJoin.getLeft().accept(this, context);
Optional<Rel> right = mergeJoin.getRight().accept(this, context);
Optional<List<FieldReference>> leftKeys =
transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference);
Optional<List<FieldReference>> rightKeys =
transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference);
Optional<List<ComparisonJoinKey>> keys =
transformList(mergeJoin.getKeys(), context, this::visitComparisonJoinKey);
Optional<Expression> postFilter =
visitOptionalExpression(mergeJoin.getPostJoinFilter(), context);

if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
if (allEmpty(left, right, keys, postFilter)) {
return Optional.empty();
}
return Optional.of(
MergeJoin.builder()
.from(mergeJoin)
.left(left.orElse(mergeJoin.getLeft()))
.right(right.orElse(mergeJoin.getRight()))
.leftKeys(leftKeys.orElse(mergeJoin.getLeftKeys()))
.rightKeys(rightKeys.orElse(mergeJoin.getRightKeys()))
.keys(keys.orElse(mergeJoin.getKeys()))
.postJoinFilter(or(postFilter, mergeJoin::getPostJoinFilter))
.build());
}
Expand Down Expand Up @@ -569,6 +564,21 @@ public Optional<FieldReference> visitFieldReference(
return Optional.of(FieldReference.builder().inputExpression(inputExpression).build());
}

public Optional<ComparisonJoinKey> visitComparisonJoinKey(
ComparisonJoinKey key, EmptyVisitationContext context) throws E {
Optional<FieldReference> left = visitFieldReference(key.getLeft(), context);
Optional<FieldReference> right = visitFieldReference(key.getRight(), context);
if (allEmpty(left, right)) {
return Optional.empty();
}
return Optional.of(
ComparisonJoinKey.builder()
.from(key)
.left(left.orElse(key.getLeft()))
.right(right.orElse(key.getRight()))
.build());
}

protected Optional<List<FunctionArg>> visitFunctionArguments(
List<FunctionArg> funcArgs, EmptyVisitationContext context) throws E {
return CopyOnWriteUtils.<FunctionArg, EmptyVisitationContext, E>transformList(
Expand Down
Loading
Loading