Skip to content

Commit 8b142a9

Browse files
committed
Adddress some comments
1 parent 456a081 commit 8b142a9

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

mlir/lib/Conversion/LinalgToRock/LinalgToRock.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,27 @@ LogicalResult ExpandStrideConverter::matchAndRewrite(
169169
return failure();
170170
}
171171

172+
// Require statically known slice sizes that exactly match the
173+
// source tensor shape.
174+
auto srcType = dyn_cast<RankedTensorType>(op.getSource().getType());
175+
if (!srcType){
176+
return failure();
177+
}
178+
179+
// into rock.expand_strides, but only in the exact expand-strides shape:
180+
// - dest is a tensor.empty with a single use
181+
// - all offsets are zero
182+
// - all strides are one
183+
// - all slice sizes are static and match the source tensor shape
184+
bool isExpandStride = llvm::all_of(op.getStaticOffsets(), [](int64_t offset) { return offset == 0; }) &&
185+
llvm::all_of(op.getStaticStrides(), [](int64_t stride) { return stride == 1; }) &&
186+
llvm::none_of(op.getStaticSizes(),
187+
[](int64_t s) { return s == ShapedType::kDynamic; }) && op.getStaticSizes() == srcType.getShape();
188+
189+
if(!isExpandStride){
190+
return failure();
191+
}
192+
172193
Location loc = op.getLoc();
173194
auto alloc = bufferization::AllocTensorOp::create(
174195
rewriter, loc, tensorEmpty.getResult().getType(), {});

mlir/lib/Conversion/LinalgToRock/LinalgToRockPass.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,26 @@ static void populateLinalgToRockDialectConversion(ConversionTarget &target) {
4343
target.addDynamicallyLegalOp<tensor::InsertSliceOp>(
4444
[](tensor::InsertSliceOp op) -> std::optional<bool> {
4545
auto emptyOp = op.getDest().getDefiningOp<tensor::EmptyOp>();
46-
return !(emptyOp && emptyOp->hasOneUse());
46+
if (!emptyOp){
47+
return true;
48+
}
49+
50+
// Require statically known slice sizes that exactly match the
51+
// source tensor shape.
52+
auto srcType = dyn_cast<RankedTensorType>(op.getSource().getType());
53+
if (!srcType)
54+
return true;
55+
56+
// into rock.expand_strides, but only in the exact expand-strides shape:
57+
// - dest is a tensor.empty with a single use
58+
// - all offsets are zero
59+
// - all strides are one
60+
// - all slice sizes are static and match the source tensor shape
61+
bool isExpandStride = llvm::all_of(op.getStaticOffsets(), [](int64_t offset) { return offset == 0; }) &&
62+
llvm::all_of(op.getStaticStrides(), [](int64_t stride) { return stride == 1; }) &&
63+
llvm::none_of(op.getStaticSizes(),
64+
[](int64_t s) { return s == ShapedType::kDynamic; }) && op.getStaticSizes() == srcType.getShape();
65+
return !isExpandStride;
4766
});
4867

4968
// We only allow Linalg operations that are elementwise. Fusion is supported

mlir/test/Conversion/MIGraphXToLinalg/mixr-to-linalg-ops.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,12 @@ func.func @clip_i32(%arg0: !migraphx.shaped<64x64xi32, 64x1>, %arg1: !migraphx.s
144144
// Test taken from migraphx-to-tosa
145145
// Tests for non-standard shapes.
146146

147-
148-
149147
// CHECK-LABEL: func.func @transposed(
150148
func.func @transposed(%arg0: !migraphx.shaped<4x3xf32, 1x4>) -> !migraphx.shaped<4x3xf32, 1x4> {
151149
%op = migraphx.floor %arg0 : <4x3xf32, 1x4> -> <4x3xf32, 1x4>
152150
// CHECK: %[[FLOOR:.*]] = linalg.floor ins{{.*}} -> tensor<4x3xf32>
153151
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x4xf32>
154-
// CHECK: linalg.transpose ins(%[[FLOOR]] : tensor{{.*}}) outs(%3 : tensor{{.*}}) permutation = [1, 0]
152+
// CHECK: linalg.transpose ins(%[[FLOOR]] : tensor{{.*}}) outs(%[[EMPTY]] : tensor{{.*}}) permutation = [1, 0]
155153
func.return %op : !migraphx.shaped<4x3xf32, 1x4>
156154
}
157155

@@ -163,7 +161,7 @@ func.func @broadcast(%arg0: !migraphx.shaped<4x3xf32, 1x0>, %arg1: !migraphx.sha
163161
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[expanded_0]]
164162
// CHECK-DAG: %[[expanded_1:.*]] = tensor.expand_shape %[[collapsed]]
165163
// CHECK-DAG: %[[zero:.*]] = tensor.empty() : tensor<4x3xf32>
166-
// CHECK-DAG: %[[broadcasted:.*]] = linalg.broadcast ins(%[[expanded_1]] : tensor<4xf32>) outs(%[[zero]] : tensor<4x3xf32>) dimensions = [1]
164+
// CHECK-DAG: %[[broadcasted:.*]] = linalg.broadcast ins(%[[expanded_1]] : tensor<4xf32>) outs(%[[zero]] : tensor<4x3xf32>) dimensions = [1]
167165
%op = migraphx.sub %arg0, %arg1 : <4x3xf32, 1x0>, <4x3xf32, 3x1> -> <4x3xf32, 3x1>
168166
// CHECK-DAG: %[[one:.*]] = tensor.empty() : tensor<4x3xf32>
169167
// CHECK-DAG: %[[two:.*]] = linalg.sub ins(%[[broadcasted]], %[[expanded]] : tensor<4x3xf32>, tensor<4x3xf32>) outs(%[[one]] : tensor<4x3xf32>) -> tensor<4x3xf32>

0 commit comments

Comments
 (0)