@@ -53,60 +53,222 @@ struct AsLogicalShapeOpConverter final
5353};
5454} // namespace
5555
56- // / Checking to see if the permutation vector is like (0, 1, 2, 3, 4, 5, ...)
57- static bool isPermutationStandardForm (ArrayRef<int64_t > permutation) {
58- SmallVector<int64_t , 4 > increasingVec (permutation.size (), 0 );
59- std::iota (increasingVec.begin (), increasingVec.end (), 0 );
60- return llvm::equal (permutation, increasingVec);
61- }
62-
6356LogicalResult AsLogicalShapeOpConverter::matchAndRewrite (
6457 migraphx::AsLogicalShapeOp op, OpAdaptor adaptor,
6558 ConversionPatternRewriter &rewriter) const {
6659 Location loc = op.getLoc ();
6760 migraphx::MIXRShapedType inType = op.getIn ().getType ();
6861 RankedTensorType resultType = op.getOut ().getType ();
69- Value in = adaptor. getIn (); // The shape we are casting from
62+ RankedTensorType memoryType = inType. asMemoryLayoutTensor ();
7063
71- SmallVector<int64_t , 4 > permutation;
72- inType.getStridePermutation (permutation);
73- if (isPermutationStandardForm (permutation)) {
64+ // / Expand a flat/underlying value into the N-D memory layout tensor.
65+ auto expandToMemoryLayout = [&](Value input) -> Value {
66+ if (input.getType () == memoryType)
67+ return input;
7468 SmallVector<ReassociationIndices, 4 > reassociationIndex (
75- 1 , ReassociationIndices (resultType .getRank (), 0 ));
69+ 1 , ReassociationIndices (memoryType .getRank (), 0 ));
7670 std::iota (reassociationIndex[0 ].begin (), reassociationIndex[0 ].end (), 0 );
77- auto newShape = tensor::ExpandShapeOp::create (rewriter, loc, resultType, in,
78- reassociationIndex);
79- rewriter.replaceOp (op, newShape);
71+ return tensor::ExpandShapeOp::create (rewriter, loc, memoryType, input,
72+ reassociationIndex);
73+ };
74+
75+ // / Invert the stride permutation to transpose from memory order back to
76+ // / logical order.
77+ auto transposeToLogicalOrder = [&](Value input) -> Value {
78+ SmallVector<int64_t , 4 > inversePermutation;
79+ inType.getStridePermutation (inversePermutation);
80+ size_t nDims = inversePermutation.size ();
81+ bool hasTranspose =
82+ !llvm::equal (llvm::seq<int64_t >(nDims), inversePermutation);
83+ if (!hasTranspose)
84+ return input;
85+
86+ // Calculating the transposed shape and permutation
87+ SmallVector<int64_t , 4 > permutation, transposedShape;
88+ permutation.resize_for_overwrite (nDims);
89+ transposedShape.resize_for_overwrite (nDims);
90+ RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
91+ for (auto [to, from] : llvm::enumerate (inversePermutation)) {
92+ permutation[from] = to;
93+ transposedShape[from] = inputType.getShape ()[to];
94+ }
95+
96+ Value init = tensor::EmptyOp::create (rewriter, loc, transposedShape,
97+ inputType.getElementType ())
98+ .getResult ();
99+ return linalg::TransposeOp::create (rewriter, loc, input, init, permutation)
100+ .getResult ()[0 ];
101+ };
102+
103+ // / Extract the logical slice when the memory layout is larger than the
104+ // / logical shape (broadcast dimensions are collapsed to size 1).
105+ auto tryExtractSlice = [&](Value input) -> Value {
106+ SmallVector<int64_t , 4 > slicingShape (resultType.getShape ());
107+ for (auto [dim, stride] :
108+ llvm::zip_equal (slicingShape, inType.getStrides ())) {
109+ if (stride == 0 )
110+ dim = 1 ;
111+ }
112+ RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
113+ if (inputType.getShape () == ArrayRef (slicingShape)) {
114+ return input;
115+ }
116+
117+ assert (llvm::none_of (llvm::zip_equal (slicingShape, inputType.getShape ()),
118+ [](auto val) {
119+ auto [sliceDim, inputDim] = val;
120+ return sliceDim > inputDim;
121+ }) &&
122+ " this should have been checked by the verifier as the memory layout "
123+ " must be greater than the logical layout" );
124+
125+ RankedTensorType sliceType = resultType.clone (slicingShape);
126+ SmallVector<OpFoldResult, 4 > offset (sliceType.getRank (),
127+ rewriter.getIndexAttr (0 )),
128+ sizes;
129+ llvm::transform (sliceType.getShape (), std::back_inserter (sizes),
130+ [&](int64_t size) { return rewriter.getIndexAttr (size); });
131+ SmallVector<OpFoldResult, 4 > strides (sliceType.getRank (),
132+ rewriter.getIndexAttr (1 ));
133+ tensor::ExtractSliceOp extractOp = tensor::ExtractSliceOp::create (
134+ rewriter, loc, input, offset, sizes, strides);
135+ return extractOp.getResult ();
136+ };
137+
138+ // / Broadcast along dimensions whose stride is 0 to reach the full logical
139+ // / shape.
140+ auto tryBroadcast = [&](Value input) -> Value {
141+ if (input.getType () == resultType)
142+ return input;
143+ SmallVector<int64_t , 4 > linalgInputShape, broadcastDimensions;
144+ for (auto [index, stride, shape] :
145+ llvm::enumerate (inType.getStrides (), inType.getShape ())) {
146+ if (stride != 0 )
147+ linalgInputShape.push_back (shape);
148+ else
149+ broadcastDimensions.push_back (index);
150+ }
151+ SmallVector<ReassociationIndices, 4 > reassociationOne (
152+ 1 , ReassociationIndices (resultType.getRank (), 0 ));
153+ SmallVector<ReassociationIndices, 4 > reassociationTwo (
154+ 1 , ReassociationIndices (linalgInputShape.size (), 0 ));
155+ std::iota (reassociationOne[0 ].begin (), reassociationOne[0 ].end (), 0 );
156+ std::iota (reassociationTwo[0 ].begin (), reassociationTwo[0 ].end (), 0 );
157+ input =
158+ tensor::CollapseShapeOp::create (rewriter, loc, input, reassociationOne);
159+ input = tensor::ExpandShapeOp::create (
160+ rewriter, loc,
161+ RankedTensorType::get (linalgInputShape, resultType.getElementType ()),
162+ input, reassociationTwo);
163+ auto init = tensor::EmptyOp::create (rewriter, loc, resultType.getShape (),
164+ resultType.getElementType ());
165+ return linalg::BroadcastOp::create (rewriter, loc, input, init,
166+ broadcastDimensions)
167+ .getResult ()[0 ];
168+ };
169+
170+ Value result = expandToMemoryLayout (adaptor.getIn ());
171+ result = transposeToLogicalOrder (result);
172+
173+ if (result.getType () == resultType) {
174+ rewriter.replaceOp (op, result);
80175 return success ();
81176 }
82177
83- return op.emitError (
84- " input shape is non standard or broadcast; cannot convert this shape" );
178+ // handle long stride/broadcasting here
179+ result = tryExtractSlice (result);
180+ result = tryBroadcast (result);
181+
182+ rewriter.replaceOp (op, result);
183+ return success ();
85184}
86185
87186LogicalResult AsUnderlyingShapeConverter::matchAndRewrite (
88187 migraphx::AsUnderlyingShapeOp op, OpAdaptor adaptor,
89188 ConversionPatternRewriter &rewriter) const {
90189 Location loc = op.getLoc ();
190+ migraphx::MIXRShapedType resultType = op.getOut ().getType ();
91191 Value in = adaptor.getIn ();
92- migraphx::MIXRShapedType resultType = op.getResult ().getType ();
93- auto resultTensorType =
94- cast<RankedTensorType>(getTypeConverter ()->convertType (resultType));
192+ RankedTensorType memoryLayoutType = resultType.asMemoryLayoutTensor ();
193+ RankedTensorType inTensorType = cast<RankedTensorType>(in.getType ());
95194
96- SmallVector<int64_t , 4 > permutation;
97- resultType.getStridePermutation (permutation);
98- if (isPermutationStandardForm (permutation)) {
195+ RankedTensorType resultTensorType =
196+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (resultType));
197+ if (!resultTensorType)
198+ return op.emitOpError (" unsupported conversion to underlying shape" );
199+
200+ if (inTensorType == resultTensorType) {
201+ rewriter.replaceOp (op, in);
202+ return success ();
203+ }
204+
205+ // / Transpose from logical order to memory layout order.
206+ auto transposeToMemoryOrder = [&](Value input) -> Value {
207+ SmallVector<int64_t , 4 > permutation;
208+ resultType.getStridePermutation (permutation);
209+ if (llvm::is_sorted (permutation))
210+ return input;
211+ RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
212+ SmallVector<int64_t , 4 > transposedShape;
213+ llvm::transform (permutation, std::back_inserter (transposedShape),
214+ [&](int64_t p) { return inputType.getShape ()[p]; });
215+ auto init = tensor::EmptyOp::create (rewriter, loc, transposedShape,
216+ inputType.getElementType ())
217+ .getResult ();
218+ return linalg::TransposeOp::create (rewriter, loc, input, init, permutation)
219+ .getResult ()[0 ];
220+ };
221+
222+ // / Pad via insert_slice when the transposed shape is smaller than the
223+ // / memory layout (e.g. due to stride-based padding).
224+ auto tryInsertSlice = [&](Value input) -> FailureOr<Value> {
225+ if (input.getType () == memoryLayoutType)
226+ return input;
227+ if (resultType.hasBroadcast ())
228+ return op.emitOpError (
229+ " writing to tensors with broadcasts is unsupported" );
230+ RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
231+ for (auto [index, memDim, inDim] :
232+ llvm::enumerate (memoryLayoutType.getShape (), inputType.getShape ())) {
233+ if (memDim < inDim) {
234+ return op.emitOpError (" memory layout dimension " )
235+ << memDim << " is smaller than logical dimension " << inDim
236+ << " ; this indicates invalid strides" ;
237+ }
238+ }
239+
240+ auto empty =
241+ tensor::EmptyOp::create (rewriter, loc, memoryLayoutType.getShape (),
242+ memoryLayoutType.getElementType ());
243+ int64_t rank = inputType.getRank ();
244+ SmallVector<OpFoldResult> offsets (rank, rewriter.getIndexAttr (0 ));
245+ SmallVector<OpFoldResult> sizes;
246+ for (int64_t dim : inputType.getShape ())
247+ sizes.push_back (rewriter.getIndexAttr (dim));
248+ SmallVector<OpFoldResult> strides (rank, rewriter.getIndexAttr (1 ));
249+ tensor::InsertSliceOp insertSlice = tensor::InsertSliceOp::create (
250+ rewriter, loc, input, empty, offsets, sizes, strides);
251+ insertSlice->setAttr (" rock.is_expand_strides" , rewriter.getUnitAttr ());
252+ return insertSlice.getResult ();
253+ };
254+
255+ // / Collapse the N-D memory layout tensor into the flat underlying shape.
256+ auto collapseToUnderlying = [&](Value input) -> Value {
257+ assert (input.getType () == memoryLayoutType &&
258+ " expected memory layout type before collapsing" );
99259 SmallVector<ReassociationIndices, 4 > reassociationIndex (
100260 1 , ReassociationIndices (resultType.getRank (), 0 ));
101261 std::iota (reassociationIndex[0 ].begin (), reassociationIndex[0 ].end (), 0 );
102- auto reshape = tensor::CollapseShapeOp::create (
103- rewriter, loc, resultTensorType, in, reassociationIndex);
104- rewriter.replaceOp (op, reshape);
105- return success ();
106- }
262+ return tensor::CollapseShapeOp::create (rewriter, loc, resultTensorType,
263+ input, reassociationIndex);
264+ };
107265
108- return op.emitError (
109- " input shape is non standard or broadcast; cannot convert this shape" );
266+ FailureOr<Value> result = tryInsertSlice (transposeToMemoryOrder (in));
267+ if (failed (result))
268+ return failure ();
269+
270+ rewriter.replaceOp (op, collapseToUnderlying (*result));
271+ return success ();
110272}
111273
112274namespace {
0 commit comments