Skip to content

Commit eb548f6

Browse files
committed
Move into function and refactor
1 parent dce1627 commit eb548f6

File tree

5 files changed

+151
-165
lines changed

5 files changed

+151
-165
lines changed

mlir/include/mlir/Conversion/LinalgToRock/LinalgToRock.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_CONVERSION_LINALGTOROCK_H
1515

1616
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Pass/Pass.h"
1819
#include "mlir/Transforms/DialectConversion.h"
1920

@@ -24,6 +25,13 @@ namespace mlir {
2425
namespace rock {
2526
void populateLinalgToRockConversionPattern(RewritePatternSet &pattern,
2627
MLIRContext *context);
28+
29+
/// A tensor.insert_slice is said to be a rock.expand_stride if it satisfies the following:
30+
/// - dest is a tensor.empty with a single use
31+
/// - all offsets are zero
32+
/// - all strides are one
33+
/// - all slice sizes are static and match the source tensor shape
34+
bool isRockExpandStride(tensor::InsertSliceOp op);
2735
}
2836
} // namespace mlir
2937

mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -156,39 +156,38 @@ struct ExpandStrideConverter final
156156
};
157157
} // namespace
158158

159-
LogicalResult ExpandStrideConverter::matchAndRewrite(
160-
tensor::InsertSliceOp op, OpAdaptor adaptor,
161-
ConversionPatternRewriter &rewriter) const {
162-
/// The linalg-to-rock passes emits the following expression
163-
/// for expanding the strides. We are matching the following IR
164-
/// %empty = tensor.empty() : ....
165-
/// %inserted_slice = tensor.insert_slice %actual_data into %empty ...
166-
auto tensorEmpty =
167-
dyn_cast<tensor::EmptyOp>(op.getOperand(1).getDefiningOp());
168-
if (!tensorEmpty || !tensorEmpty->hasOneUse()) {
169-
return failure();
159+
bool mlir::rock::isRockExpandStride(tensor::InsertSliceOp op){
160+
auto emptyOp = op.getDest().getDefiningOp<tensor::EmptyOp>();
161+
if (!emptyOp){
162+
return false;
170163
}
171164

172165
// Require statically known slice sizes that exactly match the
173166
// source tensor shape.
174167
auto srcType = dyn_cast<RankedTensorType>(op.getSource().getType());
175-
if (!srcType){
176-
return failure();
177-
}
168+
if (!srcType)
169+
return false;
178170

179-
// into rock.expand_strides, but only in the exact expand-strides shape:
180-
// - dest is a tensor.empty with a single use
181-
// - all offsets are zero
182-
// - all strides are one
183-
// - all slice sizes are static and match the source tensor shape
184171
bool isExpandStride = llvm::all_of(op.getStaticOffsets(), [](int64_t offset) { return offset == 0; }) &&
185172
llvm::all_of(op.getStaticStrides(), [](int64_t stride) { return stride == 1; }) &&
186173
llvm::none_of(op.getStaticSizes(),
187174
[](int64_t s) { return s == ShapedType::kDynamic; }) && op.getStaticSizes() == srcType.getShape();
175+
return isExpandStride;
176+
}
188177

189-
if(!isExpandStride){
178+
LogicalResult ExpandStrideConverter::matchAndRewrite(
179+
tensor::InsertSliceOp op, OpAdaptor adaptor,
180+
ConversionPatternRewriter &rewriter) const {
181+
/// The linalg-to-rock passes emits the following expression
182+
/// for expanding the strides. We are matching the following IR
183+
/// %empty = tensor.empty() : ....
184+
/// %inserted_slice = tensor.insert_slice %actual_data into %empty ...
185+
if(!rock::isRockExpandStride(op)){
190186
return failure();
191187
}
188+
auto tensorEmpty =
189+
dyn_cast<tensor::EmptyOp>(op.getOperand(1).getDefiningOp());
190+
assert(tensorEmpty && "Should have been checked by isRockExpandStride");
192191

193192
Location loc = op.getLoc();
194193
auto alloc = bufferization::AllocTensorOp::create(

mlir/lib/Conversion/LinalgToRock/LinalgToRockPass.cpp

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,11 @@ static void populateLinalgToRockDialectConversion(ConversionTarget &target) {
3838
rock::RockDialect, bufferization::BufferizationDialect,
3939
math::MathDialect>();
4040

41-
// tensor.insert_slice with operand of one tensor.empty op can be expanded
42-
// into rock.expand_strides
41+
// a tensor.insert_slice oculd be a rock expand stride, and in that case
42+
// we expand it into a rock.expand_stride
4343
target.addDynamicallyLegalOp<tensor::InsertSliceOp>(
4444
[](tensor::InsertSliceOp op) -> std::optional<bool> {
45-
auto emptyOp = op.getDest().getDefiningOp<tensor::EmptyOp>();
46-
if (!emptyOp){
47-
return true;
48-
}
49-
50-
// Require statically known slice sizes that exactly match the
51-
// source tensor shape.
52-
auto srcType = dyn_cast<RankedTensorType>(op.getSource().getType());
53-
if (!srcType)
54-
return true;
55-
56-
// into rock.expand_strides, but only in the exact expand-strides shape:
57-
// - dest is a tensor.empty with a single use
58-
// - all offsets are zero
59-
// - all strides are one
60-
// - all slice sizes are static and match the source tensor shape
61-
bool isExpandStride = llvm::all_of(op.getStaticOffsets(), [](int64_t offset) { return offset == 0; }) &&
62-
llvm::all_of(op.getStaticStrides(), [](int64_t stride) { return stride == 1; }) &&
63-
llvm::none_of(op.getStaticSizes(),
64-
[](int64_t s) { return s == ShapedType::kDynamic; }) && op.getStaticSizes() == srcType.getShape();
65-
return !isExpandStride;
45+
return !rock::isRockExpandStride(op);
6646
});
6747

6848
// We only allow Linalg operations that are elementwise. Fusion is supported

0 commit comments

Comments
 (0)