@@ -156,39 +156,38 @@ struct ExpandStrideConverter final
156156};
157157} // namespace
158158
159- LogicalResult ExpandStrideConverter::matchAndRewrite (
160- tensor::InsertSliceOp op, OpAdaptor adaptor,
161- ConversionPatternRewriter &rewriter) const {
162- // / The linalg-to-rock passes emits the following expression
163- // / for expanding the strides. We are matching the following IR
164- // / %empty = tensor.empty() : ....
165- // / %inserted_slice = tensor.insert_slice %actual_data into %empty ...
166- auto tensorEmpty =
167- dyn_cast<tensor::EmptyOp>(op.getOperand (1 ).getDefiningOp ());
168- if (!tensorEmpty || !tensorEmpty->hasOneUse ()) {
169- return failure ();
159+ bool mlir::rock::isRockExpandStride (tensor::InsertSliceOp op){
160+ auto emptyOp = op.getDest ().getDefiningOp <tensor::EmptyOp>();
161+ if (!emptyOp){
162+ return false ;
170163 }
171164
172165 // Require statically known slice sizes that exactly match the
173166 // source tensor shape.
174167 auto srcType = dyn_cast<RankedTensorType>(op.getSource ().getType ());
175- if (!srcType){
176- return failure ();
177- }
168+ if (!srcType)
169+ return false ;
178170
179- // into rock.expand_strides, but only in the exact expand-strides shape:
180- // - dest is a tensor.empty with a single use
181- // - all offsets are zero
182- // - all strides are one
183- // - all slice sizes are static and match the source tensor shape
184171 bool isExpandStride = llvm::all_of (op.getStaticOffsets (), [](int64_t offset) { return offset == 0 ; }) &&
185172 llvm::all_of (op.getStaticStrides (), [](int64_t stride) { return stride == 1 ; }) &&
186173 llvm::none_of (op.getStaticSizes (),
187174 [](int64_t s) { return s == ShapedType::kDynamic ; }) && op.getStaticSizes () == srcType.getShape ();
175+ return isExpandStride;
176+ }
188177
189- if (!isExpandStride){
178+ LogicalResult ExpandStrideConverter::matchAndRewrite (
179+ tensor::InsertSliceOp op, OpAdaptor adaptor,
180+ ConversionPatternRewriter &rewriter) const {
181+ // / The linalg-to-rock passes emits the following expression
182+ // / for expanding the strides. We are matching the following IR
183+ // / %empty = tensor.empty() : ....
184+ // / %inserted_slice = tensor.insert_slice %actual_data into %empty ...
185+ if (!rock::isRockExpandStride (op)){
190186 return failure ();
191187 }
188+ auto tensorEmpty =
189+ dyn_cast<tensor::EmptyOp>(op.getOperand (1 ).getDefiningOp ());
190+ assert (tensorEmpty && " Should have been checked by isRockExpandStride" );
192191
193192 Location loc = op.getLoc ();
194193 auto alloc = bufferization::AllocTensorOp::create (
0 commit comments