@@ -274,8 +274,9 @@ PyObject* pylayer_method_apply(PyObject* cls,
274274
275275 for (int64_t i = inputs_size - 1 ; i >= 0 ; --i) {
276276 PyObject* obj = nullptr ;
277- if (i >= args_size) {
278- obj = PyList_GetItem (kwargs_value_list, i - args_size); // NOLINT
277+ if (i >= static_cast <int64_t >(args_size)) {
278+ obj = PyList_GetItem (kwargs_value_list,
279+ i - static_cast <int64_t >(args_size)); // NOLINT
279280 } else {
280281 obj = PyTuple_GET_ITEM (args, i);
281282 }
@@ -692,31 +693,50 @@ PyObject* pylayer_method_apply(PyObject* cls,
692693// DenseTensors found (Tensor / Tuple / List, recursively). Used by
693694// tensor_properties_set_container to hold strong references so that
694695// _clear_dataptr() cannot free the underlying allocation before backward.
695- static void CollectDenseTensors (
696- PyObject* obj, std::vector<std::shared_ptr<phi::DenseTensor>>* holder) {
696+ // DFS-walks obj (tuple/list tree) and calls fn(tensor) for every Tensor leaf.
697+ // Both CollectDenseTensors and RestoreDenseTensors are built on top of this.
698+ template <typename Fn>
699+ static void WalkDenseTensors (PyObject* obj, Fn&& fn) {
697700 if (!obj || obj == Py_None) return ;
698701 if (PyCheckTensor (obj)) {
699- const auto & tensor = reinterpret_cast <TensorObject*>(obj)->tensor ;
700- if (tensor.impl () && tensor.is_dense_tensor ()) {
701- holder->push_back (
702- std::static_pointer_cast<phi::DenseTensor>(tensor.impl ()));
703- }
702+ fn (reinterpret_cast <TensorObject*>(obj)->tensor );
704703 return ;
705704 }
706705 if (PyTuple_Check (obj)) {
707706 Py_ssize_t n = PyTuple_GET_SIZE (obj);
708707 for (Py_ssize_t i = 0 ; i < n; ++i)
709- CollectDenseTensors (PyTuple_GET_ITEM (obj, i), holder );
708+ WalkDenseTensors (PyTuple_GET_ITEM (obj, i), fn );
710709 return ;
711710 }
712711 if (PyList_Check (obj)) {
713712 Py_ssize_t n = PyList_GET_SIZE (obj);
714713 for (Py_ssize_t i = 0 ; i < n; ++i)
715- CollectDenseTensors (PyList_GET_ITEM (obj, i), holder );
714+ WalkDenseTensors (PyList_GET_ITEM (obj, i), fn );
716715 return ;
717716 }
718717}
719718
719+ static void CollectDenseTensors (
720+ PyObject* obj, std::vector<std::shared_ptr<phi::TensorBase>>* holder) {
721+ WalkDenseTensors (obj, [holder](const paddle::Tensor& tensor) {
722+ if (tensor.impl ()) holder->push_back (tensor.impl ());
723+ });
724+ }
725+
726+ // Re-installs impl() for tensors cleared by _clear_dataptr(), using the
727+ // shared_ptrs stored in holder (same DFS order as CollectDenseTensors).
728+ static void RestoreDenseTensors (
729+ PyObject* obj,
730+ const std::vector<std::shared_ptr<phi::TensorBase>>& holder) {
731+ size_t idx = 0 ;
732+ WalkDenseTensors (obj, [&holder, &idx](paddle::Tensor& tensor) {
733+ if (idx < holder.size ()) {
734+ if (!tensor.impl ()) tensor.set_impl (holder[idx]);
735+ ++idx;
736+ }
737+ });
738+ }
739+
720740PyObject* call_unpack_hook (PyLayerObject* self) {
721741 auto unpack_hook = self->unpack_hook ;
722742 auto packed_value = self->container ;
@@ -767,45 +787,13 @@ PyObject* tensor_properties_get_container(PyLayerObject* self, void* closure) {
767787 if (self->container_be_packed ) {
768788 return call_unpack_hook (self);
769789 }
770-
771- // If tensor_hold_helper is non-empty, some tensors may have been cleared by
772- // _clear_dataptr(). Iterate the top-level container tuple and restore any
773- // null impl from the corresponding entry in tensor_hold_helper.
774- // tensor_hold_helper is ordered by the DenseTensors found during deep
775- // traversal in set_container; for the common case (flat tuple of tensors)
776- // the k-th tensor in the tuple maps to tensor_hold_helper[k].
790+ // Re-attach any DenseTensor impls that were freed by _clear_dataptr().
791+ // tensor_hold_helper keeps the underlying allocations alive; walk the
792+ // container in the same DFS order as CollectDenseTensors and reinstall
793+ // impls for tensors whose impl() is currently null.
777794 if (!self->tensor_hold_helper .empty ()) {
778- Py_ssize_t size = PyTuple_Size (self->container );
779- PyObject* recovered_container = PyTuple_New (size);
780- Py_ssize_t holder_idx = 0 ;
781- for (Py_ssize_t i = 0 ; i < size; ++i) {
782- PyObject* item = PyTuple_GetItem (self->container , i);
783- if (item && PyCheckTensor (item)) {
784- TensorObject* tensor_obj = reinterpret_cast <TensorObject*>(item);
785- if (!tensor_obj->tensor .impl () &&
786- holder_idx <
787- static_cast <Py_ssize_t>(self->tensor_hold_helper .size ()) &&
788- self->tensor_hold_helper [holder_idx]) {
789- // Tensor was cleared by _clear_dataptr; restore impl from holder.
790- paddle::Tensor recovered;
791- recovered.set_impl (self->tensor_hold_helper [holder_idx]);
792- PyTuple_SET_ITEM (
793- recovered_container, i, paddle::pybind::ToPyObject (recovered));
794- ++holder_idx;
795- continue ;
796- }
797- ++holder_idx;
798- Py_INCREF (item);
799- PyTuple_SET_ITEM (recovered_container, i, item);
800- } else {
801- Py_INCREF (item);
802- PyTuple_SET_ITEM (recovered_container, i, item);
803- }
804- }
805- return recovered_container;
795+ RestoreDenseTensors (self->container , self->tensor_hold_helper );
806796 }
807-
808- // Fallback: return original container as-is.
809797 Py_INCREF (self->container );
810798 return self->container ;
811799 EAGER_CATCH_AND_THROW_RETURN_NULL
0 commit comments