Skip to content

Commit 416eff3

Browse files
authored
[Cpp API Compatibility] Align some other APIs (#78837)
1 parent 16239ff commit 416eff3

12 files changed

Lines changed: 191 additions & 179 deletions

File tree

paddle/phi/api/include/compat/ATen/ops/_values.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@ inline at::Tensor Tensor::_values() const {
3434
return paddle::Tensor(
3535
std::make_shared<phi::DenseTensor>(sparse_coo_tensor->values()));
3636
} else {
37-
auto sparse_csr_tensor =
38-
std::dynamic_pointer_cast<phi::SparseCsrTensor>(tensor_.impl());
39-
PD_CHECK(sparse_csr_tensor != nullptr,
40-
"_values: failed to cast tensor impl to SparseCsrTensor");
41-
return paddle::Tensor(
42-
std::make_shared<phi::DenseTensor>(sparse_csr_tensor->values()));
37+
PD_THROW("_values is not implemented for SparseCsr tensors");
4338
}
4439
}
4540

paddle/phi/api/include/compat/ATen/ops/chunk.h

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,48 @@ namespace at {
2525
inline std::vector<Tensor> chunk(const Tensor& self,
2626
int64_t chunks,
2727
int64_t dim = 0) {
28+
if (chunks <= 0) {
29+
PD_THROW("chunk expects chunks to be greater than 0, got ", chunks);
30+
}
31+
2832
std::vector<Tensor> result;
2933
paddle::Tensor pd_tensor = self._PD_GetInner();
30-
int64_t dim_size = pd_tensor.dims().size() > 0 ? pd_tensor.dims()[dim] : 1;
3134

32-
// PyTorch returns exactly 'chunks' number of tensors, even if some are empty
33-
// When chunks > dim_size, it returns dim_size non-empty tensors plus
34-
// (chunks - dim_size) empty tensors
35-
if (chunks > dim_size) {
36-
// First create non-empty chunks for existing elements
37-
for (int64_t i = 0; i < dim_size; ++i) {
35+
int64_t rank = static_cast<int64_t>(pd_tensor.dims().size());
36+
if (rank == 0) {
37+
PD_THROW("chunk expects at least a 1-dimensional tensor");
38+
}
39+
40+
int64_t original_dim = dim;
41+
if (dim < 0) {
42+
dim += rank;
43+
}
44+
if (dim < 0 || dim >= rank) {
45+
PD_THROW("Dimension out of range (expected to be in range of [",
46+
-rank,
47+
", ",
48+
rank - 1,
49+
"], but got ",
50+
original_dim,
51+
")");
52+
}
53+
54+
int64_t dim_size = pd_tensor.dims()[dim];
55+
56+
if (dim_size == 0) {
57+
for (int64_t i = 0; i < chunks; ++i) {
3858
auto chunk_tensor =
39-
paddle::experimental::slice(pd_tensor, {dim}, {i}, {i + 1}, {1}, {});
59+
paddle::experimental::slice(pd_tensor, {dim}, {0}, {0}, {1}, {});
4060
result.push_back(Tensor(chunk_tensor));
4161
}
42-
// Then add empty chunks
43-
for (int64_t i = dim_size; i < chunks; ++i) {
44-
// Create empty tensor with same shape except for the chunk dimension
45-
std::vector<int64_t> empty_shape;
46-
for (int64_t j = 0; j < pd_tensor.dims().size(); ++j) {
47-
if (j == dim) {
48-
empty_shape.push_back(0);
49-
} else {
50-
empty_shape.push_back(pd_tensor.dims()[j]);
51-
}
52-
}
53-
auto empty_tensor = paddle::experimental::empty(
54-
phi::IntArray(empty_shape), pd_tensor.dtype(), pd_tensor.place());
55-
result.push_back(Tensor(empty_tensor));
56-
}
5762
return result;
5863
}
5964

65+
// PyTorch returns at most 'dim_size' non-empty chunks when chunks > dim_size
66+
if (chunks > dim_size) {
67+
chunks = dim_size;
68+
}
69+
6070
int64_t chunk_size = (dim_size + chunks - 1) / chunks;
6171
int64_t remaining = dim_size;
6272

paddle/phi/api/include/compat/ATen/ops/expand.h

Lines changed: 29 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,6 @@ inline Tensor expand(const Tensor& self,
3939
auto input_dims = pd_tensor.dims();
4040
auto input_rank = static_cast<size_t>(input_dims.size());
4141

42-
auto tile_and_slice_to_target =
43-
[&](const paddle::Tensor& input,
44-
const std::vector<int64_t>& input_shape,
45-
const std::vector<int64_t>& target_shape) -> Tensor {
46-
size_t rank = target_shape.size();
47-
std::vector<int64_t> repeat_times(rank, 1);
48-
for (size_t i = 0; i < rank; ++i) {
49-
auto in_size = input_shape[i];
50-
auto target_size = target_shape[i];
51-
52-
if (in_size == 0 || target_size == 0) {
53-
repeat_times[i] = 0;
54-
} else if (target_size <= in_size) {
55-
repeat_times[i] = 1;
56-
} else {
57-
repeat_times[i] = (target_size + in_size - 1) / in_size;
58-
}
59-
}
60-
61-
paddle::Tensor tiled =
62-
paddle::experimental::tile(input, phi::IntArray(repeat_times));
63-
64-
std::vector<int64_t> axes(rank);
65-
std::vector<int64_t> starts(rank, 0);
66-
std::vector<int64_t> ends(rank);
67-
std::vector<int64_t> strides(rank, 1);
68-
for (size_t i = 0; i < rank; ++i) {
69-
axes[i] = static_cast<int64_t>(i);
70-
ends[i] = target_shape[i];
71-
}
72-
73-
paddle::Tensor sliced =
74-
paddle::experimental::slice(tiled, axes, starts, ends, strides, {});
75-
return Tensor(sliced);
76-
};
77-
7842
// PyTorch's expand uses right-alignment semantics:
7943
// - For 1D tensor expand to 2D: {3}.expand({3,4}) treats input as {3,1},
8044
// expands to {3,4}
@@ -86,26 +50,24 @@ inline Tensor expand(const Tensor& self,
8650
// then expand: dim 0: 3 stays 3, dim 1: 1 -> 4 -> result {3, 4}
8751

8852
if (input_rank < target_rank) {
89-
// Add trailing 1s to right-align with target shape (PyTorch behavior)
90-
// Input {3}, target {3, 4} -> reshape to {3, 1}
91-
std::vector<int64_t> reshape_vec(input_rank, 1);
53+
// Add leading 1s to right-align with target shape (PyTorch behavior)
54+
// Input {1, 2}, target {2, 3, 2} -> reshape to {1, 1, 2}
55+
std::vector<int64_t> reshape_vec(target_rank, 1);
9256
for (size_t i = 0; i < input_rank; ++i) {
93-
reshape_vec[i] = input_dims[i];
94-
}
95-
// Add trailing 1s
96-
while (reshape_vec.size() < target_rank) {
97-
reshape_vec.push_back(1);
57+
reshape_vec[target_rank - input_rank + i] = input_dims[i];
9858
}
9959

10060
// Check if Paddle's expand can handle this right-aligned shape
10161
// Paddle allows: input[i] == 1 (can expand), or input[i] == target[i]
10262
// (match)
10363
bool can_use_paddle_expand = true;
64+
size_t fail_dim = 0;
10465
for (size_t i = 0; i < target_rank; ++i) {
10566
bool dim_can_expand = (reshape_vec[i] == 1);
10667
bool dim_is_matching = (reshape_vec[i] == target_size_vec[i]);
10768
if (!dim_can_expand && !dim_is_matching) {
10869
can_use_paddle_expand = false;
70+
fail_dim = i;
10971
break;
11072
}
11173
}
@@ -119,18 +81,23 @@ inline Tensor expand(const Tensor& self,
11981
return Tensor(result);
12082
}
12183

122-
// If Paddle's expand can't handle it, use tile + slice as fallback
123-
paddle::Tensor reshaped =
124-
paddle::experimental::reshape(pd_tensor, phi::IntArray(reshape_vec));
125-
return tile_and_slice_to_target(reshaped, reshape_vec, target_size_vec);
84+
PD_THROW("expand(): the expanded size of the tensor (",
85+
target_size_vec[fail_dim],
86+
") must match the existing size (",
87+
reshape_vec[fail_dim],
88+
") at non-singleton dimension ",
89+
fail_dim,
90+
".");
12691
} else if (input_rank == target_rank) {
127-
// Same rank - check if we can use expand directly or need tile
92+
// Same rank - check if we can use expand directly
12893
bool can_use_paddle_expand = true;
94+
size_t fail_dim = 0;
12995
for (size_t i = 0; i < target_rank; ++i) {
13096
auto in_size = input_dims[i];
13197
auto target_size = target_size_vec[i];
13298
if (in_size != 1 && in_size != target_size) {
13399
can_use_paddle_expand = false;
100+
fail_dim = i;
134101
break;
135102
}
136103
}
@@ -141,33 +108,20 @@ inline Tensor expand(const Tensor& self,
141108
return Tensor(result);
142109
}
143110

144-
// Need tile + slice fallback
145-
std::vector<int64_t> input_shape(target_rank);
146-
for (size_t i = 0; i < target_rank; ++i) {
147-
input_shape[i] = input_dims[i];
148-
}
149-
return tile_and_slice_to_target(pd_tensor, input_shape, target_size_vec);
111+
PD_THROW("expand(): the expanded size of the tensor (",
112+
target_size_vec[fail_dim],
113+
") must match the existing size (",
114+
input_dims[fail_dim],
115+
") at non-singleton dimension ",
116+
fail_dim,
117+
".");
150118
} else {
151-
// Input has more dimensions.
152-
// Keep the trailing target_rank dimensions and slice leading dimensions to
153-
// 1 before reshape, so total element count remains valid.
154-
paddle::Tensor squeezed = pd_tensor;
155-
size_t leading_dims = input_rank - target_rank;
156-
for (size_t i = 0; i < leading_dims; ++i) {
157-
squeezed = paddle::experimental::slice(
158-
squeezed, {static_cast<int64_t>(i)}, {0}, {1}, {1}, {});
159-
}
160-
161-
std::vector<int64_t> new_shape(target_rank);
162-
for (size_t i = 0; i < target_rank; ++i) {
163-
new_shape[i] = input_dims[i + (input_rank - target_rank)];
164-
}
165-
166-
// Reshape to target rank, then reuse the same expand implementation.
167-
paddle::Tensor reshaped =
168-
paddle::experimental::reshape(squeezed, phi::IntArray(new_shape));
169-
170-
return expand(Tensor(reshaped), size, implicit);
119+
PD_THROW("expand(): the number of sizes provided (",
120+
target_rank,
121+
") must be greater or equal to the number of dimensions in the "
122+
"tensor (",
123+
input_rank,
124+
").");
171125
}
172126
}
173127

paddle/phi/api/include/compat/ATen/ops/index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace at {
3232
inline at::Tensor index(const at::Tensor& self,
3333
ArrayRef<at::indexing::TensorIndex> indices) {
3434
if (indices.size() == 0) {
35-
return self;
35+
PD_THROW("index() cannot be called with an empty index list");
3636
}
3737

3838
bool has_slice = false;

paddle/phi/api/include/compat/ATen/ops/sparse_coo_tensor.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,9 @@ inline at::Tensor sparse_coo_tensor(const at::Tensor& indices,
4949
paddle::Tensor idx = indices._PD_GetInner();
5050
paddle::Tensor vals = values._PD_GetInner();
5151

52-
if (options.dtype_opt().has_value() &&
53-
options.dtype_opt().value() != values.scalar_type()) {
54-
vals = paddle::experimental::cast(
55-
vals,
56-
compat::_PD_AtenScalarTypeToPhiDataType(options.dtype_opt().value()));
57-
}
52+
// PyTorch ignores dtype mismatch between values and TensorOptions in
53+
// sparse_coo_tensor; the resulting sparse tensor uses values' original dtype.
54+
// Do not cast or throw here.
5855

5956
if (options.pinned_memory()) {
6057
phi::Place base_place = options._PD_GetPlace();

paddle/phi/api/include/compat/ATen/ops/sparse_csr_tensor.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ inline at::Tensor sparse_csr_tensor(const at::Tensor& crow_indices,
3636
paddle::Tensor cols = col_indices._PD_GetInner();
3737
paddle::Tensor vals = values._PD_GetInner();
3838

39-
if (options.dtype_opt().has_value() &&
40-
options.dtype_opt().value() != values.scalar_type()) {
41-
vals = paddle::experimental::cast(
42-
vals,
43-
compat::_PD_AtenScalarTypeToPhiDataType(options.dtype_opt().value()));
44-
}
39+
// PyTorch ignores dtype mismatch between values and TensorOptions in
40+
// sparse_csr_tensor; the resulting sparse tensor uses values' original dtype.
41+
// Do not cast or throw here.
4542

4643
if (options.pinned_memory()) {
4744
phi::Place base_place = options._PD_GetPlace();
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
// Placeholder header to satisfy PyTorch compatibility checks.
18+
// Paddle does not use the same CUDA cmake macros as PyTorch,
19+
// but the presence of this file allows downstream code to use
20+
// __has_include(<c10/cuda/impl/cuda_cmake_macros.h>) for feature detection.

test/cpp/compat/ATen_chunk_test.cc

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ TEST(TensorChunkTest, ChunkMoreChunksThanSize) {
5858

5959
std::vector<at::Tensor> chunks = t.chunk(5, 0);
6060

61-
ASSERT_EQ(chunks.size(), 5);
61+
// PyTorch returns at most dim_size non-empty chunks when chunks > dim_size
62+
ASSERT_EQ(chunks.size(), 2);
6263
}
6364

6465
TEST(TensorChunkTest, ChunkDefaultDim) {
@@ -78,3 +79,47 @@ TEST(TensorChunkTest, ChunkIntType) {
7879
ASSERT_EQ(chunks.size(), 3);
7980
ASSERT_EQ(chunks[0].dtype(), at::kInt);
8081
}
82+
83+
TEST(TensorChunkTest, ChunkZeroDim) {
84+
at::Tensor t = at::zeros({0, 4}, at::kFloat);
85+
86+
std::vector<at::Tensor> chunks = t.chunk(2, 0);
87+
88+
// PyTorch returns 'chunks' number of empty tensors when dim_size == 0
89+
ASSERT_EQ(chunks.size(), 2);
90+
ASSERT_EQ(chunks[0].size(0), 0);
91+
ASSERT_EQ(chunks[1].size(0), 0);
92+
}
93+
94+
TEST(TensorChunkTest, ChunkNegativeDim) {
95+
at::Tensor t = at::arange(12, at::kFloat).reshape({3, 4});
96+
97+
// chunk(-1) should be equivalent to chunk(rank - 1) = chunk(1)
98+
std::vector<at::Tensor> chunks_neg = t.chunk(2, -1);
99+
std::vector<at::Tensor> chunks_pos = t.chunk(2, 1);
100+
101+
ASSERT_EQ(chunks_neg.size(), chunks_pos.size());
102+
for (size_t i = 0; i < chunks_neg.size(); ++i) {
103+
ASSERT_EQ(chunks_neg[i].sizes(), chunks_pos[i].sizes());
104+
}
105+
}
106+
107+
TEST(TensorChunkTest, ChunkOutOfRangeDim) {
108+
at::Tensor t = at::arange(12, at::kFloat).reshape({3, 4});
109+
110+
ASSERT_THROW(t.chunk(2, 2), std::exception); // dim >= rank
111+
ASSERT_THROW(t.chunk(2, -3), std::exception); // dim < -rank
112+
}
113+
114+
TEST(TensorChunkTest, ChunkZeroRankTensor) {
115+
at::Tensor t = at::empty({}, at::kFloat); // 0-dim scalar tensor
116+
117+
ASSERT_THROW(t.chunk(2, 0), std::exception);
118+
}
119+
120+
TEST(TensorChunkTest, ChunkZeroChunks) {
121+
at::Tensor t = at::arange(12, at::kFloat).reshape({3, 4});
122+
123+
ASSERT_THROW(t.chunk(0, 0), std::exception);
124+
ASSERT_THROW(t.chunk(-1, 0), std::exception);
125+
}

0 commit comments

Comments
 (0)