-
Notifications
You must be signed in to change notification settings - Fork 700
Expand file tree
/
Copy pathdistributed.py
More file actions
2122 lines (1814 loc) · 78.3 KB
/
distributed.py
File metadata and controls
2122 lines (1814 loc) · 78.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Methods needed for distributed training (DP/TP)."""
from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from dataclasses import dataclass
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import torch
from torch.cuda import _lazy_call, _lazy_init
from torch.utils.checkpoint import detach_variable, noop_context_fn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from .torch_version import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
needs_quantized_gemm,
)
from .constants import dist_group_type
from .quantization import FP8GlobalStateManager, autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
"partition_stride": 1,
}
_USE_REENTRANT_ACTIVATION_RECOMPUTE = True
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
_ALL_ACTIVE_RNG_STATES = {}
def get_all_rng_states() -> bool:
"""Returns all generator states used by `CudaRNGStatesTracker`."""
return _ALL_ACTIVE_RNG_STATES
def set_all_rng_states(states: List) -> None:
"""Updates all generator states used by `CudaRNGStatesTracker`."""
global _ALL_ACTIVE_RNG_STATES
_ALL_ACTIVE_RNG_STATES = states
def graph_safe_rng_available() -> bool:
"""Returns whether cuda graph safe RNG state manipulation is supported."""
return (
hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state")
)
def is_graph_safe_rng_state(state: Union[torch.Tensor, torch.Generator]) -> bool:
"""Returns whether the rng state is a graph safe version."""
return graph_safe_rng_available() and isinstance(state, torch.Generator)
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
graph_safe: bool = True,
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU."""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
if clone:
# Reference to the cloned generator state
return default_generator.clone_state()
# Reference to the current generator state
return default_generator.graphsafe_get_state()
return default_generator.get_state()
def _set_cuda_rng_state(
new_state: torch.Tensor,
device: Union[int, str] = -1,
graph_safe=True,
) -> None:
"""Sets the random number generator state of the current GPU."""
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb() -> None:
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available() and graph_safe:
default_generator.graphsafe_set_state(new_state)
return
default_generator.set_state(new_state)
_lazy_call(cb)
def set_tensor_model_parallel_attributes(
tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
"""set attributes needed for TP"""
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
if hasattr(tensor, attribute):
raise RuntimeError(
f"Tensor already has attribute '{attribute}' set. Cannot set "
"tensor model parallel attributes on a tensor that already has them."
)
# Set the attributes.
setattr(tensor, "tensor_model_parallel", is_parallel)
setattr(tensor, "partition_dim", dim)
setattr(tensor, "partition_stride", stride)
@lru_cache
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
"""Return world size for the distributed group."""
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size(group=group)
@lru_cache
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
"""Return my rank for the distributed group."""
if not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed is not initialized. Call torch.distributed.init_process_group() "
"before calling get_distributed_rank()."
)
return torch.distributed.get_rank(group=group)
def initialize_affine_weight_gpu(
weight: torch.Tensor,
init_method: Callable,
get_rng_state_tracker: Callable,
partition_dim: int = 0,
stride: int = 1,
set_tp_attributes: bool = True,
) -> None:
"""Initialize affine weight for model parallel on GPU."""
if set_tp_attributes:
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if get_rng_state_tracker is None:
init_method(weight)
return
with get_rng_state_tracker().fork():
init_method(weight)
def split_tensor_into_1d_equal_chunks(
tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
start_index = partition_size * get_distributed_rank(tp_group)
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(
partition_size,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor: torch.Tensor, tp_group: dist_group_type) -> torch.Tensor:
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
return gathered
class activation_recompute_forward(AbstractContextManager, ContextDecorator):
"""Context manager used to control the forward runtime behavior when executed
under the `CheckpointFunction` function. For running FP8, the forward pass will
run without storing intermediate activations. Instead, the forward pass saves
the inputs tuple and the calling function. In the backwards pass, these are
retrieved, and the forward pass is computed again while tracking the intermediate
activations, followed by calculation of gradients using these values.
"""
_is_first_fp8_module: List = []
def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False):
super().__init__()
self.activation_recompute = activation_recompute
self.recompute_phase = recompute_phase
def __enter__(self):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = (
self.activation_recompute and FP8GlobalStateManager.is_fp8_enabled()
)
_FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase
if self.activation_recompute and not self.recompute_phase:
activation_recompute_forward._is_first_fp8_module.append(
FP8GlobalStateManager.IS_FIRST_FP8_MODULE
)
if self.activation_recompute and self.recompute_phase:
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = (
activation_recompute_forward._is_first_fp8_module.pop(0)
)
def __exit__(self, *exc_details):
global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False
def is_fp8_activation_recompute_enabled() -> bool:
"""Return global boolean"""
return _FP8_ACTIVATION_RECOMPUTE_ENABLED
def in_fp8_activation_recompute_phase() -> bool:
"""Return global boolean"""
return _FP8_ACTIVATION_RECOMPUTE_PHASE
def _get_active_autocast_contexts():
"""
Returns new CPU and GPU torch.amp.autocast(..) contexts that match the active autocast state
at the time of this function's execution.
"""
autocast_cached = torch.is_autocast_cache_enabled()
if torch_version() >= (2, 4, 0):
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
return gpu_autocast_ctx, cpu_autocast_ctx
class _CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""
@staticmethod
def forward(
ctx,
run_function: Callable,
distribute_saved_activations: bool,
get_rng_state_tracker: Union[Callable, None],
tp_group: Union[dist_group_type, None],
context_fn: Union[Callable, None],
kwargs: Dict[str, Any],
*args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
"""Call forward function while saving state to be able to
redo the computation later."""
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
ctx.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(ctx.fwd_cuda_rng_state_tracker.values())))
if ctx.fwd_cuda_rng_state_tracker
else False
)
else:
ctx.graph_safe_rng_state = False
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if context_fn is not None:
forward_ctx, recompute_ctx = context_fn()
else:
forward_ctx, recompute_ctx = noop_context_fn()
# Preserve torch autocast context for the backward pass
torch_gpu_amp_ctx, torch_cpu_amp_ctx = _get_active_autocast_contexts()
with torch.no_grad(), forward_ctx:
with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
outputs = run_function(*args, **kwargs)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(args[0].data, tp_group, new_buffer=True),
)
# Store everything.
ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs)
fp8 = FP8GlobalStateManager.is_fp8_enabled()
ctx.get_rng_state_tracker = get_rng_state_tracker
ctx.tp_group = tp_group
ctx.recompute_ctx = recompute_ctx
ctx.torch_gpu_amp_ctx = torch_gpu_amp_ctx
ctx.torch_cpu_amp_ctx = torch_cpu_amp_ctx
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.kwargs = kwargs
return outputs
@staticmethod
def backward(
ctx, *args: Tuple[Union[torch.Tensor, None], ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Call backward function with activation recomputation."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), please use .backward() if possible"
)
inputs = tuple(
t if t is not None else arg for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
)
get_rng_state_tracker = ctx.get_rng_state_tracker
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(ctx.input_0_shape),
)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward(
activation_recompute=True, recompute_phase=True
), autocast(
enabled=ctx.fp8, recipe=ctx.fp8_recipe
):
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=ctx.graph_safe_rng_state)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
outputs_with_grad = []
args_with_grad = []
for i, output in enumerate(outputs):
if torch.is_tensor(output) and output.requires_grad:
outputs_with_grad.append(output)
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
# backward does not require entering autocast context because
# backward implementations already retrieve fp8 recipe and
# enablement from stored ctx.
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs
)
return (None, None, None, None, None, None) + grads
class _CheckpointFrame:
"""
Storage frame for forward RNG states and detached activations from the forward recompute.
"""
def __init__(self, recompute_fn: Callable, get_rng_state_tracker: Callable):
self.recompute_fn = recompute_fn
self.recomputed = []
self.count = 0
self.get_rng_state_tracker = get_rng_state_tracker
self.fwd_rng_states = None
self.bwd_rng_states = None
def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (torch.get_rng_state(),)
if self.get_rng_state_tracker is not None:
tracker_states = self.get_rng_state_tracker().get_states()
self.graph_safe_rng_state = (
is_graph_safe_rng_state(next(iter(tracker_states.values())))
if tracker_states
else False
)
rng_states += (
_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),
tracker_states,
)
else:
self.graph_safe_rng_state = False
rng_states += (_get_cuda_rng_state(graph_safe=self.graph_safe_rng_state),)
if forward:
self.fwd_rng_states = rng_states
else:
self.bwd_rng_states = rng_states
def restore_rng_states(self, forward=True):
"""Restore fwd/bwd RNG states that were previously cached into the frame."""
if forward:
rng_states = self.fwd_rng_states
else:
rng_states = self.bwd_rng_states
torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1], graph_safe=self.graph_safe_rng_state)
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])
class _recomputation_hook(
torch.autograd.graph.saved_tensors_hooks
): # pylint: disable=too-few-public-methods
"""torch.autograd hook for packing/unpacking tensors during the activation recompute phase."""
def __init__(self, frame):
def pack_hook(x):
"""
Packing hook for each recomputed activation passed into the `ctx.save_for_backward()`
call in the forward recomputation.
"""
frame.recomputed.append(x.detach())
return x.detach()
def unpack_hook(x):
"""
No-op unpack hook that will never be called because the backward pass for the
forward recomputation is never triggered.
"""
return x
super().__init__(pack_hook, unpack_hook)
class _checkpoint_hook(
torch.autograd.graph.saved_tensors_hooks
): # pylint: disable=too-few-public-methods
"""torch.autograd hook for packing/unpacking tensors during the checkpointed forward pass."""
def __init__(self, frame, args, kwargs):
def pack_hook(x):
"""
Packing hook for each tensor passed into `ctx.save_for_backward()` call in the
forward pass. Since this is the first forward pass, we discard the tensor and instead
pack a placeholder tensor index into the autograd engine context.
"""
del x
idx = frame.count
frame.count += 1
return idx
def unpack_hook(idx):
"""
Unpacking hook for each tensor that comes out of the `ctx.saved_tensors` call in the
backward pass. The first time this is called, the _recomputation_hook will save all the
activation tensors from `ctx.save_for_backward()` in the forward recomputation into the
_CheckpointFrame. Subsequent calls will simply return the already recomputed activation
tensor at the given index of the _CheckpointFrame storage.
"""
if not frame.recomputed:
# Store current RNG states in the backward pass
frame.cache_rng_states(forward=False)
# Set RNG states to what we saved before the forward pass
frame.restore_rng_states(forward=True)
# Recompute the forward pass
with _recomputation_hook(frame):
frame.recompute_fn(*args, **kwargs)
# Restore RNG states back to the backward pass
frame.restore_rng_states(forward=False)
# Return the already recomputed activation tensor at the given index
activation = frame.recomputed[idx]
frame.recomputed[idx] = None
return activation
super().__init__(pack_hook, unpack_hook)
def use_reentrant_activation_recompute():
"""Returns `True` if activation recompute is using the 'reentrant' method."""
return _USE_REENTRANT_ACTIVATION_RECOMPUTE
def get_activation_recompute_contexts():
"""Returns context objects for the checkpointed forward pass and the forward recompute phase."""
forward_ctx = activation_recompute_forward(
activation_recompute=True,
recompute_phase=False,
)
recompute_ctx = activation_recompute_forward(
activation_recompute=True,
recompute_phase=True,
)
return forward_ctx, recompute_ctx
def has_te_modules(network):
"""
Check if there are any Transformer Engine modules in the network.
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
LayerNorm,
RMSNorm,
TransformerEngineBaseModule,
UnfusedDotProductAttention,
DotProductAttention,
MultiheadAttention,
TransformerLayer,
]
if isinstance(network, torch.nn.Module):
for module in network.modules():
if any(isinstance(module, te_class) for te_class in te_classes_list):
return True
return False
# Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module,
# so just assume that it has TE modules just to be safe.
return True
@torch._disable_dynamo
def checkpoint(
function: Callable,
*args: Tuple[torch.Tensor, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
"""
Checkpoint a part of the model by trading compute for memory. This function is based on
`torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.
.. warning::
It is the user's responsibility to ensure identical behavior when calling
:attr:`function` from the forward and backward pass. If different output is
produced (e.g. due to global state), then the checkpointed version won't
be numerically equivalent.
.. warning::
`use_reentrant=False` does not support early stopping, and will execute the entire forward
pass for the checkpointed module when recomputing activations in the backward pass.
Parameters
----------
function : Callable
pytorch module used to run the forward and backward passes using
the specified :attr:`args` and :attr:`kwargs`.
distribute_saved_activations : bool, default = False
if set to ``True`` and ``use_reentrant=True``, first tensor argument is distributed
across the specified tensor parallel group (``tp_group``) before saving it for the
backward pass. This has no effect when ``use_reentrant=False``.
get_rng_state_tracker : Callable, default = None
python callable which returns an instance of :class:`CudaRNGStatesTracker`.
tp_group : ProcessGroup, default = None
tensor parallel process group. Used only when ``distribute_saved_activations=True``
and ``use_reentrant=True``. If ``None``, it falls back to the default group.
use_reentrant : bool, default = True
perform checkpointing in reentrant mode.
args : tuple
tuple of torch tensors for inputs to :attr:`function`.
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
# Pop out te.distributed.checkpoint() arguments
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
tp_group = kwargs.pop("tp_group", None)
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)
# Ensure backward compatibility.
if (
len(args) > 3
and isinstance(args[0], bool)
and callable(args[1])
and isinstance(args[2], None | dist_group_type)
):
warnings.warn(
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
DeprecationWarning,
stacklevel=2,
)
distribute_saved_activations = args[0]
get_rng_state_tracker = args[1]
tp_group = args[2]
args = args[3:]
# Trigger the native PyTorch checkpoint if the function is not or does not contain a
# Transformer Engine module.
context_fn = kwargs.pop("context_fn", noop_context_fn)
determinism_check = kwargs.pop("determinism_check", "default")
debug = kwargs.pop("debug", False)
if not has_te_modules(function):
return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=_USE_REENTRANT_ACTIVATION_RECOMPUTE,
context_fn=context_fn,
determinism_check=determinism_check,
debug=debug,
**kwargs,
)
from .module.base import TransformerEngineBaseModule
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
function.fast_setattr("fsdp_wrapped", False)
function.fast_setattr("fsdp_group", None)
# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
# NOTE: This logic uses the TE checkpoint on all custom callable `function` handles because we
# cannot be sure there are no TE modules inside the function. It also means we might run
# the TE checkpoint for non-TE modules, so the TE checkpoint has to support a potential
# user context function.
del determinism_check, debug
if _USE_REENTRANT_ACTIVATION_RECOMPUTE:
# If saved activations need to be distributed but there is no process group,
# default to the world group.
if distribute_saved_activations:
if not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed is not initialized. Call "
"torch.distributed.init_process_group() before using "
"distribute_saved_activations=True."
)
tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group
return _CheckpointFunction.apply(
function,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
context_fn,
kwargs,
*args,
)
if distribute_saved_activations:
warnings.warn(
"`distribute_saved_activations=True` has no effect when `use_reentrant=False`. "
"The non-reentrant checkpoint implementation does not manually store forward "
"inputs for the activation recompute in the backward pass, and instead leverages "
"the autograd engine's pack/unpack hooks."
)
user_forward_ctx, user_recompute_ctx = context_fn()
te_forward_ctx, te_recompute_ctx = get_activation_recompute_contexts()
# Preserve the torch autocast contexts from the forward pass during recompute phase.
torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx = _get_active_autocast_contexts()
fp8 = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
def recompute_fn(*args, **kwargs):
with torch.autograd.enable_grad(), (
te_recompute_ctx
), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, autocast(
enabled=fp8, recipe=fp8_recipe
):
function(*args, **kwargs)
# Initialize a new checkpoint frame for each new forward pass.
new_frame = _CheckpointFrame(
recompute_fn,
get_rng_state_tracker,
)
new_frame.cache_rng_states(forward=True)
with _checkpoint_hook(new_frame, args, kwargs), te_forward_ctx, user_forward_ctx:
out = function(*args, **kwargs)
return out
class CudaRNGStatesTracker:
"""
For model parallelism, multiple RNG states need to simultaneously exist in order
to execute operations in or out of the model parallel region. This class keeps
track of the various RNG states and provides utility methods to maintain them and
execute parts of the model under a given RNG setting. Using the :meth:`add` method, a
cuda rng state is initialized based on the input ``seed`` and is assigned to ``name``.
Later, by forking the rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""
Set to the initial state (no tracker).
"""
self.states_ = {}
self.seeds_ = set()
def get_states(self) -> Dict[str, torch.Tensor]:
"""
Get rng states. Copy the dictionary so we have direct pointers
to the states, not just a pointer to the dictionary.
"""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states: Dict[str, torch.Tensor]) -> None:
"""
Set the rng states. For efficiency purposes, we do not
check the size of seed for compatibility.
Parameters
----------
states : Dict[str, torch.Tensor]
A mapping from string names to RNG states.
"""
self.states_ = states
# Update global states.
set_all_rng_states(self.states_)
def add(self, name: str, seed: int) -> None:
"""
Adds a new RNG state.
Parameters
----------
name : str
string identifier for the RNG state.
seed : int
PyTorch seed for the RNG state.
"""
# Check seed is not already used.
if seed in self.seeds_:
raise RuntimeError(f"seed {seed} already exists")
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise RuntimeError(f"cuda rng state {name} already exists")
if graph_safe_rng_available():
new_state = _get_cuda_rng_state(clone=True)
new_state.manual_seed(seed)
self.states_[name] = new_state
# Update global states.
set_all_rng_states(self.states_)
else:
# Get the current rng state.
orig_rng_state = _get_cuda_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = _get_cuda_rng_state(clone=True)
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
# Update global states.
set_all_rng_states(self.states_)
@contextmanager
def fork(self, name: str = "model-parallel-rng"):
"""
Fork the cuda rng state, perform operations, and exit with
the original state.
Parameters
----------
name : str
string identifier for the RNG state.
"""
# Check if we have added the state
if name not in self.states_:
raise KeyError(f"cuda rng state {name} is not added")
# Get the reference to current rng state.
orig_cuda_rng_state = _get_cuda_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# this is redundant with graph-safe API
if not graph_safe_rng_available():
self.states_[name] = _get_cuda_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
def reduce_scatter_along_first_dim(
inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return inp, None
dim_size = list(inp.size())
if dim_size[0] % world_size != 0:
raise ValueError(
"First dimension of the tensor should be divisible by tensor parallel size, "
f"but got dim_size[0]={dim_size[0]} and world_size={world_size} "
f"(remainder={dim_size[0] % world_size})."
)
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device())
handle = torch.distributed.reduce_scatter_tensor(
output, inp.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
@dataclass
class _AsyncHandle:
"""Handle for asynchronous collectives."""
async_handle: torch.distributed.Work
post_process_function: Optional[Callable] = None
post_process_function_args: Optional[Tuple[Any, ...]] = None
post_process_function_kwargs: Optional[Dict[str, Any]] = None
_synchronized: bool = False
def wait(self) -> None:
"""Synchronize the asynchronous communicaton.
Perform post-processing if needed.
"""
if self._synchronized:
return
self.async_handle.wait()
if self.post_process_function is not None:
args = self.post_process_function_args
args = () if args is None else args
kwargs = self.post_process_function_kwargs
kwargs = {} if kwargs is None else kwargs
self.post_process_function(*args, **kwargs)
self._synchronized = True
def _all_gather_fp8(
inp: torch.Tensor,
process_group: dist_group_type,
*,
async_op: bool = False,
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather FP8 tensor along first dimension."""
world_size = get_distributed_world_size(process_group)
# Check that quantizer is valid
if quantizer is not None and not isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
raise ValueError(f"Got non-FP8 quantizer ({quantizer.__class__.__name__})")
# Output tensor dims
if out_shape is None:
out_shape = list(inp.size())
out_shape[0] *= world_size
# Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.