Skip to content

Commit 42b33e4

Browse files
committed
modify code according to reviwer's option
1 parent 0013e71 commit 42b33e4

4 files changed

Lines changed: 38 additions & 57 deletions

File tree

paddle/fluid/eager/pylayer/py_layer_node.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,4 @@ GradNodePyLayer::operator()(
267267

268268
return grad_out;
269269
}
270-
271-
void GradNodePyLayer::ClearTensorWrappers() {
272-
VLOG(6) << "Clearing PyLayer tensor wrappers";
273-
SetIsTensorWrappersCleared(true);
274-
}
275-
276270
} // namespace egr

paddle/fluid/eager/pylayer/py_layer_node.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "paddle/fluid/eager/autograd_meta.h"
2020
#include "paddle/fluid/eager/grad_node_info.h"
2121
#include "paddle/fluid/eager/hooks.h"
22-
#include "paddle/fluid/eager/tensor_wrapper.h"
2322
#include "paddle/phi/core/compat/convert_utils.h"
2423
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
2524
#include "paddle/phi/core/tensor_meta.h"
@@ -64,7 +63,7 @@ class GradNodePyLayer : public GradNodeBase {
6463
bool create_graph = false,
6564
bool is_new_grad = false) override;
6665

67-
void ClearTensorWrappers() override;
66+
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
6867

6968
std::string name() override { return name_; }
7069

