1717import contextlib
1818import copy
1919import ctypes
20+ import functools
2021import inspect
2122import random
2223import weakref
@@ -118,42 +119,6 @@ def check_recompute_necessary(inputs):
118119 )
119120
120121
121- def _protect_tensors (seq ):
122- """For each element in seq (a list or tuple of forward args), create a new
123- tensor Python object that shares the same underlying buffer via
124- _new_shared_tensor(), so that when pipeline-parallel calls
125- _release_input/_release_output (which clears the data pointer of the
126- original tensor), the copies held by recompute for backward are not
127- invalidated. Non-tensor elements are kept as-is.
128- Returns a list with the same length as seq.
129- """
130- result = list (seq )
131- for idx , arg in enumerate (result ):
132- if isinstance (arg , core .eager .Tensor ):
133- # _new_shared_tensor() creates a new Python-level tensor object
134- # that shares the same C++ storage with arg, without cloning data.
135- shared = arg ._new_shared_tensor ()
136- assert shared is not arg , (
137- "_protect_tensors() must return a new Python object distinct from the original "
138- "tensor, otherwise the protection against pipeline-parallel tensor "
139- "release is ineffective."
140- )
141- result [idx ] = shared
142- elif isinstance (arg , tuple ):
143- # For tuple args (e.g., pipeline-parallel passes inputs as tuples),
144- # protect each tensor element inside the tuple individually;
145- # non-tensor elements (e.g., int, bool) are passed through unchanged.
146- protected_tuple = []
147- for t in arg :
148- if isinstance (t , core .eager .Tensor ):
149- shared = t ._new_shared_tensor ()
150- protected_tuple .append (shared )
151- else :
152- protected_tuple .append (t )
153- result [idx ] = tuple (protected_tuple )
154- return result
155-
156-
157122class CustomStatesManager :
158123 """CustomStatesManager"""
159124
@@ -224,22 +189,6 @@ def switch_rng_state_tracker(
224189 custom_set_state_func (orig_custom_state )
225190
226191
227- def _restore_freed_closure_tensors (ctx ):
228- """..."""
229- _PyCell_Set = ctypes .pythonapi .PyCell_Set
230- _PyCell_Set .argtypes = [ctypes .py_object , ctypes .py_object ]
231- _PyCell_Set .restype = ctypes .c_int
232- for cell , protected in zip (ctx .closure_cells , ctx .closure_protected ):
233- if cell is None or protected is None :
234- continue
235- try :
236- val = cell .cell_contents
237- except ValueError :
238- continue
239- if isinstance (val , core .eager .Tensor ) and not val ._is_initialized ():
240- _PyCell_Set (cell , protected )
241-
242-
243192class RecomputeFunction (PyLayer ):
244193 @staticmethod
245194 def forward (
@@ -260,32 +209,6 @@ def forward(
260209 ctx .offload_indices = offload_indices
261210 ctx .kwargs = kwargs
262211
263- # Protect tensor-type closure variables of run_function against
264- # pipeline-parallel _release_input/_release_output calling _clear_dataptr().
265- # Explicit args are already protected by _protect_tensors(); here we cover
266- # any tensors captured in the function's __closure__ (e.g. grid_thw).
267- ctx .closure_cells = []
268- ctx .closure_protected = []
269- fn = (
270- run_function .forward
271- if isinstance (run_function , paddle .nn .Layer )
272- else run_function
273- )
274- if hasattr (fn , '__closure__' ) and fn .__closure__ :
275- for cell in fn .__closure__ :
276- try :
277- val = cell .cell_contents
278- except ValueError : # empty cell
279- ctx .closure_cells .append (None )
280- ctx .closure_protected .append (None )
281- continue
282- if isinstance (val , core .eager .Tensor ):
283- ctx .closure_cells .append (cell )
284- ctx .closure_protected .append (val ._new_shared_tensor ())
285- else :
286- ctx .closure_cells .append (None )
287- ctx .closure_protected .append (None )
288-
289212 # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
290213 # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
291214 # None tensor inputs will be filtered in backward inputs.
@@ -428,8 +351,6 @@ def backward(ctx, *args):
428351 dtype = ctx .amp_dtype ,
429352 ),
430353 ):
431- if ctx .closure_cells :
432- _restore_freed_closure_tensors (ctx )
433354 detached_inputs = detach_variable (tuple (inputs ))
434355 outputs = ctx .run_function (* detached_inputs , ** ctx .kwargs )
435356 else :
@@ -440,8 +361,6 @@ def backward(ctx, *args):
440361 level = ctx .amp_level ,
441362 dtype = ctx .amp_dtype ,
442363 ):
443- if ctx .closure_cells :
444- _restore_freed_closure_tensors (ctx )
445364 detached_inputs = detach_variable (tuple (inputs ))
446365 outputs = ctx .run_function (* detached_inputs , ** ctx .kwargs )
447366
@@ -795,16 +714,14 @@ def recompute(function, *args, **kwargs):
795714 if use_reentrant :
796715 offload_indices = kwargs .pop ('offload_indices' , [])
797716 if not kwargs : # fast path
798- # Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
799- protected_args = _protect_tensors (args )
800717 return RecomputeFunction .apply (
801718 function ,
802719 preserve ,
803720 preserve_external_rng_state ,
804721 offload_indices ,
805722 custom_get_state_func ,
806723 custom_set_state_func ,
807- * protected_args ,
724+ * args ,
808725 )
809726
810727 # rearrange `position-args + keyword-args` into `position-args`
@@ -845,16 +762,14 @@ def recompute(function, *args, **kwargs):
845762 )
846763 else :
847764 raise ValueError ("Unknown parameter kind." )
848- # Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
849- protected_args = _protect_tensors (input_args )
850765 return RecomputeFunction .apply (
851766 function ,
852767 preserve ,
853768 preserve_external_rng_state ,
854769 offload_indices ,
855770 custom_get_state_func ,
856771 custom_set_state_func ,
857- * protected_args ,
772+ * input_args ,
858773 )
859774 else :
860775 return _recompute_without_reentrant (
0 commit comments