Skip to content

Commit 3b36fbb

Browse files
committed
delete shallow copy in recompute
1 parent 42b33e4 commit 3b36fbb

3 files changed

Lines changed: 53 additions & 547 deletions

File tree

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 3 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import contextlib
1818
import copy
1919
import ctypes
20+
import functools
2021
import inspect
2122
import random
2223
import weakref
@@ -118,42 +119,6 @@ def check_recompute_necessary(inputs):
118119
)
119120

120121

121-
def _protect_tensors(seq):
122-
"""For each element in seq (a list or tuple of forward args), create a new
123-
tensor Python object that shares the same underlying buffer via
124-
_new_shared_tensor(), so that when pipeline-parallel calls
125-
_release_input/_release_output (which clears the data pointer of the
126-
original tensor), the copies held by recompute for backward are not
127-
invalidated. Non-tensor elements are kept as-is.
128-
Returns a list with the same length as seq.
129-
"""
130-
result = list(seq)
131-
for idx, arg in enumerate(result):
132-
if isinstance(arg, core.eager.Tensor):
133-
# _new_shared_tensor() creates a new Python-level tensor object
134-
# that shares the same C++ storage with arg, without cloning data.
135-
shared = arg._new_shared_tensor()
136-
assert shared is not arg, (
137-
"_protect_tensors() must return a new Python object distinct from the original "
138-
"tensor, otherwise the protection against pipeline-parallel tensor "
139-
"release is ineffective."
140-
)
141-
result[idx] = shared
142-
elif isinstance(arg, tuple):
143-
# For tuple args (e.g., pipeline-parallel passes inputs as tuples),
144-
# protect each tensor element inside the tuple individually;
145-
# non-tensor elements (e.g., int, bool) are passed through unchanged.
146-
protected_tuple = []
147-
for t in arg:
148-
if isinstance(t, core.eager.Tensor):
149-
shared = t._new_shared_tensor()
150-
protected_tuple.append(shared)
151-
else:
152-
protected_tuple.append(t)
153-
result[idx] = tuple(protected_tuple)
154-
return result
155-
156-
157122
class CustomStatesManager:
158123
"""CustomStatesManager"""
159124

@@ -224,22 +189,6 @@ def switch_rng_state_tracker(
224189
custom_set_state_func(orig_custom_state)
225190

226191

227-
def _restore_freed_closure_tensors(ctx):
228-
"""..."""
229-
_PyCell_Set = ctypes.pythonapi.PyCell_Set
230-
_PyCell_Set.argtypes = [ctypes.py_object, ctypes.py_object]
231-
_PyCell_Set.restype = ctypes.c_int
232-
for cell, protected in zip(ctx.closure_cells, ctx.closure_protected):
233-
if cell is None or protected is None:
234-
continue
235-
try:
236-
val = cell.cell_contents
237-
except ValueError:
238-
continue
239-
if isinstance(val, core.eager.Tensor) and not val._is_initialized():
240-
_PyCell_Set(cell, protected)
241-
242-
243192
class RecomputeFunction(PyLayer):
244193
@staticmethod
245194
def forward(
@@ -260,32 +209,6 @@ def forward(
260209
ctx.offload_indices = offload_indices
261210
ctx.kwargs = kwargs
262211

263-
# Protect tensor-type closure variables of run_function against
264-
# pipeline-parallel _release_input/_release_output calling _clear_dataptr().
265-
# Explicit args are already protected by _protect_tensors(); here we cover
266-
# any tensors captured in the function's __closure__ (e.g. grid_thw).
267-
ctx.closure_cells = []
268-
ctx.closure_protected = []
269-
fn = (
270-
run_function.forward
271-
if isinstance(run_function, paddle.nn.Layer)
272-
else run_function
273-
)
274-
if hasattr(fn, '__closure__') and fn.__closure__:
275-
for cell in fn.__closure__:
276-
try:
277-
val = cell.cell_contents
278-
except ValueError: # empty cell
279-
ctx.closure_cells.append(None)
280-
ctx.closure_protected.append(None)
281-
continue
282-
if isinstance(val, core.eager.Tensor):
283-
ctx.closure_cells.append(cell)
284-
ctx.closure_protected.append(val._new_shared_tensor())
285-
else:
286-
ctx.closure_cells.append(None)
287-
ctx.closure_protected.append(None)
288-
289212
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
290213
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
291214
# None tensor inputs will be filtered in backward inputs.
@@ -428,8 +351,6 @@ def backward(ctx, *args):
428351
dtype=ctx.amp_dtype,
429352
),
430353
):
431-
if ctx.closure_cells:
432-
_restore_freed_closure_tensors(ctx)
433354
detached_inputs = detach_variable(tuple(inputs))
434355
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
435356
else:
@@ -440,8 +361,6 @@ def backward(ctx, *args):
440361
level=ctx.amp_level,
441362
dtype=ctx.amp_dtype,
442363
):
443-
if ctx.closure_cells:
444-
_restore_freed_closure_tensors(ctx)
445364
detached_inputs = detach_variable(tuple(inputs))
446365
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
447366

@@ -795,16 +714,14 @@ def recompute(function, *args, **kwargs):
795714
if use_reentrant:
796715
offload_indices = kwargs.pop('offload_indices', [])
797716
if not kwargs: # fast path
798-
# Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
799-
protected_args = _protect_tensors(args)
800717
return RecomputeFunction.apply(
801718
function,
802719
preserve,
803720
preserve_external_rng_state,
804721
offload_indices,
805722
custom_get_state_func,
806723
custom_set_state_func,
807-
*protected_args,
724+
*args,
808725
)
809726

810727
# rearrange `position-args + keyword-args` into `position-args`
@@ -845,16 +762,14 @@ def recompute(function, *args, **kwargs):
845762
)
846763
else:
847764
raise ValueError("Unknown parameter kind.")
848-
# Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
849-
protected_args = _protect_tensors(input_args)
850765
return RecomputeFunction.apply(
851766
function,
852767
preserve,
853768
preserve_external_rng_state,
854769
offload_indices,
855770
custom_get_state_func,
856771
custom_set_state_func,
857-
*protected_args,
772+
*input_args,
858773
)
859774
else:
860775
return _recompute_without_reentrant(

python/paddle/distributed/fleet/recompute/recompute_hybrid.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from ..meta_parallel.parallel_layers.random import get_rng_state_tracker
2727
from ..meta_parallel.pp_utils import utils
2828
from .recompute import (
29-
_protect_tensors,
3029
check_recompute_necessary,
3130
custom_state_manager,
3231
detach_variable,
@@ -155,13 +154,10 @@ def forward(
155154
ctx.amp_dtype = tracer._amp_dtype
156155
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
157156

158-
# Protect input tensors before saving to prevent release by pipeline parallel
159-
protected_args = _protect_tensors(args)
160-
161157
with paddle.no_grad():
162-
outputs = run_function(*protected_args, **kwargs)
158+
outputs = run_function(*args, **kwargs)
163159

164-
for i, arg in enumerate(protected_args):
160+
for i, arg in enumerate(args):
165161
if paddle.is_tensor(arg):
166162
state = arg.stop_gradient
167163
if partition:

0 commit comments

Comments
 (0)