paddle/fluid/pybind/eager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ typedef struct {
4040
// preventing _clear_dataptr() from freeing the underlying memory before
4141
// backward runs. Lifecycle: born with container (set_container), dies with
4242
// the PyLayerObject (PyLayerDealloc).
43-
std::vector<std::shared_ptr<phi::DenseTensor>> tensor_hold_helper;
43+
std::vector<std::shared_ptr<phi::TensorBase>> tensor_hold_helper;
4444
#ifdef PADDLE_WITH_CUDA
4545
std::vector<egr::ReloadFunctor> reload_functors;
4646
#endif

paddle/fluid/pybind/eager_py_layer.cc

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,9 @@ PyObject* pylayer_method_apply(PyObject* cls,
274274

275275
for (int64_t i = inputs_size - 1; i >= 0; --i) {
276276
PyObject* obj = nullptr;
277-
if (i >= args_size) {
278-
obj = PyList_GetItem(kwargs_value_list, i - args_size); // NOLINT
277+
if (i >= static_cast<int64_t>(args_size)) {
278+
obj = PyList_GetItem(kwargs_value_list,
279+
i - static_cast<int64_t>(args_size)); // NOLINT
279280
} else {
280281
obj = PyTuple_GET_ITEM(args, i);
281282
}
@@ -692,31 +693,50 @@ PyObject* pylayer_method_apply(PyObject* cls,
692693
// DenseTensors found (Tensor / Tuple / List, recursively). Used by
693694
// tensor_properties_set_container to hold strong references so that
694695
// _clear_dataptr() cannot free the underlying allocation before backward.
695-
static void CollectDenseTensors(
696-
PyObject* obj, std::vector<std::shared_ptr<phi::DenseTensor>>* holder) {
696+
// DFS-walks obj (tuple/list tree) and calls fn(tensor) for every Tensor leaf.
697+
// Both CollectDenseTensors and RestoreDenseTensors are built on top of this.
698+
template <typename Fn>
699+
static void WalkDenseTensors(PyObject* obj, Fn&& fn) {
697700
if (!obj || obj == Py_None) return;
698701
if (PyCheckTensor(obj)) {
699-
const auto& tensor = reinterpret_cast<TensorObject*>(obj)->tensor;
700-
if (tensor.impl() && tensor.is_dense_tensor()) {
701-
holder->push_back(
702-
std::static_pointer_cast<phi::DenseTensor>(tensor.impl()));
703-
}
702+
fn(reinterpret_cast<TensorObject*>(obj)->tensor);
704703
return;
705704
}
706705
if (PyTuple_Check(obj)) {
707706
Py_ssize_t n = PyTuple_GET_SIZE(obj);
708707
for (Py_ssize_t i = 0; i < n; ++i)
709-
CollectDenseTensors(PyTuple_GET_ITEM(obj, i), holder);
708+
WalkDenseTensors(PyTuple_GET_ITEM(obj, i), fn);
710709
return;
711710
}
712711
if (PyList_Check(obj)) {
713712
Py_ssize_t n = PyList_GET_SIZE(obj);
714713
for (Py_ssize_t i = 0; i < n; ++i)
715-
CollectDenseTensors(PyList_GET_ITEM(obj, i), holder);
714+
WalkDenseTensors(PyList_GET_ITEM(obj, i), fn);
716715
return;
717716
}
718717
}
719718

719+
static void CollectDenseTensors(
720+
PyObject* obj, std::vector<std::shared_ptr<phi::TensorBase>>* holder) {
721+
WalkDenseTensors(obj, [holder](const paddle::Tensor& tensor) {
722+
if (tensor.impl()) holder->push_back(tensor.impl());
723+
});
724+
}
725+
726+
// Re-installs impl() for tensors cleared by _clear_dataptr(), using the
727+
// shared_ptrs stored in holder (same DFS order as CollectDenseTensors).
728+
static void RestoreDenseTensors(
729+
PyObject* obj,
730+
const std::vector<std::shared_ptr<phi::TensorBase>>& holder) {
731+
size_t idx = 0;
732+
WalkDenseTensors(obj, [&holder, &idx](paddle::Tensor& tensor) {
733+
if (idx < holder.size()) {
734+
if (!tensor.impl()) tensor.set_impl(holder[idx]);
735+
++idx;
736+
}
737+
});
738+
}
739+
720740
PyObject* call_unpack_hook(PyLayerObject* self) {
721741
auto unpack_hook = self->unpack_hook;
722742
auto packed_value = self->container;
@@ -767,45 +787,13 @@ PyObject* tensor_properties_get_container(PyLayerObject* self, void* closure) {
767787
if (self->container_be_packed) {
768788
return call_unpack_hook(self);
769789
}
770-
771-
// If tensor_hold_helper is non-empty, some tensors may have been cleared by
772-
// _clear_dataptr(). Iterate the top-level container tuple and restore any
773-
// null impl from the corresponding entry in tensor_hold_helper.
774-
// tensor_hold_helper is ordered by the DenseTensors found during deep
775-
// traversal in set_container; for the common case (flat tuple of tensors)
776-
// the k-th tensor in the tuple maps to tensor_hold_helper[k].
790+
// Re-attach any DenseTensor impls that were freed by _clear_dataptr().
791+
// tensor_hold_helper keeps the underlying allocations alive; walk the
792+
// container in the same DFS order as CollectDenseTensors and reinstall
793+
// impls for tensors whose impl() is currently null.
777794
if (!self->tensor_hold_helper.empty()) {
778-
Py_ssize_t size = PyTuple_Size(self->container);
779-
PyObject* recovered_container = PyTuple_New(size);
780-
Py_ssize_t holder_idx = 0;
781-
for (Py_ssize_t i = 0; i < size; ++i) {
782-
PyObject* item = PyTuple_GetItem(self->container, i);
783-
if (item && PyCheckTensor(item)) {
784-
TensorObject* tensor_obj = reinterpret_cast<TensorObject*>(item);
785-
if (!tensor_obj->tensor.impl() &&
786-
holder_idx <
787-
static_cast<Py_ssize_t>(self->tensor_hold_helper.size()) &&
788-
self->tensor_hold_helper[holder_idx]) {
789-
// Tensor was cleared by _clear_dataptr; restore impl from holder.
790-
paddle::Tensor recovered;
791-
recovered.set_impl(self->tensor_hold_helper[holder_idx]);
792-
PyTuple_SET_ITEM(
793-
recovered_container, i, paddle::pybind::ToPyObject(recovered));
794-
++holder_idx;
795-
continue;
796-
}
797-
++holder_idx;
798-
Py_INCREF(item);
799-
PyTuple_SET_ITEM(recovered_container, i, item);
800-
} else {
801-
Py_INCREF(item);
802-
PyTuple_SET_ITEM(recovered_container, i, item);
803-
}
804-
}
805-
return recovered_container;
795+
RestoreDenseTensors(self->container, self->tensor_hold_helper);
806796
}
807-
808-
// Fallback: return original container as-is.
809797
Py_INCREF(self->container);
810798
return self->container;
811799
EAGER_CATCH_AND_THROW_RETURN_NULL

0 commit comments

Comments
 (0)