@@ -39,42 +39,6 @@ inline Tensor expand(const Tensor& self,
3939 auto input_dims = pd_tensor.dims ();
4040 auto input_rank = static_cast <size_t >(input_dims.size ());
4141
42- auto tile_and_slice_to_target =
43- [&](const paddle::Tensor& input,
44- const std::vector<int64_t >& input_shape,
45- const std::vector<int64_t >& target_shape) -> Tensor {
46- size_t rank = target_shape.size ();
47- std::vector<int64_t > repeat_times (rank, 1 );
48- for (size_t i = 0 ; i < rank; ++i) {
49- auto in_size = input_shape[i];
50- auto target_size = target_shape[i];
51-
52- if (in_size == 0 || target_size == 0 ) {
53- repeat_times[i] = 0 ;
54- } else if (target_size <= in_size) {
55- repeat_times[i] = 1 ;
56- } else {
57- repeat_times[i] = (target_size + in_size - 1 ) / in_size;
58- }
59- }
60-
61- paddle::Tensor tiled =
62- paddle::experimental::tile (input, phi::IntArray (repeat_times));
63-
64- std::vector<int64_t > axes (rank);
65- std::vector<int64_t > starts (rank, 0 );
66- std::vector<int64_t > ends (rank);
67- std::vector<int64_t > strides (rank, 1 );
68- for (size_t i = 0 ; i < rank; ++i) {
69- axes[i] = static_cast <int64_t >(i);
70- ends[i] = target_shape[i];
71- }
72-
73- paddle::Tensor sliced =
74- paddle::experimental::slice (tiled, axes, starts, ends, strides, {});
75- return Tensor (sliced);
76- };
77-
7842 // PyTorch's expand uses right-alignment semantics:
7943 // - For 1D tensor expand to 2D: {3}.expand({3,4}) treats input as {3,1},
8044 // expands to {3,4}
@@ -86,26 +50,24 @@ inline Tensor expand(const Tensor& self,
8650 // then expand: dim 0: 3 stays 3, dim 1: 1 -> 4 -> result {3, 4}
8751
8852 if (input_rank < target_rank) {
89- // Add trailing 1s to right-align with target shape (PyTorch behavior)
90- // Input {3 }, target {3, 4 } -> reshape to {3 , 1}
91- std::vector<int64_t > reshape_vec (input_rank , 1 );
53+ // Add leading 1s to right-align with target shape (PyTorch behavior)
54+ // Input {1, 2 }, target {2, 3, 2 } -> reshape to {1 , 1, 2 }
55+ std::vector<int64_t > reshape_vec (target_rank , 1 );
9256 for (size_t i = 0 ; i < input_rank; ++i) {
93- reshape_vec[i] = input_dims[i];
94- }
95- // Add trailing 1s
96- while (reshape_vec.size () < target_rank) {
97- reshape_vec.push_back (1 );
57+ reshape_vec[target_rank - input_rank + i] = input_dims[i];
9858 }
9959
10060 // Check if Paddle's expand can handle this right-aligned shape
10161 // Paddle allows: input[i] == 1 (can expand), or input[i] == target[i]
10262 // (match)
10363 bool can_use_paddle_expand = true ;
64+ size_t fail_dim = 0 ;
10465 for (size_t i = 0 ; i < target_rank; ++i) {
10566 bool dim_can_expand = (reshape_vec[i] == 1 );
10667 bool dim_is_matching = (reshape_vec[i] == target_size_vec[i]);
10768 if (!dim_can_expand && !dim_is_matching) {
10869 can_use_paddle_expand = false ;
70+ fail_dim = i;
10971 break ;
11072 }
11173 }
@@ -119,18 +81,23 @@ inline Tensor expand(const Tensor& self,
11981 return Tensor (result);
12082 }
12183
122- // If Paddle's expand can't handle it, use tile + slice as fallback
123- paddle::Tensor reshaped =
124- paddle::experimental::reshape (pd_tensor, phi::IntArray (reshape_vec));
125- return tile_and_slice_to_target (reshaped, reshape_vec, target_size_vec);
84+ PD_THROW (" expand(): the expanded size of the tensor (" ,
85+ target_size_vec[fail_dim],
86+ " ) must match the existing size (" ,
87+ reshape_vec[fail_dim],
88+ " ) at non-singleton dimension " ,
89+ fail_dim,
90+ " ." );
12691 } else if (input_rank == target_rank) {
127- // Same rank - check if we can use expand directly or need tile
92+ // Same rank - check if we can use expand directly
12893 bool can_use_paddle_expand = true ;
94+ size_t fail_dim = 0 ;
12995 for (size_t i = 0 ; i < target_rank; ++i) {
13096 auto in_size = input_dims[i];
13197 auto target_size = target_size_vec[i];
13298 if (in_size != 1 && in_size != target_size) {
13399 can_use_paddle_expand = false ;
100+ fail_dim = i;
134101 break ;
135102 }
136103 }
@@ -141,33 +108,20 @@ inline Tensor expand(const Tensor& self,
141108 return Tensor (result);
142109 }
143110
144- // Need tile + slice fallback
145- std::vector<int64_t > input_shape (target_rank);
146- for (size_t i = 0 ; i < target_rank; ++i) {
147- input_shape[i] = input_dims[i];
148- }
149- return tile_and_slice_to_target (pd_tensor, input_shape, target_size_vec);
111+ PD_THROW (" expand(): the expanded size of the tensor (" ,
112+ target_size_vec[fail_dim],
113+ " ) must match the existing size (" ,
114+ input_dims[fail_dim],
115+ " ) at non-singleton dimension " ,
116+ fail_dim,
117+ " ." );
150118 } else {
151- // Input has more dimensions.
152- // Keep the trailing target_rank dimensions and slice leading dimensions to
153- // 1 before reshape, so total element count remains valid.
154- paddle::Tensor squeezed = pd_tensor;
155- size_t leading_dims = input_rank - target_rank;
156- for (size_t i = 0 ; i < leading_dims; ++i) {
157- squeezed = paddle::experimental::slice (
158- squeezed, {static_cast <int64_t >(i)}, {0 }, {1 }, {1 }, {});
159- }
160-
161- std::vector<int64_t > new_shape (target_rank);
162- for (size_t i = 0 ; i < target_rank; ++i) {
163- new_shape[i] = input_dims[i + (input_rank - target_rank)];
164- }
165-
166- // Reshape to target rank, then reuse the same expand implementation.
167- paddle::Tensor reshaped =
168- paddle::experimental::reshape (squeezed, phi::IntArray (new_shape));
169-
170- return expand (Tensor (reshaped), size, implicit);
119+ PD_THROW (" expand(): the number of sizes provided (" ,
120+ target_rank,
121+ " ) must be greater or equal to the number of dimensions in the "
122+ " tensor (" ,
123+ input_rank,
124+ " )." );
171125 }
172126}
173127
0 commit comments