From 086cebaff5054e935e181c2373a69c06a3c26263 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 15:25:31 -0600 Subject: [PATCH 01/11] feat: add trainer rank api --- dev/trainer_rank.py | 114 + dev/trainer_rank_fast_check.py | 25 + dev/trainer_rank_parity_probe.py | 539 +++ dev/trainer_rank_perf.py | 2906 +++++++++++++++++ dev/trainer_rank_topology_check.py | 1238 +++++++ src/art/trainer_rank/__init__.py | 2199 +++++++++++++ src/art/trainer_rank/topk.py | 283 ++ .../megatron/lora/test_dynamic_lora_slots.py | 198 ++ tests/unit/test_trainer_rank_validation.py | 448 +++ tests/unit/test_trainer_rank_weird_shapes.py | 501 +++ 10 files changed, 8451 insertions(+) create mode 100644 dev/trainer_rank.py create mode 100644 dev/trainer_rank_fast_check.py create mode 100644 dev/trainer_rank_parity_probe.py create mode 100644 dev/trainer_rank_perf.py create mode 100644 dev/trainer_rank_topology_check.py create mode 100644 src/art/trainer_rank/__init__.py create mode 100644 src/art/trainer_rank/topk.py create mode 100644 tests/integration/megatron/lora/test_dynamic_lora_slots.py create mode 100644 tests/unit/test_trainer_rank_validation.py create mode 100644 tests/unit/test_trainer_rank_weird_shapes.py diff --git a/dev/trainer_rank.py b/dev/trainer_rank.py new file mode 100644 index 000000000..14934d753 --- /dev/null +++ b/dev/trainer_rank.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import os + +import torch +import torch.distributed as dist +from transformers import AutoTokenizer +import typer + +from art.trainer_rank import AdamParams, ForwardInput, TrainerRank + + +def main( + model: str = "Qwen/Qwen3-0.6B", + dataset: str = "roneneldan/TinyStories", + split: str = "train", + text_column: str = "text", + samples: int = 16, + steps: int = 1, + lr: float = 5e-5, + layers: int = 2, + max_seq_length: int = 256, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + if not torch.cuda.is_available(): + raise RuntimeError("dev/trainer_rank.py requires CUDA") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + + try: + from datasets import load_dataset + + from art.megatron import train as megatron_train + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + inputs: list[ForwardInput[torch.Tensor, None, None, None]] = [] + for row in load_dataset(dataset, split=split, streaming=True): + text = str(row.get(text_column, "")).strip() # type: ignore[union-attr] + if not text: + continue + token_ids = tokenizer( + text, + add_special_tokens=True, + truncation=True, + max_length=max_seq_length + 1, + return_tensors="pt", + )["input_ids"].reshape(-1) + if int(token_ids.numel()) <= 1: + continue + inputs.append( + ForwardInput( + input_tokens=token_ids[:-1], + target_tokens=token_ids[1:], + ) + ) + if len(inputs) >= samples: + break + if not inputs: + raise RuntimeError("dataset produced no tokenized training examples") + + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=lambda provider: setattr( + provider, + "num_layers", + layers, + ), + print_env=dist.get_rank() == 0, + ) + rank = TrainerRank(runtime) + if dist.get_rank() == 0: + print( + "TrainerRank ready: " + f"dp={megatron_train.ps.get_data_parallel_world_size()} " + f"device={rank.device}", + flush=True, + ) + + for step in range(steps): + loss_sum = torch.tensor(0.0, device=rank.device) + token_count = torch.tensor(0.0, device=rank.device) + for micro in rank.forward_micro_batches(inputs): + loss = torch.tensor(0.0, device=rank.device) + for output in micro.outputs: + assert output.target_logprobs is not None + loss = loss - output.target_logprobs.sum() + token_count += output.target_logprobs.numel() + if loss.requires_grad: + loss.backward() + loss_sum += loss.detach() + + rank.dp_reduce(loss_sum) + rank.dp_reduce(token_count) + scale = 1.0 / max(float(token_count.item()), 1.0) + metrics = rank.optim_step( + params=AdamParams(learning_rate=lr), + scale_grads=scale, + ) + metrics["loss"] = float(loss_sum.item() * scale) + metrics["tokens"] = float(token_count.item()) + if dist.get_rank() == 0: + print(f"step={step} {metrics}", flush=True) + + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_fast_check.py b/dev/trainer_rank_fast_check.py new file mode 100644 index 000000000..51372d7d8 --- /dev/null +++ b/dev/trainer_rank_fast_check.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import subprocess +import sys + +FAST_TESTS = ( + "tests/unit/test_trainer_rank_validation.py", + "tests/unit/test_trainer_rank_weird_shapes.py", + "tests/unit/test_shared_prefix_packing.py", + "tests/unit/test_shared_prefix_tree.py", + "tests/unit/test_shared_prefix_attention_builder.py", + "tests/unit/test_shared_prefix_grad_parity.py", +) + + +def main() -> None: + raise SystemExit( + subprocess.call( + [sys.executable, "-m", "pytest", "--tb=short", *FAST_TESTS, *sys.argv[1:]] + ) + ) + + +if __name__ == "__main__": + main() diff --git a/dev/trainer_rank_parity_probe.py b/dev/trainer_rank_parity_probe.py new file mode 100644 index 000000000..1640512f2 --- /dev/null +++ b/dev/trainer_rank_parity_probe.py @@ -0,0 +1,539 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +import json +import os +import re +from typing import Any, cast + +import torch +import torch.distributed as dist +import typer + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.trainer_rank import ( + AnyForwardInput, + TrainerRank, + _batch_seq_logits, + _language_model, +) + + +@dataclass(frozen=True) +class _Capture: + values: dict[str, torch.Tensor] + positions_by_item: tuple[torch.Tensor, ...] + source_positions_by_item: tuple[torch.Tensor, ...] + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + sequences: int = 6, + sequence_length: int = 7, + compare_requests: int = 6, + request_shape: str = "varied", + oracle: str = "independent", + max_depth: int = 1, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + torch.manual_seed(1234) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=lambda provider: setattr( + provider, + "num_layers", + layers, + ), + print_env=dist.get_rank() == 0, + ) + if int(ps.get_tensor_model_parallel_world_size()) != 1: + raise RuntimeError("trainer_rank_parity_probe currently expects TP=1") + for chunk in runtime.model: + chunk.eval() + + rank = TrainerRank(runtime, shared_prefix_max_depth=max_depth) + requests = _unique_requests( + sequences=sequences, + sequence_length=sequence_length, + request_shape=request_shape, + ) + request_count = min(compare_requests, len(requests)) + + with torch.no_grad(): + packed = _run_capture(rank, requests) + records = _records_from_capture( + kind="packed", + capture=packed, + request_indices=range(len(requests)), + cp_rank=int(ps.get_context_parallel_rank()), + dp_rank=int(ps.get_data_parallel_rank()), + ) + for request_index, request in enumerate(requests): + if oracle == "independent": + oracle_capture = _run_capture(rank, [request]) + oracle_request_indices = (request_index,) + oracle_local_indices = None + elif oracle == "same-layout": + oracle_capture = _run_capture( + rank, + requests, + mutate_except=request_index, + ) + oracle_request_indices = range(len(requests)) + oracle_local_indices = (request_index,) + else: + raise ValueError("oracle must be 'independent' or 'same-layout'") + records.extend( + _records_from_capture( + kind="independent", + capture=oracle_capture, + request_indices=oracle_request_indices, + cp_rank=int(ps.get_context_parallel_rank()), + dp_rank=int(ps.get_data_parallel_rank()), + local_indices=oracle_local_indices, + ) + ) + + gathered: list[list[dict[str, object]] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, records) + if dist.get_rank() == 0: + flat_records = [ + record for rank_records in gathered for record in rank_records or [] + ] + report = _build_report( + records=flat_records, + requests=requests[:request_count], + topology={ + "world": dist.get_world_size(), + "dp": int(ps.get_data_parallel_world_size()), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + }, + oracle=oracle, + ) + print(json.dumps(report, sort_keys=True), flush=True) + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _unique_requests( + *, + sequences: int, + sequence_length: int, + request_shape: str, +) -> list[AnyForwardInput]: + from art.trainer_rank import ForwardInput + + if sequences < 1 or sequence_length < 2: + raise ValueError("sequences must be >= 1 and sequence_length must be >= 2") + if request_shape == "varied": + base_rows = ( + (11, 12, 13, 14, 15, 16, 17), + (11, 12, 13, 14, 24, 25), + (11, 12, 13, 14, 24, 26), + (11, 12, 13, 27), + (31, 32, 33, 34), + (31, 32, 33, 35), + (11, 12, 13, 14, 15, 16, 17), + (41, 42, 43), + (41, 42, 44, 45), + (51, 52, 53, 54, 55), + (61, 62, 63), + (61, 62, 64, 65), + (71, 72), + (81, 82, 83, 84), + (91, 92, 93), + (101, 102, 103, 104, 105), + ) + return [ + ForwardInput( + input_tokens=torch.tensor(row, dtype=torch.long) + 1000 * index + ) + for index, row in enumerate(base_rows[:sequences]) + ] + if request_shape == "deep": + base_rows = ( + (11, 12, 13, 14, 15, 16, 17), + (11, 12, 13, 14, 15, 16, 18), + (11, 12, 13, 14, 15, 19), + (11, 12, 13, 14, 20), + (11, 12, 21), + (31, 32, 33, 34, 35), + (31, 32, 33, 34, 36), + (31, 32, 33, 37), + (41, 42, 43), + (41, 42, 44), + (51, 52, 53, 54), + (61, 62), + (71, 72, 73, 74, 75), + (71, 72, 73, 76), + (81,), + (91, 92, 93), + ) + return [ + ForwardInput(input_tokens=torch.tensor(row, dtype=torch.long)) + for row in base_rows[:sequences] + ] + if request_shape != "equal": + raise ValueError("request_shape must be 'equal', 'varied', or 'deep'") + return [ + ForwardInput( + input_tokens=torch.arange( + 1000 * index + 11, + 1000 * index + 11 + sequence_length, + dtype=torch.long, + ) + ) + for index in range(sequences) + ] + + +def _run_capture( + rank: TrainerRank, + requests: Sequence[AnyForwardInput], + *, + mutate_except: int | None = None, +) -> _Capture: + from art.megatron.train import _placeholder_attention_mask + + model = _language_model(rank.runtime.model[0]) + items = [rank._forward_item(request) for request in requests] + batch = pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + if mutate_except is not None: + batch = _mutated_batch( + batch, keep_positions=batch.positions_by_sequence[mutate_except] + ) + prepared = rank._prepare_packed_forward(batch) + local_seq_len = int(prepared.tokens.shape[1]) + values: dict[str, torch.Tensor] = {} + handles = _register_hooks(model, values, seq_len=local_seq_len) + try: + handler = rank.runtime.model_support_handler + forward_kwargs = handler.get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ) + extra_block_kwargs = cast( + dict[str, object] | None, + forward_kwargs.pop("extra_block_kwargs", None), + ) + preprocessed = model._preprocess( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + packed_seq_params=prepared.packed_seq_params, + ) + values["00.preprocess.decoder_input"] = _rows( + cast(torch.Tensor, preprocessed[0]).detach(), + seq_len=local_seq_len, + ) + hidden = cast( + torch.Tensor, + model.decoder( + hidden_states=preprocessed[0], + attention_mask=_placeholder_attention_mask(rank.device), + rotary_pos_emb=preprocessed[1], + rotary_pos_cos=preprocessed[2], + rotary_pos_sin=preprocessed[3], + rotary_pos_cos_sin=preprocessed[6] if len(preprocessed) == 7 else None, + packed_seq_params=prepared.packed_seq_params, + sequence_len_offset=preprocessed[4], + padding_mask=preprocessed[5], + **(extra_block_kwargs or {}), + ), + ) + gathered_hidden = rank._gather_sequence_parallel_hidden(hidden) + values["90.decoder.output"] = gathered_hidden.detach() + values["99.lm_head.logits"] = _logits(rank, gathered_hidden).detach() + return _Capture( + values=values, + positions_by_item=prepared.positions_by_item, + source_positions_by_item=prepared.source_positions_by_item, + ) + finally: + for handle in handles: + handle.remove() + + +def _mutated_batch( + batch: SharedPrefixPack, + *, + keep_positions: torch.Tensor, +) -> SharedPrefixPack: + tokens = batch.tokens.clone() + mask = torch.ones(int(tokens.shape[1]), dtype=torch.bool, device=tokens.device) + mask[keep_positions.to(device=tokens.device)] = False + replacement = ( + torch.arange(int(tokens.shape[1]), dtype=tokens.dtype, device=tokens.device) + + 50_000 + ) + tokens[0, mask] = replacement[mask] % 100_000 + return SharedPrefixPack( + tokens=tokens, + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + position_ids=batch.position_ids, + positions_by_sequence=batch.positions_by_sequence, + ) + + +def _register_hooks( + model: torch.nn.Module, + values: dict[str, torch.Tensor], + *, + seq_len: int, +) -> list[Any]: + handles: list[Any] = [] + for module_name, module in model.named_modules(): + label = _capture_label(module_name) + if label is None: + continue + + def hook( + _module: torch.nn.Module, + _inputs: tuple[object, ...], + output: object, + *, + label: str = label, + ) -> None: + tensor = _first_tensor(output) + if tensor is not None: + try: + values[label] = _rows(tensor.detach(), seq_len=seq_len) + except RuntimeError: + pass + + handles.append(module.register_forward_hook(hook)) + return handles + + +def _capture_label(module_name: str) -> str | None: + layer_prefix = r"decoder\.layers\.(\d+)(?:\._orig_mod)?" + if re.fullmatch(r"decoder\.layers\.(\d+)\._orig_mod", module_name): + return None + layer_match = re.fullmatch(r"decoder\.layers\.(\d+)", module_name) + if layer_match: + return f"30.layer.{int(layer_match.group(1)):03d}.output" + input_norm_match = re.fullmatch(rf"{layer_prefix}\.input_layernorm", module_name) + if input_norm_match: + return f"05.layer.{int(input_norm_match.group(1)):03d}.input_layernorm" + qkv_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.linear_qkv", module_name + ) + if qkv_match: + return f"08.layer.{int(qkv_match.group(1)):03d}.self_attention.linear_qkv" + core_attention_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.core_attention", + module_name, + ) + if core_attention_match: + return f"10.layer.{int(core_attention_match.group(1)):03d}.self_attention.core_attention" + attention_proj_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.linear_proj", + module_name, + ) + if attention_proj_match: + return f"12.layer.{int(attention_proj_match.group(1)):03d}.self_attention.linear_proj" + attention_match = re.fullmatch( + rf"{layer_prefix}\.self_attention", + module_name, + ) + if attention_match: + return f"15.layer.{int(attention_match.group(1)):03d}.self_attention" + pre_mlp_norm_match = re.fullmatch( + rf"{layer_prefix}\.pre_mlp_layernorm", + module_name, + ) + if pre_mlp_norm_match: + return f"18.layer.{int(pre_mlp_norm_match.group(1)):03d}.pre_mlp_layernorm" + fc1_match = re.fullmatch(rf"{layer_prefix}\.mlp\.linear_fc1", module_name) + if fc1_match: + return f"20.layer.{int(fc1_match.group(1)):03d}.mlp.linear_fc1" + fc2_match = re.fullmatch(rf"{layer_prefix}\.mlp\.linear_fc2", module_name) + if fc2_match: + return f"22.layer.{int(fc2_match.group(1)):03d}.mlp.linear_fc2" + mlp_match = re.fullmatch(rf"{layer_prefix}\.mlp", module_name) + if mlp_match: + return f"25.layer.{int(mlp_match.group(1)):03d}.mlp" + if module_name == "decoder.final_layernorm": + return "80.decoder.final_layernorm" + return None + + +def _first_tensor(value: object) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, (tuple, list)): + for item in value: + tensor = _first_tensor(item) + if tensor is not None: + return tensor + return None + + +def _rows(tensor: torch.Tensor, *, seq_len: int) -> torch.Tensor: + if tensor.ndim >= 2 and int(tensor.shape[0]) == seq_len: + rows = tensor + if rows.ndim >= 3 and int(rows.shape[1]) == 1: + return rows[:, 0].contiguous() + return rows.contiguous() + if tensor.ndim >= 2 and int(tensor.shape[1]) == seq_len: + rows = ( + tensor[:, :, 0] + if tensor.ndim == 4 and int(tensor.shape[2]) == 1 + else tensor + ) + if int(rows.shape[0]) == 1: + return rows[0].contiguous() + raise RuntimeError( + f"Cannot identify sequence axis for tensor shape={tuple(tensor.shape)} " + f"seq_len={seq_len}" + ) + + +def _logits(rank: TrainerRank, hidden_rows: torch.Tensor) -> torch.Tensor: + model = _language_model(rank.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + if int(hidden_rows.shape[0]) == 0: + return hidden_rows.new_empty((0, int(model.vocab_size))) + local_logits = rank._local_logits_from_hidden_rows( + model, + hidden_rows, + output_weight=output_weight, + ) + return _batch_seq_logits( + rank._gather_tensor_parallel_logits(local_logits.unsqueeze(1)), + seq_len=int(hidden_rows.shape[0]), + ).squeeze(0) + + +def _records_from_capture( + *, + kind: str, + capture: _Capture, + request_indices: Sequence[int], + cp_rank: int, + dp_rank: int, + local_indices: Sequence[int] | None = None, +) -> list[dict[str, object]]: + records: list[dict[str, object]] = [] + local_index_set = None if local_indices is None else frozenset(local_indices) + for local_index, request_index in enumerate(request_indices): + if local_index_set is not None and local_index not in local_index_set: + continue + positions = capture.positions_by_item[local_index] + source_positions = capture.source_positions_by_item[local_index] + if int(positions.numel()) == 0: + continue + for name, rows in capture.values.items(): + records.append( + { + "kind": kind, + "name": name, + "request_index": int(request_index), + "source_positions": source_positions.cpu(), + "value": rows.index_select(0, positions.to(rows.device)).cpu(), + "cp": int(cp_rank), + "dp": int(dp_rank), + } + ) + return records + + +def _build_report( + *, + records: list[dict[str, object]], + requests: Sequence[AnyForwardInput], + topology: dict[str, int], + oracle: str, +) -> dict[str, object]: + results = [] + names = sorted( + { + cast(str, record["name"]) + for record in records + if record.get("kind") == "packed" + } + ) + for request_index, request in enumerate(requests): + length = int(request.input_tokens.numel()) + for name in names: + packed = _assemble(records, "packed", name, request_index, length) + independent = _assemble(records, "independent", name, request_index, length) + if packed is None or independent is None: + continue + diff = (packed.float() - independent.float()).abs() + denom = independent.float().abs().max().clamp_min(1e-12) + results.append( + { + "request": request_index, + "site": name, + "shape": list(packed.shape), + "max_abs": float(diff.max().item()) if int(diff.numel()) else 0.0, + "mean_abs": float(diff.mean().item()) if int(diff.numel()) else 0.0, + "rel_max": float((diff.max() / denom).item()) + if int(diff.numel()) + else 0.0, + } + ) + return { + "topology": topology, + "oracle": oracle, + "requests": len(requests), + "results": results, + } + + +def _assemble( + records: list[dict[str, object]], + kind: str, + name: str, + request_index: int, + length: int, +) -> torch.Tensor | None: + matching = [ + record + for record in records + if record["kind"] == kind + and record["name"] == name + and record["request_index"] == request_index + ] + if not matching: + return None + first = cast(torch.Tensor, matching[0]["value"]) + output = torch.empty((length, *first.shape[1:]), dtype=first.dtype) + filled = torch.zeros(length, dtype=torch.bool) + for record in matching: + positions = cast(torch.Tensor, record["source_positions"]) + value = cast(torch.Tensor, record["value"]) + output[positions] = value + filled[positions] = True + if not bool(filled.all().item()): + raise RuntimeError( + f"Missing positions for {kind} {name} request={request_index}" + ) + return output + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_perf.py b/dev/trainer_rank_perf.py new file mode 100644 index 000000000..a939e9932 --- /dev/null +++ b/dev/trainer_rank_perf.py @@ -0,0 +1,2906 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from contextlib import contextmanager, suppress +import json +import os +from pathlib import Path +import threading +import time +from typing import Any + +import torch +import torch.distributed as dist +import typer + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +import art.trainer_rank as trainer_rank_module +from art.trainer_rank import ( + AdamParams, + ForwardInput, + TopK, + TrainerRank, + _batch_seq_logits, + _language_model, + _unflatten, +) + + +def _pack_forward_items(items: Sequence[Any], *, max_depth: int) -> SharedPrefixPack: + return pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=max_depth, + ) + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + seq_len: int = 2048, + prefix_families: int = 0, + prefix_len: int = 5000, + mid_prefixes_per_family: int = 1, + mid_prefix_len: int = 0, + branches_per_prefix: int = 16, + completion_len: int = 100, + warmup: int = 2, + repeat: int = 5, + head_chunk_tokens: int = 512, + shared_prefix_max_depth: int = 1, + benchmark: str = "target_builtin_fwd", + target_count: int = 4, + top_k: int = 5, + top_k_values: str = "1,2,5,10,20,50", + max_unpacked_output_gb: float = 0.5, + mask_prefix_targets: bool = True, + workload: str = "regular", + tree_depth: int = 3, + tree_seed: int = 1, + tree_duplicate_factor: int = 1, + adapter_slots: int = 0, + adapter_slot_mode: str = "family", + adapter_slot_rank: int = 1, + learning_rate: float = 1e-5, + full_step_offload_reload: bool = False, + memory_safety_factor: float = 1.10, + memory_reserve_fraction: float = 0.03, + memory_sample_interval_s: float = 0.05, + compare_target_correctness: bool = False, + run_adapter_sanity: bool = False, + progress_jsonl: str = "", + output_jsonl: str = "", +) -> None: + if progress_jsonl: + os.environ["ART_TRAINER_RANK_PROGRESS_JSONL"] = progress_jsonl + + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + provider_configure = ( + (lambda provider: setattr(provider, "num_layers", layers)) + if layers > 0 + else None + ) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=provider_configure, + print_env=dist.get_rank() == 0, + ) + for chunk in runtime.model: + chunk.eval() + rank = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_tokens, + shared_prefix_max_depth=shared_prefix_max_depth, + memory_safety_factor=memory_safety_factor, + memory_reserve_fraction=memory_reserve_fraction, + ) + if adapter_slots < 0: + raise ValueError("adapter_slots must be >= 0") + if adapter_slot_rank < 1: + raise ValueError("adapter_slot_rank must be >= 1") + if adapter_slots: + loaded_sites = _load_adapter_slots( + rank, + count=adapter_slots, + slot_rank=adapter_slot_rank, + ) + else: + loaded_sites = 0 + hidden_size, vocab_size, dtype_size = _runtime_output_shape(runtime) + model_config = getattr(_language_model(runtime.model[0]), "config", None) + + benchmarks = { + name.strip().replace("-", "_") + for name in benchmark.split(",") + if name.strip() + } + if "all" in benchmarks: + benchmarks = { + "target_builtin_fwd", + "target_trainer_fwd", + "target_hidden_fwd", + "logits_builtin_fwd", + "logits_hidden_fwd", + "target_builtin_fwd_bwd", + "target_builtin_masked_fwd_bwd", + "target_trainer_fwd_bwd", + "target_hidden_fwd_bwd", + "target_builtin_train_step", + "target_trainer_train_step", + "target_trainer_fixed_train_step", + "target_trainer_adaptive_train_step", + "target_trainer_adaptive_profile_train_step", + "target_hidden_train_step", + "trainer_multi_target_fwd_bwd", + "trainer_multi_target_train_step", + "trainer_multi_target_fixed_train_step", + "trainer_multi_target_adaptive_train_step", + "trainer_target", + "trainer_multi_target", + "trainer_topk", + "trainer_topk_head", + "trainer_topk_fwd_bwd", + "trainer_topk_train_step", + "trainer_topk_fixed_train_step", + "trainer_topk_adaptive_train_step", + "trainer_topk_sweep", + "trainer_target_topk", + "trainer_hidden", + "trainer_all_no_logits", + "trainer_logits", + } + if "trainer_all" in benchmarks: + benchmarks.update( + { + "trainer_target", + "trainer_multi_target", + "trainer_multi_target_fwd_bwd", + "trainer_multi_target_train_step", + "trainer_multi_target_fixed_train_step", + "trainer_multi_target_adaptive_train_step", + "trainer_topk", + "trainer_topk_head", + "trainer_topk_fwd_bwd", + "trainer_topk_train_step", + "trainer_topk_fixed_train_step", + "trainer_topk_adaptive_train_step", + "trainer_topk_sweep", + "trainer_target_topk", + "trainer_hidden", + "trainer_all_no_logits", + "trainer_logits", + } + ) + + if target_count < 1: + raise ValueError("target_count must be >= 1") + if top_k < 1: + raise ValueError("top_k must be >= 1") + if memory_sample_interval_s < 0: + raise ValueError("memory_sample_interval_s must be >= 0") + requests, multi_target_requests, request_metadata = _requests( + seq_len=seq_len, + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + target_count=target_count, + mask_prefix_targets=mask_prefix_targets, + workload=workload, + tree_depth=tree_depth, + tree_seed=tree_seed, + tree_duplicate_factor=tree_duplicate_factor, + ) + requests = _route_adapter_slots( + requests, + adapter_slots=adapter_slots, + mode=adapter_slot_mode, + ) + multi_target_requests = _route_adapter_slots( + multi_target_requests, + adapter_slots=adapter_slots, + mode=adapter_slot_mode, + ) + stats_items = [rank._forward_item(request) for request in requests] + stats_batch = _pack_forward_items( + stats_items, + max_depth=rank.shared_prefix_max_depth, + ) + stats_prepared = rank._prepare_packed_forward(stats_batch) + request_stats = _packed_request_stats( + requests, + stats_items, + stats_batch, + request_metadata=request_metadata, + ) + planner_metadata = _gather_planner_metadata(stats_prepared) + target_items = None + target_prepared = None + if any(name.startswith("target_") for name in benchmarks): + target_items = stats_items + target_prepared = stats_prepared + logits_items = None + logits_prepared = None + if any(name.startswith("logits_") for name in benchmarks): + logits_items = [ + rank._forward_item(_with_outputs(request, logits=True)) + for request in requests + ] + logits_prepared = rank._prepare_packed_forward( + _pack_forward_items( + logits_items, + max_depth=rank.shared_prefix_max_depth, + ) + ) + results: dict[str, float] = {} + metadata: dict[str, object] = {} + rate_units: dict[str, dict[str, int]] = {} + + def register_case( + name: str, + case_requests: Sequence[ + ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, + ] + ], + case_stats: dict[str, int | str], + ) -> None: + units = _rate_units( + case_requests, + case_stats, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + rate_units[name] = units + for key, value in units.items(): + metadata[f"{name}_{key}"] = value + + for name in ( + "target_builtin_fwd", + "target_hidden_fwd", + "target_trainer_fwd", + "target_builtin_fwd_bwd", + "target_builtin_masked_fwd_bwd", + "target_trainer_fwd_bwd", + "target_hidden_fwd_bwd", + "target_builtin_train_step", + "target_trainer_train_step", + "target_trainer_fixed_train_step", + "target_trainer_adaptive_train_step", + "target_trainer_adaptive_profile_train_step", + "target_hidden_train_step", + ): + register_case(name, requests, request_stats) + + memory_tracker = _CudaMemoryTracker( + device_index=int(os.environ["LOCAL_RANK"]), + sample_interval_s=memory_sample_interval_s, + ) + memory_tracker.start() + torch.cuda.reset_peak_memory_stats() + with torch.no_grad(): + if "target_builtin_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_builtin_fwd_ms"] = _bench( + lambda: _builtin( + rank, + target_prepared, + _packed_labels(target_items, target_prepared), + ), + warmup=warmup, + repeat=repeat, + ) + if "target_hidden_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_hidden_fwd_ms"] = _bench( + lambda: rank._project_head( + target_items, + target_prepared, + rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(target_prepared) + ), + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_trainer_fwd_ms"] = _bench( + lambda: rank._forward_packed(target_items, target_prepared), + warmup=warmup, + repeat=repeat, + ) + if "logits_builtin_fwd" in benchmarks: + assert logits_prepared is not None + register_case( + "logits_builtin_fwd", _logits_requests(requests), request_stats + ) + results["logits_builtin_fwd_ms"] = _bench( + lambda: _full_logits(rank, logits_prepared), + warmup=warmup, + repeat=repeat, + ) + if "logits_hidden_fwd" in benchmarks: + assert logits_items is not None and logits_prepared is not None + register_case( + "logits_hidden_fwd", _logits_requests(requests), request_stats + ) + results["logits_hidden_fwd_ms"] = _bench( + lambda: rank._project_head( + logits_items, + logits_prepared, + rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(logits_prepared) + ), + ), + warmup=warmup, + repeat=repeat, + ) + trainer_cases = { + "trainer_target": requests, + "trainer_multi_target": multi_target_requests, + "trainer_topk": [ + _with_outputs(request, top_k=top_k) for request in requests + ], + "trainer_target_topk": [ + _with_outputs( + request, + target_tokens=request.target_tokens, + top_k=top_k, + ) + for request in requests + ], + "trainer_hidden": [ + _with_outputs(request, hidden_states=True) for request in requests + ], + "trainer_all_no_logits": [ + _with_outputs( + request, + target_tokens=multi_request.target_tokens, + top_k=top_k, + hidden_states=True, + ) + for request, multi_request in zip( + requests, multi_target_requests, strict=True + ) + ], + "trainer_logits": [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ], + } + if "trainer_topk_sweep" in benchmarks: + for k in _int_values(top_k_values): + trainer_cases[f"trainer_topk_{k}"] = [ + _with_outputs(request, top_k=k) for request in requests + ] + for name, case_requests in trainer_cases.items(): + if name not in benchmarks and not ( + "trainer_topk_sweep" in benchmarks + and name.startswith("trainer_topk_") + ): + continue + output_gb = _request_output_gb( + case_requests, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + metadata[f"{name}_output_gb"] = round(output_gb, 3) + if max_unpacked_output_gb > 0 and output_gb > max_unpacked_output_gb: + metadata[f"{name}_skipped"] = "unpacked_output_cap" + continue + items = [rank._forward_item(request) for request in case_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + name, + case_requests, + _packed_request_stats( + case_requests, items, batch, request_metadata={} + ), + ) + prepared = rank._prepare_packed_forward(batch) + if adapter_slots: + results[f"{name}_ms"] = _bench( + lambda case_requests=case_requests: rank.dp_rank_forward( + case_requests + ), + warmup=warmup, + repeat=repeat, + ) + else: + results[f"{name}_ms"] = _bench( + lambda items=items, prepared=prepared: rank._forward_packed( + items, + prepared, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_topk_head" in benchmarks: + case_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + output_gb = _request_output_gb( + case_requests, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + metadata["trainer_topk_head_output_gb"] = round(output_gb, 3) + items = [rank._forward_item(request) for request in case_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_head", + case_requests, + _packed_request_stats( + case_requests, items, batch, request_metadata={} + ), + ) + prepared = rank._prepare_packed_forward(batch) + hidden = rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(prepared) + ) + results["trainer_topk_head_ms"] = _bench( + lambda: rank._project_head(items, prepared, hidden), + warmup=warmup, + repeat=repeat, + ) + + if "target_builtin_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_fwd_bwd_ms"] = _bench( + lambda: _target_builtin_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_builtin_masked_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_masked_fwd_bwd_ms"] = _bench( + lambda: _target_builtin_masked_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_trainer_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_trainer_fwd_bwd_ms"] = _bench( + lambda: ( + _target_requests_loss(rank, requests) + if adapter_slots + else _target_trainer_loss( + rank, + target_items, + target_prepared, + ) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_hidden_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_hidden_fwd_bwd_ms"] = _bench( + lambda: _target_hidden_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + train_step_params = AdamParams(learning_rate=learning_rate) + offload_manager = ( + _make_offload_manager(runtime) if full_step_offload_reload else None + ) + if "target_builtin_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: _target_builtin_loss(rank, target_items, target_prepared), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_trainer_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _target_requests_loss(rank, requests) + if adapter_slots + else _target_trainer_loss(rank, target_items, target_prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_fixed_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + fixed_stats: list[dict[str, int | bool]] = [] + results["target_trainer_fixed_train_step_ms"] = _bench( + lambda: _fixed_micro_batch_training_step( + rank, + requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=fixed_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "target_trainer_fixed_train_step", fixed_stats + ) + if "target_trainer_adaptive_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + adaptive_stats: list[dict[str, int | bool]] = [] + results["target_trainer_adaptive_train_step_ms"] = _bench( + lambda: _adaptive_micro_batch_training_step( + rank, + requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "target_trainer_adaptive_train_step", adaptive_stats + ) + if "target_trainer_adaptive_profile_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + adaptive_stats: list[dict[str, int | bool | float]] = [] + results["target_trainer_adaptive_profile_train_step_ms"] = _bench( + lambda: _profiled_adaptive_micro_batch_training_step( + rank, + requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, + "target_trainer_adaptive_profile_train_step", + adaptive_stats, + ) + _record_profile_stats( + metadata, + "target_trainer_adaptive_profile_train_step", + adaptive_stats, + ) + if "target_hidden_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_hidden_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: _target_hidden_loss(rank, target_items, target_prepared), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_multi_target_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_multi_target_fwd_bwd", + multi_target_requests, + _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_multi_target_fwd_bwd_ms"] = _bench( + lambda: ( + _target_requests_loss(rank, multi_target_requests) + if adapter_slots + else _target_trainer_loss(rank, items, prepared) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "trainer_multi_target_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_multi_target_train_step", + multi_target_requests, + _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_multi_target_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _target_requests_loss(rank, multi_target_requests) + if adapter_slots + else _target_trainer_loss(rank, items, prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if ( + "trainer_multi_target_fixed_train_step" in benchmarks + or "trainer_multi_target_adaptive_train_step" in benchmarks + ): + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + multi_target_stats = _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ) + if "trainer_multi_target_fixed_train_step" in benchmarks: + register_case( + "trainer_multi_target_fixed_train_step", + multi_target_requests, + multi_target_stats, + ) + for chunk in runtime.model: + chunk.train() + fixed_stats = [] + results["trainer_multi_target_fixed_train_step_ms"] = _bench( + lambda: _fixed_micro_batch_training_step( + rank, + multi_target_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=fixed_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, + "trainer_multi_target_fixed_train_step", + fixed_stats, + ) + if "trainer_multi_target_adaptive_train_step" in benchmarks: + register_case( + "trainer_multi_target_adaptive_train_step", + multi_target_requests, + multi_target_stats, + ) + for chunk in runtime.model: + chunk.train() + adaptive_stats = [] + results["trainer_multi_target_adaptive_train_step_ms"] = _bench( + lambda: _adaptive_micro_batch_training_step( + rank, + multi_target_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, + "trainer_multi_target_adaptive_train_step", + adaptive_stats, + ) + if "trainer_topk_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_fwd_bwd", + topk_requests, + _packed_request_stats(topk_requests, items, batch, request_metadata={}), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_topk_fwd_bwd_ms"] = _bench( + lambda: ( + _topk_requests_loss(rank, topk_requests) + if adapter_slots + else _trainer_topk_loss(rank, items, prepared) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "trainer_topk_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_train_step", + topk_requests, + _packed_request_stats(topk_requests, items, batch, request_metadata={}), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_topk_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _topk_requests_loss(rank, topk_requests) + if adapter_slots + else _trainer_topk_loss(rank, items, prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if ( + "trainer_topk_fixed_train_step" in benchmarks + or "trainer_topk_adaptive_train_step" in benchmarks + ): + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + topk_stats = _packed_request_stats( + topk_requests, + items, + batch, + request_metadata={}, + ) + if "trainer_topk_fixed_train_step" in benchmarks: + register_case( + "trainer_topk_fixed_train_step", + topk_requests, + topk_stats, + ) + for chunk in runtime.model: + chunk.train() + fixed_stats = [] + results["trainer_topk_fixed_train_step_ms"] = _bench( + lambda: _fixed_micro_batch_training_step( + rank, + topk_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="topk", + stats_sink=fixed_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "trainer_topk_fixed_train_step", fixed_stats + ) + if "trainer_topk_adaptive_train_step" in benchmarks: + register_case( + "trainer_topk_adaptive_train_step", + topk_requests, + topk_stats, + ) + for chunk in runtime.model: + chunk.train() + adaptive_stats = [] + results["trainer_topk_adaptive_train_step_ms"] = _bench( + lambda: _adaptive_micro_batch_training_step( + rank, + topk_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="topk", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "trainer_topk_adaptive_train_step", adaptive_stats + ) + + if compare_target_correctness and adapter_slots: + metadata["target_correctness_skipped"] = "adapter_slots" + elif compare_target_correctness: + assert target_items is not None and target_prepared is not None + metadata.update( + _target_correctness_metrics(rank, target_items, target_prepared) + ) + if run_adapter_sanity and adapter_slots > 0: + metadata.update( + _adapter_sanity_metrics( + rank, + requests, + params=train_step_params, + adapter_slots=adapter_slots, + ) + ) + + memory_tracker.stop() + memory_metadata = _distributed_memory_metadata(memory_tracker) + model_metadata = _model_metadata(runtime, model, layers=layers) + + if dist.get_rank() == 0: + token_rates = _rate_metrics(results, rate_units) + payload = { + "world": dist.get_world_size(), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "seq_len": seq_len, + "prefix_families": prefix_families, + "prefix_len": prefix_len, + "mid_prefixes_per_family": mid_prefixes_per_family, + "mid_prefix_len": mid_prefix_len, + "branches_per_prefix": branches_per_prefix, + "completion_len": completion_len, + "head_chunk_tokens": head_chunk_tokens, + "shared_prefix_max_depth": shared_prefix_max_depth, + "warmup": warmup, + "repeat": repeat, + "target_count": target_count, + "top_k": top_k, + "top_k_values": top_k_values, + "max_unpacked_output_gb": max_unpacked_output_gb, + "mask_prefix_targets": mask_prefix_targets, + "workload": workload, + "tree_depth": tree_depth, + "tree_seed": tree_seed, + "tree_duplicate_factor": tree_duplicate_factor, + "adapter_slots": adapter_slots, + "adapter_slot_mode": adapter_slot_mode, + "adapter_slot_rank": adapter_slot_rank, + "adapter_loaded_sites": loaded_sites, + "learning_rate": learning_rate, + "full_step_offload_reload": full_step_offload_reload, + "memory_safety_factor": memory_safety_factor, + "memory_reserve_fraction": memory_reserve_fraction, + "mtp_num_layers": getattr(model_config, "mtp_num_layers", None), + "cross_entropy_loss_fusion": getattr( + model_config, "cross_entropy_loss_fusion", None + ), + "cross_entropy_fusion_impl": getattr( + model_config, "cross_entropy_fusion_impl", None + ), + **model_metadata, + **request_stats, + **memory_metadata, + **results, + **token_rates, + **metadata, + **planner_metadata, + } + line = json.dumps(payload, sort_keys=True) + print(line, flush=True) + if output_jsonl: + output_path = Path(output_jsonl) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("a", encoding="utf-8") as output_file: + output_file.write(line + "\n") + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _requests( + *, + seq_len: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, + target_count: int, + mask_prefix_targets: bool, + workload: str, + tree_depth: int, + tree_seed: int, + tree_duplicate_factor: int, +) -> tuple[ + list[ForwardInput[torch.Tensor, None, None, None]], + list[ForwardInput[torch.Tensor, None, None, None]], + dict[str, int | str], +]: + if workload == "regular" and prefix_families <= 0: + tokens = torch.arange(seq_len, dtype=torch.long) % 32_000 + 100 + labels = _labels(tokens, target_count=1) + return ( + [ForwardInput(input_tokens=tokens, target_tokens=labels)], + [ + ForwardInput( + input_tokens=tokens, + target_tokens=_labels(tokens, target_count=target_count), + ) + ], + { + "request_count": 1, + "workload_shape": "single", + }, + ) + + if prefix_len < 1 or branches_per_prefix < 1 or completion_len < 1: + raise ValueError( + "prefix_len, branches_per_prefix, and completion_len must be >= 1" + ) + if mid_prefixes_per_family < 1 or mid_prefix_len < 0: + raise ValueError("mid_prefixes_per_family must be >= 1 and mid_prefix_len >= 0") + + sequences, prefix_lengths, workload_shape = _workload_sequences( + workload=workload, + seq_len=seq_len, + prefix_families=max(prefix_families, 1), + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + tree_depth=tree_depth, + tree_seed=tree_seed, + tree_duplicate_factor=tree_duplicate_factor, + ) + requests = [] + multi_requests = [] + for tokens, shared_length in zip(sequences, prefix_lengths, strict=True): + labels = _labels(tokens, target_count=1) + multi_labels = _labels(tokens, target_count=target_count) + if mask_prefix_targets and shared_length: + labels[:shared_length] = -100 + multi_labels[:shared_length] = -100 + requests.append(ForwardInput(input_tokens=tokens, target_tokens=labels)) + multi_requests.append( + ForwardInput(input_tokens=tokens, target_tokens=multi_labels) + ) + + return ( + requests, + multi_requests, + { + "request_count": len(requests), + "workload_shape": workload_shape, + }, + ) + + +def _load_adapter_slots( + rank: TrainerRank, + *, + count: int, + slot_rank: int, +) -> int: + loaded_sites = 0 + for slot_index in range(count): + loaded_sites += rank.load_checkpoint_slot( + f"S{slot_index}", + _synthetic_adapter( + rank.runtime.model, slot_rank=slot_rank, seed=slot_index + ), + ) + return loaded_sites + + +def _synthetic_adapter( + model: Sequence[torch.nn.Module], + *, + slot_rank: int, + seed: int, +) -> dict[str, torch.Tensor]: + from art.megatron.lora import LoRA + + adapter: dict[str, torch.Tensor] = {} + generator = torch.Generator(device="cuda").manual_seed(10_000 + seed) + for chunk in model: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + a_keys = module._expected_weight_keys("lora_A") + b_keys = module._expected_weight_keys("lora_B") + for a_key, b_key in zip(a_keys, b_keys, strict=True): + adapter[a_key] = ( + torch.randn( + slot_rank, + module.in_features, + dtype=module.A_T.dtype, + device=module.A_T.device, + generator=generator, + ) + * 0.01 + ) + adapter[b_key] = ( + torch.randn( + module.out_features, + slot_rank, + dtype=module.B_T.dtype, + device=module.B_T.device, + generator=generator, + ) + * 0.01 + ) + if not adapter: + raise RuntimeError("adapter slot stress requested, but model has no LoRA sites") + return adapter + + +def _route_adapter_slots( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + adapter_slots: int, + mode: str, +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if adapter_slots == 0: + return list(requests) + if mode not in {"family", "round_robin", "single", "skewed_random"}: + raise ValueError( + "adapter_slot_mode must be one of: family, round_robin, single, " + "skewed_random" + ) + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + top_k=request.top_k, + logits=request.logits, + hidden_states=request.hidden_states, + checkpoint=f"S{_adapter_slot_index(index, request, adapter_slots, mode)}", + ) + for index, request in enumerate(requests) + ] + + +def _adapter_slot_index( + index: int, + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + adapter_slots: int, + mode: str, +) -> int: + if mode == "single": + return 0 + if mode == "round_robin": + return index % adapter_slots + if mode == "skewed_random": + bucket = (index * 1103515245 + 12345) & 0x7FFFFFFF + skew = bucket % 100 + if skew < 50: + return 0 + if skew < 75: + return min(1, adapter_slots - 1) + if skew < 90: + return min(2, adapter_slots - 1) + return min(3 + (bucket % max(1, adapter_slots - 3)), adapter_slots - 1) + first_token = ( + int(request.input_tokens[0].item()) if request.input_tokens.numel() else 0 + ) + return (first_token // 10_000_019) % adapter_slots + + +def _with_outputs( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, +) -> ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None +]: + return ForwardInput( + input_tokens=request.input_tokens, + target_tokens=target_tokens, + top_k=top_k, + logits=logits, + hidden_states=hidden_states, + checkpoint=request.checkpoint, + lora=request.lora, + ) + + +def _workload_sequences( + *, + workload: str, + seq_len: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, + tree_depth: int, + tree_seed: int, + tree_duplicate_factor: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + if workload in {"austin_198k", "austin_5k_16x100"}: + return _regular_tree_sequences( + prefix_families=30, + prefix_len=5000, + mid_prefixes_per_family=1, + mid_prefix_len=0, + branches_per_prefix=16, + completion_len=100, + ) + if workload == "austin_varied": + return _austin_varied_sequences() + if workload == "regular": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "single": + tokens = torch.arange(seq_len, dtype=torch.long) % 32_000 + 100 + return (tokens,), (0,), "single" + if workload == "long_root": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=1, + mid_prefix_len=0, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "long_mid": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "many_tiny_leaves": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(1, mid_prefixes_per_family), + mid_prefix_len=max(0, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=max(1, completion_len), + ) + if workload == "uneven": + return _uneven_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "duplicates": + sequences, shared, shape = _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + factor = max(1, tree_duplicate_factor) + return ( + tuple(sequence for sequence in sequences for _ in range(factor)), + tuple(length for length in shared for _ in range(factor)), + f"{shape}:duplicates={factor}", + ) + if workload == "random": + return _random_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + branches_per_prefix=max(2, min(branches_per_prefix, 4)), + completion_len=completion_len, + tree_depth=max(1, tree_depth), + seed=tree_seed, + ) + raise ValueError( + "workload must be one of: regular, single, long_root, long_mid, " + "many_tiny_leaves, uneven, duplicates, random, austin_198k, austin_varied" + ) + + +def _regular_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + nested = mid_prefixes_per_family > 1 and mid_prefix_len > 0 + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(prefix_families): + family_base = family * 10_000_019 + root = _tokens(family_base, prefix_len) + mid_count = mid_prefixes_per_family if nested else 1 + for mid in range(mid_count): + mid_prefix = ( + _tokens(family_base + 1_000_003 + mid * 100_003, mid_prefix_len) + if nested + else torch.empty(0, dtype=torch.long) + ) + shared = torch.cat((root, mid_prefix)) + for branch in range(branches_per_prefix): + sequences.append( + torch.cat( + ( + shared, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + shared_lengths.append(int(shared.numel())) + shape = ( + f"families={prefix_families}:mid={mid_prefixes_per_family}:" + f"branches={branches_per_prefix}:nested={int(nested)}" + ) + return tuple(sequences), tuple(shared_lengths), shape + + +def _austin_varied_sequences() -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(30): + family_base = family * 10_000_019 + prefix_len = 4500 + ((family * 137) % 1001) + root = _tokens(family_base, prefix_len) + branch_count = 10 + ((family * 7) % 13) + for branch in range(branch_count): + completion_len = 32 + ((family * 19 + branch * 23) % 145) + sequences.append( + torch.cat( + ( + root, + _tokens( + family_base + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + shared_lengths.append(int(root.numel())) + return tuple(sequences), tuple(shared_lengths), "austin_varied" + + +def _uneven_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(prefix_families): + family_base = family * 10_000_019 + root_len = max(1, prefix_len // (family + 1)) + root = _tokens(family_base, root_len) + for mid in range(mid_prefixes_per_family): + mid_len = max(1, mid_prefix_len // (mid + 1)) + mid_prefix = _tokens(family_base + 1_000_003 + mid * 100_003, mid_len) + branch_count = max(1, branches_per_prefix - mid) + for branch in range(branch_count): + leaf_len = max(1, completion_len * (branch + 1) // branch_count) + shared = torch.cat((root, mid_prefix)) + sequences.append( + torch.cat( + ( + shared, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + leaf_len, + ), + ) + ) + ) + shared_lengths.append(int(shared.numel())) + return tuple(sequences), tuple(shared_lengths), "uneven" + + +def _random_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + branches_per_prefix: int, + completion_len: int, + tree_depth: int, + seed: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + generator = torch.Generator().manual_seed(seed) + next_offset = 1 + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def segment(length: int) -> torch.Tensor: + nonlocal next_offset + out = _tokens(next_offset, max(1, length)) + next_offset += max(1, length) + 10_000 + return out + + def length_for_depth(depth: int) -> int: + if depth == 0: + return max(1, prefix_len) + choices = (1, 8, 64, max(1, completion_len), max(1, prefix_len // 2)) + return choices[randint(0, len(choices) - 1)] + + def walk(prefix: torch.Tensor, depth: int) -> None: + shared = torch.cat((prefix, segment(length_for_depth(depth)))) + if depth + 1 >= tree_depth: + leaf_count = randint(2, branches_per_prefix) + for _ in range(leaf_count): + leaf = segment(randint(1, max(1, completion_len))) + sequences.append(torch.cat((shared, leaf))) + shared_lengths.append(int(shared.numel())) + return + for _ in range(randint(2, branches_per_prefix)): + walk(shared, depth + 1) + + for _ in range(prefix_families): + walk(torch.empty(0, dtype=torch.long), 0) + return tuple(sequences), tuple(shared_lengths), f"random:depth={tree_depth}" + + +def _packed_request_stats( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + items: Sequence[object], + batch: object, + *, + request_metadata: dict[str, int | str], +) -> dict[str, int | str]: + from art.megatron.shared_prefix_tree import parse_shared_prefix_tree + + trainable_mask = torch.zeros(int(batch.tokens.numel()), dtype=torch.bool) + trainable_tokens = 0 + for item, positions in zip(items, batch.positions_by_sequence, strict=True): + labels = getattr(item, "labels", None) + if labels is None: + continue + mask = labels != -100 + row_mask = mask.reshape(int(mask.shape[0]), -1).any(dim=1) + trainable_tokens += int(mask.sum().item()) + trainable_mask[positions.reshape(-1).cpu()] |= row_mask.cpu() + group_ids = batch.group_ids + parent_ids = batch.parent_ids + return { + **request_metadata, + "request_count": len(requests), + "packed_tokens": int(batch.tokens.numel()), + "logical_tokens": sum( + int(request.input_tokens.numel()) for request in requests + ), + "trainable_tokens": trainable_tokens, + "packed_trainable_tokens": int(trainable_mask.sum().item()), + "packed_group_count": int(group_ids.max().item()) + if int(group_ids.numel()) + else 0, + "nested_prefix_depth": max( + ( + segment.depth + for row in parse_shared_prefix_tree( + group_ids=group_ids, + parent_ids=parent_ids, + ) + for segment in row.segments + ), + default=0, + ), + } + + +def _gather_planner_metadata(prepared: object) -> dict[str, object]: + local = _local_planner_metadata(prepared) + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + if dist.get_rank() != 0: + return {} + ranks = [metrics or {} for metrics in gathered] + gdn_tokens = [int(metrics.get("gdn_tokens", 0)) for metrics in ranks] + attention_tokens = [int(metrics.get("attention_tokens", 0)) for metrics in ranks] + keys = ( + "tree_local_bucket_count", + "tree_chain_bucket_count", + "tree_local_segment_count", + "tree_chain_segment_count", + "tree_local_real_tokens", + "tree_chain_real_tokens", + "tree_state_transfer_count", + "tree_state_transfer_rows", + "tree_max_padding_ratio", + ) + merged: dict[str, object] = { + "planner_rank_gdn_tokens": gdn_tokens, + "planner_rank_attention_tokens": attention_tokens, + "planner_gdn_token_imbalance": max(gdn_tokens, default=0) + - min(gdn_tokens, default=0), + } + for key in keys: + values = [metrics[key] for metrics in ranks if key in metrics] + if not values: + continue + if key.endswith("_ratio"): + merged[f"planner_{key}_max"] = round( + max(float(value) for value in values), 3 + ) + else: + merged[f"planner_{key}_sum"] = int(sum(int(value) for value in values)) + merged[f"planner_{key}_max"] = int(max(int(value) for value in values)) + rank0 = ranks[0] if ranks else {} + if "tree_depth_count" in rank0: + merged["planner_tree_depth_count"] = rank0["tree_depth_count"] + return merged + + +def _local_planner_metadata(prepared: object) -> dict[str, object]: + plan = getattr( + getattr(prepared, "attention_state", None), "gdn_execution_plan", None + ) + if plan is None: + return {} + local_buckets = tuple( + bucket + for depth in getattr(plan, "tree_segment_buckets_by_depth", ()) + for bucket in depth + ) + chain_buckets = tuple( + bucket + for depth in getattr(plan, "tree_chain_buckets_by_depth", ()) + for bucket in depth + ) + all_buckets = (*local_buckets, *chain_buckets) + padding_ratios = [ + bucket.length * bucket.segment_count / max(1, bucket.real_token_count) + for bucket in all_buckets + ] + transfers_by_depth = getattr(plan, "tree_state_transfers_by_depth", ()) + return { + "attention_tokens": int(getattr(plan, "attention_token_count", 0)), + "gdn_tokens": int(getattr(plan, "gdn_token_count", 0)), + "tree_depth_count": len(getattr(plan, "tree_segment_buckets_by_depth", ())), + "tree_local_bucket_count": len(local_buckets), + "tree_chain_bucket_count": len(chain_buckets), + "tree_local_segment_count": sum( + bucket.segment_count for bucket in local_buckets + ), + "tree_chain_segment_count": sum( + bucket.segment_count for bucket in chain_buckets + ), + "tree_local_real_tokens": sum( + bucket.real_token_count for bucket in local_buckets + ), + "tree_chain_real_tokens": sum( + bucket.real_token_count for bucket in chain_buckets + ), + "tree_state_transfer_count": sum( + len(transfers) for transfers in transfers_by_depth + ), + "tree_state_transfer_rows": sum( + len(transfer.family_indices) + for transfers in transfers_by_depth + for transfer in transfers + ), + "tree_max_padding_ratio": max(padding_ratios, default=1.0), + } + + +def _tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _int_values(value: str) -> list[int]: + values = [int(part) for part in value.split(",") if part.strip()] + if not values or any(item < 1 for item in values): + raise ValueError("top_k_values must contain positive integers") + return values + + +def _labels(tokens: torch.Tensor, *, target_count: int) -> torch.Tensor: + labels = torch.stack( + [((tokens * 7 + 3 + index) % 32_000) for index in range(target_count)], + dim=1, + ) + if target_count > 1: + labels[::17, -1] = -100 + return labels + return labels[:, 0] + + +class _CudaMemoryTracker: + def __init__(self, *, device_index: int, sample_interval_s: float) -> None: + self.device_index = device_index + self.sample_interval_s = sample_interval_s + self.process_peak_bytes = 0 + self.allocated_peak_bytes = 0 + self.reserved_peak_bytes = 0 + self._stop = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + if not torch.cuda.is_available(): + return + torch.cuda.reset_peak_memory_stats() + self._sample() + if self.sample_interval_s <= 0: + return + self._thread = threading.Thread(target=self._poll, daemon=True) + self._thread.start() + + def stop(self) -> None: + if not torch.cuda.is_available(): + return + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=1.0) + torch.cuda.synchronize() + self._sample() + self.allocated_peak_bytes = max( + self.allocated_peak_bytes, + int(torch.cuda.max_memory_allocated()), + ) + self.reserved_peak_bytes = max( + self.reserved_peak_bytes, + int(torch.cuda.max_memory_reserved()), + ) + + def _poll(self) -> None: + while not self._stop.wait(self.sample_interval_s): + self._sample() + + def _sample(self) -> None: + self.process_peak_bytes = max( + self.process_peak_bytes, + _current_process_gpu_memory_bytes(self.device_index), + ) + self.allocated_peak_bytes = max( + self.allocated_peak_bytes, + int(torch.cuda.memory_allocated()) if torch.cuda.is_available() else 0, + ) + self.reserved_peak_bytes = max( + self.reserved_peak_bytes, + int(torch.cuda.memory_reserved()) if torch.cuda.is_available() else 0, + ) + + +def _current_process_gpu_memory_bytes(device_index: int) -> int: + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + pid = os.getpid() + processes = list(pynvml.nvmlDeviceGetComputeRunningProcesses(handle)) + with suppress(Exception): + processes.extend(pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)) + for process in processes: + if int(process.pid) == pid: + return int(process.usedGpuMemory) + except Exception: + return 0 + return 0 + + +def _distributed_memory_metadata(tracker: _CudaMemoryTracker) -> dict[str, float]: + values = torch.tensor( + [ + tracker.allocated_peak_bytes, + tracker.reserved_peak_bytes, + tracker.process_peak_bytes, + ], + device="cuda", + dtype=torch.float64, + ) + dist.all_reduce(values, op=dist.ReduceOp.MAX) + return { + "peak_memory_allocated_gb": round(float(values[0].item()) / 1024**3, 3), + "peak_memory_reserved_gb": round(float(values[1].item()) / 1024**3, 3), + "peak_memory_process_gb": round(float(values[2].item()) / 1024**3, 3), + "peak_memory_gb": round(float(values[0].item()) / 1024**3, 3), + } + + +def _mean_abs_pct(reference: torch.Tensor, candidate: torch.Tensor) -> float: + reference_fp32 = reference.detach().float() + candidate_fp32 = candidate.detach().float() + return float( + (candidate_fp32 - reference_fp32).abs().mean().item() + / (reference_fp32.abs().mean().item() + 1e-18) + ) + + +def _model_metadata(runtime: object, model_name: str, *, layers: int) -> dict[str, Any]: + from art.megatron.lora import LoRA + + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + config = getattr(model, "config", None) + total_params = sum( + int(param.numel()) for chunk in runtime.model for param in chunk.parameters() + ) + trainable_params = sum( + int(param.numel()) + for chunk in runtime.model + for param in chunk.parameters() + if param.requires_grad + ) + lora_sites = sum( + 1 + for chunk in runtime.model + for module in chunk.modules() + if isinstance(module, LoRA) + ) + local = torch.tensor( + [total_params, trainable_params, lora_sites], + device="cuda", + dtype=torch.float64, + ) + dist.all_reduce(local, op=dist.ReduceOp.MAX) + return { + "model": model_name, + "layers_arg": layers, + "provider_num_layers": getattr(provider, "num_layers", None), + "config_num_layers": getattr(config, "num_layers", None), + "rank_local_param_count": int(local[0].item()), + "rank_local_trainable_param_count": int(local[1].item()), + "rank_local_lora_site_count": int(local[2].item()), + } + + +def _bench( + fn: Callable[[], object], + *, + warmup: int, + repeat: int, + after: Callable[[], object] | None = None, +) -> float: + for _ in range(warmup): + fn() + if after is not None: + after() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + fn() + if after is not None: + after() + stop.record() + torch.cuda.synchronize() + elapsed = torch.tensor(start.elapsed_time(stop) / repeat, device="cuda") + dist.all_reduce(elapsed, op=dist.ReduceOp.MAX) + return round(float(elapsed.item()), 3) + + +def _builtin( + rank: TrainerRank, + prepared: object, + labels: torch.Tensor | None, +) -> torch.Tensor: + from art.megatron.train import _placeholder_attention_mask + + return rank.runtime.model[0]( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + attention_mask=_placeholder_attention_mask(rank.device), + labels=labels, + packed_seq_params=prepared.packed_seq_params, + **rank.runtime.model_support_handler.get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ), + ) + + +def _full_logits(rank: TrainerRank, prepared: object) -> torch.Tensor: + logits = rank._gather_tensor_parallel_logits(_builtin(rank, prepared, None)) + return _batch_seq_logits(logits, seq_len=int(prepared.tokens.shape[1])) + + +def _target_builtin_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + return _builtin(rank, prepared, _packed_labels(items, prepared)).float().sum() + + +def _target_builtin_masked_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + labels = _packed_labels(items, prepared) + per_token_loss = _builtin(rank, prepared, labels).float().reshape(-1) + valid = labels.reshape(-1) != -100 + return per_token_loss[valid].sum() + per_token_loss.sum() * 0.0 + + +def _target_hidden_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + outputs = rank._project_head(items, prepared, hidden) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _target_trainer_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + outputs = rank._forward_packed(items, prepared) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _target_requests_loss( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> torch.Tensor: + outputs = rank.dp_rank_forward(requests) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _trainer_topk_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + outputs = rank._forward_packed(items, prepared) + losses = [ + -output.top_k.logprobs.sum() for output in outputs if output.top_k is not None + ] + if not losses: + raise RuntimeError("top_k logprobs were not produced") + return torch.stack(losses).sum() + + +def _topk_requests_loss( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> torch.Tensor: + outputs = rank.dp_rank_forward(requests) + losses = [ + -output.top_k.logprobs.sum() for output in outputs if output.top_k is not None + ] + if not losses: + raise RuntimeError("top_k logprobs were not produced") + return torch.stack(losses).sum() + + +def _fixed_micro_batch_training_step( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + offload_manager: object | None, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + def body() -> dict[str, float]: + return _fixed_micro_batch_training_step_body( + rank, + requests, + params=params, + loss_kind=loss_kind, + stats_sink=stats_sink, + ) + + if offload_manager is None: + return body() + with offload_manager.job(): # type: ignore[attr-defined] + return body() + + +def _fixed_micro_batch_training_step_body( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + rank.zero_grad() + dp_rank, dp_size = rank._dp_rank_and_size() + stats: list[dict[str, int | bool]] = [] + for start in range(0, len(requests), dp_size): + stop = min(start + dp_size, len(requests)) + indices = tuple(range(start + dp_rank, stop, dp_size)) + local_requests = [requests[index] for index in indices] + outputs = rank.dp_rank_forward(local_requests) + loss = _micro_batch_loss(rank, outputs, loss_kind=loss_kind) + if loss.requires_grad: + loss.backward() + stats.append( + { + "global_count": stop - start, + "local_count": len(local_requests), + "packed_tokens": _logical_input_tokens(local_requests), + "logical_tokens": _logical_input_tokens(local_requests), + "rejected_candidates": 0, + "cold_start": False, + } + ) + stats_sink[:] = stats + return rank.optim_step(params=params, scale_grads=1.0) + + +def _adaptive_micro_batch_training_step( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + offload_manager: object | None, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + def body() -> dict[str, float]: + return _adaptive_micro_batch_training_step_body( + rank, + requests, + params=params, + loss_kind=loss_kind, + stats_sink=stats_sink, + ) + + if offload_manager is None: + return body() + with offload_manager.job(): # type: ignore[attr-defined] + return body() + + +def _adaptive_micro_batch_training_step_body( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + rank.zero_grad() + stats: list[dict[str, int | bool]] = [] + step_start = time.perf_counter() + for micro_batch in rank.forward_micro_batches(requests): + loss = _micro_batch_loss(rank, micro_batch.outputs, loss_kind=loss_kind) + if loss.requires_grad: + loss.backward() + row = { + "global_count": int(micro_batch.stats.global_count), + "local_count": int(micro_batch.stats.local_count), + "packed_tokens": int(micro_batch.stats.packed_tokens), + "logical_tokens": int(micro_batch.stats.logical_tokens), + "estimated_required_bytes": int(micro_batch.stats.estimated_required_bytes), + "available_bytes": int(micro_batch.stats.available_bytes), + "rejected_candidates": int(micro_batch.stats.rejected_candidates), + "cold_start": bool(micro_batch.stats.cold_start), + } + stats.append(row) + _emit_adaptive_progress( + "target_trainer_adaptive_train_step_window", + { + **row, + "window_index": len(stats) - 1, + "elapsed_ms": (time.perf_counter() - step_start) * 1000.0, + }, + ) + stats_sink[:] = stats + return rank.optim_step(params=params, scale_grads=1.0) + + +def _profiled_adaptive_micro_batch_training_step( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + offload_manager: object | None, + loss_kind: str, + stats_sink: list[dict[str, int | bool | float]], +) -> dict[str, float]: + def body() -> dict[str, float]: + return _profiled_adaptive_micro_batch_training_step_body( + rank, + requests, + params=params, + loss_kind=loss_kind, + stats_sink=stats_sink, + ) + + if offload_manager is None: + return body() + with offload_manager.job(): # type: ignore[attr-defined] + return body() + + +def _profiled_adaptive_micro_batch_training_step_body( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + loss_kind: str, + stats_sink: list[dict[str, int | bool | float]], +) -> dict[str, float]: + rank.zero_grad() + items = list(requests) + rank._validate_replicated_top_level_count(len(items)) + start = 0 + stats: list[dict[str, int | bool | float]] = [] + step_start = time.perf_counter() + while start < len(items): + with _profile_adaptive_selection(rank) as select_profile: + candidate, select_ms = _timed_cuda( + rank, lambda: rank._select_next_micro_batch(items, start) + ) + select_profile["select_plan_residual_ms"] = max( + 0.0, + select_profile["select_plan_ms"] + - select_profile["select_forward_item_ms"] + - select_profile["select_pack_ms"] + - select_profile["select_output_estimate_ms"] + - select_profile["select_signature_ms"], + ) + select_profile["select_memory_check_residual_ms"] = max( + 0.0, + select_profile["select_memory_check_ms"] + - select_profile["select_memory_estimate_ms"] + - select_profile["select_available_memory_ms"], + ) + select_profile["select_residual_ms"] = max( + 0.0, + select_ms + - select_profile["select_estimate_ms"] + - select_profile["select_plan_ms"] + - select_profile["select_memory_check_ms"] + - select_profile["select_profile_check_ms"], + ) + flat_outputs, execute_ms = _timed_cuda( + rank, + lambda: rank._run_flat_plan_with_memory_tracking( + candidate.plan, + context="target_trainer_adaptive_profile_train_step", + ), + ) + + def unflatten_outputs() -> list[object]: + flat_iter = iter(flat_outputs) + return [_unflatten(item, flat_iter) for item in candidate.inputs] + + outputs, unflatten_ms = _timed_cuda( + rank, + unflatten_outputs, + ) + loss, loss_ms = _timed_cuda( + rank, lambda: _micro_batch_loss(rank, outputs, loss_kind=loss_kind) + ) + if loss.requires_grad: + _, backward_ms = _timed_cuda(rank, loss.backward) + else: + backward_ms = 0.0 + row = { + "global_count": int(candidate.stats_global_count), + "local_count": int(len(candidate.inputs)), + "packed_tokens": int(candidate.plan.packed_tokens), + "logical_tokens": int(candidate.plan.logical_tokens), + "estimated_required_bytes": int(candidate.check.estimated_required_bytes), + "available_bytes": int(candidate.check.available_bytes), + "rejected_candidates": int(candidate.rejected_candidates), + "cold_start": bool(candidate.cold_start), + "select_ms": select_ms, + "execute_ms": execute_ms, + "unflatten_ms": unflatten_ms, + "loss_ms": loss_ms, + "backward_ms": backward_ms, + "optim_ms": 0.0, + **select_profile, + } + stats.append(row) + stop = start + candidate.stats_global_count + if stop < len(items): + rank._last_global_micro_batch_size = max( + rank._last_global_micro_batch_size or 0, + candidate.stats_global_count, + ) + _emit_adaptive_progress( + "target_trainer_adaptive_profile_train_step_window", + { + **row, + "window_index": len(stats) - 1, + "global_start": int(start), + "global_stop": int(stop), + "remembered_window": int(rank._last_global_micro_batch_size or 0), + "elapsed_ms": (time.perf_counter() - step_start) * 1000.0, + }, + ) + start = stop + metrics, optim_ms = _timed_cuda( + rank, lambda: rank.optim_step(params=params, scale_grads=1.0) + ) + if stats: + stats[-1]["optim_ms"] = optim_ms + stats_sink[:] = stats + return metrics + + +def _emit_adaptive_progress(event: str, row: dict[str, object]) -> None: + if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: + return + path = os.environ.get("ART_TRAINER_RANK_PROGRESS_JSONL") + if not path: + return + payload = {"event": event, **row} + line = json.dumps(payload, sort_keys=True) + print(line, flush=True) + progress_path = Path(path) + progress_path.parent.mkdir(parents=True, exist_ok=True) + with progress_path.open("a") as handle: + handle.write(line + "\n") + + +@contextmanager +def _profile_adaptive_selection(rank: TrainerRank) -> Any: + stats = { + "select_plan_ms": 0.0, + "select_plan_calls": 0, + "select_forward_item_ms": 0.0, + "select_forward_item_calls": 0, + "select_pack_ms": 0.0, + "select_pack_calls": 0, + "select_estimate_ms": 0.0, + "select_estimate_calls": 0, + "select_plan_lookup_calls": 0, + "select_plan_cache_hit_calls": 0, + "select_plan_cache_miss_calls": 0, + "select_estimate_lookup_calls": 0, + "select_estimate_cache_hit_calls": 0, + "select_estimate_cache_miss_calls": 0, + "select_output_estimate_ms": 0.0, + "select_output_estimate_calls": 0, + "select_signature_ms": 0.0, + "select_signature_calls": 0, + "select_memory_check_ms": 0.0, + "select_memory_check_calls": 0, + "select_memory_estimate_ms": 0.0, + "select_memory_estimate_calls": 0, + "select_available_memory_ms": 0.0, + "select_available_memory_calls": 0, + "select_profile_check_ms": 0.0, + "select_profile_check_calls": 0, + } + + def timed( + key: str, + calls_key: str, + fn: Callable[..., object], + *args: object, + **kwargs: object, + ) -> object: + start = time.perf_counter() + try: + return fn(*args, **kwargs) + finally: + stats[key] += (time.perf_counter() - start) * 1000.0 + stats[calls_key] += 1 + + original_plan = rank._plan_flat_forward + original_cached_plan = rank._cached_adaptive_plan + original_estimate = rank._estimate_flat_forward + original_cached_estimate = rank._cached_adaptive_estimate + original_forward_item = rank._forward_item + original_pack = trainer_rank_module.pack_shared_prefixes + original_output_estimate = rank._estimate_group_request_output_bytes + original_signature = rank._memory_signature_from_requests + original_memory_check = rank._memory_check + original_memory_estimate = rank._estimate_required_memory_bytes_from_values + original_available = rank._available_memory_bytes + original_profile_check = rank._all_ranks_have_memory_profile + + def plan_wrapper(requests: object) -> object: + return timed("select_plan_ms", "select_plan_calls", original_plan, requests) + + def cached_plan_wrapper(*args: object, **kwargs: object) -> object: + stats["select_plan_lookup_calls"] += 1 + before = stats["select_plan_calls"] + result = original_cached_plan(*args, **kwargs) + if stats["select_plan_calls"] == before: + stats["select_plan_cache_hit_calls"] += 1 + else: + stats["select_plan_cache_miss_calls"] += 1 + return result + + def estimate_wrapper(requests: object) -> object: + return timed( + "select_estimate_ms", + "select_estimate_calls", + original_estimate, + requests, + ) + + def cached_estimate_wrapper(*args: object, **kwargs: object) -> object: + stats["select_estimate_lookup_calls"] += 1 + before = stats["select_estimate_calls"] + result = original_cached_estimate(*args, **kwargs) + if stats["select_estimate_calls"] == before: + stats["select_estimate_cache_hit_calls"] += 1 + else: + stats["select_estimate_cache_miss_calls"] += 1 + return result + + def forward_item_wrapper(request: object) -> object: + return timed( + "select_forward_item_ms", + "select_forward_item_calls", + original_forward_item, + request, + ) + + def pack_wrapper(*args: object, **kwargs: object) -> object: + start = time.perf_counter() + try: + return original_pack(*args, **kwargs) + finally: + stats["select_pack_ms"] += (time.perf_counter() - start) * 1000.0 + stats["select_pack_calls"] += 1 + + def output_estimate_wrapper(items: object) -> object: + return timed( + "select_output_estimate_ms", + "select_output_estimate_calls", + original_output_estimate, + items, + ) + + def signature_wrapper(*args: object, **kwargs: object) -> object: + return timed( + "select_signature_ms", + "select_signature_calls", + original_signature, + *args, + **kwargs, + ) + + def memory_check_wrapper(plan: object) -> object: + return timed( + "select_memory_check_ms", + "select_memory_check_calls", + original_memory_check, + plan, + ) + + def memory_estimate_wrapper(*args: object, **kwargs: object) -> object: + return timed( + "select_memory_estimate_ms", + "select_memory_estimate_calls", + original_memory_estimate, + *args, + **kwargs, + ) + + def available_wrapper() -> object: + return timed( + "select_available_memory_ms", + "select_available_memory_calls", + original_available, + ) + + def profile_check_wrapper(*args: object, **kwargs: object) -> object: + return timed( + "select_profile_check_ms", + "select_profile_check_calls", + original_profile_check, + *args, + **kwargs, + ) + + rank._plan_flat_forward = plan_wrapper # type: ignore[method-assign] + rank._cached_adaptive_plan = cached_plan_wrapper # type: ignore[method-assign] + rank._estimate_flat_forward = estimate_wrapper # type: ignore[method-assign] + rank._cached_adaptive_estimate = cached_estimate_wrapper # type: ignore[method-assign] + rank._forward_item = forward_item_wrapper # type: ignore[method-assign] + trainer_rank_module.pack_shared_prefixes = pack_wrapper # type: ignore[assignment] + rank._estimate_group_request_output_bytes = output_estimate_wrapper # type: ignore[method-assign] + rank._memory_signature_from_requests = signature_wrapper # type: ignore[method-assign] + rank._memory_check = memory_check_wrapper # type: ignore[method-assign] + rank._estimate_required_memory_bytes_from_values = memory_estimate_wrapper # type: ignore[method-assign] + rank._available_memory_bytes = available_wrapper # type: ignore[method-assign] + rank._all_ranks_have_memory_profile = profile_check_wrapper # type: ignore[method-assign] + try: + yield stats + finally: + rank._plan_flat_forward = original_plan # type: ignore[method-assign] + rank._cached_adaptive_plan = original_cached_plan # type: ignore[method-assign] + rank._estimate_flat_forward = original_estimate # type: ignore[method-assign] + rank._cached_adaptive_estimate = original_cached_estimate # type: ignore[method-assign] + rank._forward_item = original_forward_item # type: ignore[method-assign] + trainer_rank_module.pack_shared_prefixes = original_pack # type: ignore[assignment] + rank._estimate_group_request_output_bytes = original_output_estimate # type: ignore[method-assign] + rank._memory_signature_from_requests = original_signature # type: ignore[method-assign] + rank._memory_check = original_memory_check # type: ignore[method-assign] + rank._estimate_required_memory_bytes_from_values = original_memory_estimate # type: ignore[method-assign] + rank._available_memory_bytes = original_available # type: ignore[method-assign] + rank._all_ranks_have_memory_profile = original_profile_check # type: ignore[method-assign] + + +def _timed_cuda( + rank: TrainerRank, + fn: Callable[[], object], +) -> tuple[object, float]: + _sync_cuda(rank) + start = time.perf_counter() + result = fn() + _sync_cuda(rank) + return result, (time.perf_counter() - start) * 1000.0 + + +def _sync_cuda(rank: TrainerRank) -> None: + if torch.cuda.is_available() and rank.device.type == "cuda": + torch.cuda.synchronize(rank.device) + + +def _micro_batch_loss( + rank: TrainerRank, + outputs: object, + *, + loss_kind: str, +) -> torch.Tensor: + losses: list[torch.Tensor] = [] + for output in _iter_outputs(outputs): + if loss_kind == "target": + target_logprobs = getattr(output, "target_logprobs", None) + if target_logprobs is not None: + losses.append(-target_logprobs.sum()) + elif loss_kind == "topk": + top_k = getattr(output, "top_k", None) + if top_k is not None: + losses.append(-top_k.logprobs.sum()) + else: + raise ValueError(f"unknown loss_kind: {loss_kind}") + if not losses: + return torch.tensor(0.0, device=rank.device) + return torch.stack(losses).sum() + + +def _iter_outputs(value: object) -> Sequence[object]: + if hasattr(value, "target_logprobs") and hasattr(value, "top_k"): + return (value,) + if isinstance(value, Sequence): + outputs: list[object] = [] + for item in value: + outputs.extend(_iter_outputs(item)) + return outputs + raise TypeError(f"unexpected TrainerRank output value: {type(value)!r}") + + +def _logical_input_tokens( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> int: + return sum( + int(request.input_tokens.numel()) + for request in requests + if request.input_tokens is not None + ) + + +def _record_micro_batch_stats( + metadata: dict[str, object], + name: str, + stats: Sequence[dict[str, int | bool | float]], +) -> None: + if not stats: + metadata[f"{name}_micro_window_count"] = 0 + return + global_counts = [int(stat["global_count"]) for stat in stats] + local_counts = [int(stat["local_count"]) for stat in stats] + packed_tokens = [int(stat["packed_tokens"]) for stat in stats] + rejected = [int(stat["rejected_candidates"]) for stat in stats] + estimated_required = [ + int(stat.get("estimated_required_bytes", 0)) for stat in stats + ] + available = [int(stat.get("available_bytes", 0)) for stat in stats] + metadata[f"{name}_micro_window_count"] = len(stats) + metadata[f"{name}_micro_global_count_first"] = global_counts[0] + metadata[f"{name}_micro_global_count_last"] = global_counts[-1] + metadata[f"{name}_micro_global_count_min"] = min(global_counts) + metadata[f"{name}_micro_global_count_max"] = max(global_counts) + metadata[f"{name}_micro_local_count_min"] = min(local_counts) + metadata[f"{name}_micro_local_count_max"] = max(local_counts) + metadata[f"{name}_micro_packed_tokens_min"] = min(packed_tokens) + metadata[f"{name}_micro_packed_tokens_max"] = max(packed_tokens) + metadata[f"{name}_micro_rejected_candidates_total"] = sum(rejected) + metadata[f"{name}_micro_estimated_required_gb_max"] = round( + max(estimated_required) / 1024**3, 3 + ) + metadata[f"{name}_micro_available_gb_min"] = round(min(available) / 1024**3, 3) + metadata[f"{name}_micro_cold_start_count"] = sum( + int(bool(stat["cold_start"])) for stat in stats + ) + metadata[f"{name}_micro_global_counts_head"] = ",".join( + str(count) for count in global_counts[:8] + ) + + +def _record_profile_stats( + metadata: dict[str, object], + name: str, + stats: Sequence[dict[str, int | bool | float]], +) -> None: + fields = sorted( + { + key + for stat in stats + for key, value in stat.items() + if key.endswith("_ms") and isinstance(value, int | float) + } + ) + for field in fields: + total = sum(float(stat.get(field, 0.0)) for stat in stats) + metadata[f"{name}_{field}_sum"] = round(total, 3) + metadata[f"{name}_{field}_max"] = round( + max((float(stat.get(field, 0.0)) for stat in stats), default=0.0), + 3, + ) + call_fields = sorted( + { + key + for stat in stats + for key, value in stat.items() + if key.endswith("_calls") and isinstance(value, int | float) + } + ) + for field in call_fields: + metadata[f"{name}_{field}_sum"] = int( + sum(int(stat.get(field, 0)) for stat in stats) + ) + metadata[f"{name}_{field}_max"] = int( + max((int(stat.get(field, 0)) for stat in stats), default=0) + ) + + +def _training_step( + rank: TrainerRank, + loss_fn: Callable[[], torch.Tensor], + *, + params: AdamParams, + offload_manager: object | None, +) -> dict[str, float]: + if offload_manager is None: + return _training_step_body(rank, loss_fn, params=params) + with offload_manager.job(): # type: ignore[attr-defined] + return _training_step_body(rank, loss_fn, params=params) + + +def _training_step_body( + rank: TrainerRank, + loss_fn: Callable[[], torch.Tensor], + *, + params: AdamParams, +) -> dict[str, float]: + rank.zero_grad() + loss = loss_fn() + loss.backward() + return rank.optim_step(params=params, scale_grads=1.0) + + +def _make_offload_manager(runtime: object) -> object: + from art.megatron.training.streaming_weight_offload import ( + StreamingWeightOffloadConfig, + ) + from art.megatron.training.weight_offload import WeightOffloadManager + + manager = WeightOffloadManager.from_config( + model=getattr(runtime, "model"), + rank=dist.get_rank(), + compile_enabled=bool(getattr(runtime, "transformer_layers_compiled", False)), + offload_between_jobs=True, + streaming_config=StreamingWeightOffloadConfig(enabled=False), + ) + manager.install() + manager.after_job() + return manager + + +def _target_correctness_metrics( + rank: TrainerRank, + items: object, + prepared: object, +) -> dict[str, float]: + for chunk in rank.runtime.model: + chunk.eval() + with torch.no_grad(): + labels = _packed_labels(items, prepared) + native_logprobs = _native_target_logprobs(rank, items, prepared, labels) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + head_outputs = rank._project_head(items, prepared, hidden) + abs_diff_sum = torch.tensor(0.0, device=rank.device) + reference_abs_sum = torch.tensor(0.0, device=rank.device) + value_count = torch.tensor(0.0, device=rank.device) + max_abs_diff = torch.tensor(0.0, device=rank.device) + for native, candidate in zip( + native_logprobs, + (output.target_logprobs for output in head_outputs), + strict=True, + ): + if candidate is None: + continue + diff = (candidate.float() - native.float()).abs() + if int(diff.numel()) == 0: + continue + abs_diff_sum += diff.sum() + reference_abs_sum += native.float().abs().sum() + value_count += float(diff.numel()) + max_abs_diff = torch.maximum(max_abs_diff, diff.max()) + sums = torch.stack((abs_diff_sum, reference_abs_sum, value_count)) + dist.all_reduce(sums, op=dist.ReduceOp.SUM) + dist.all_reduce(max_abs_diff, op=dist.ReduceOp.MAX) + mean_abs_pct = float((sums[0] / torch.clamp(sums[1], min=1e-18)).item()) + max_abs = float(max_abs_diff.item()) + return { + "target_hidden_vs_native_mean_abs_pct": mean_abs_pct, + "target_hidden_vs_native_max_abs_diff": max_abs, + "target_hidden_vs_native_value_count": float(sums[2].item()), + } + + +def _native_target_logprobs( + rank: TrainerRank, + items: object, + prepared: object, + labels: torch.Tensor, +) -> list[torch.Tensor]: + from art.megatron.train import _placeholder_attention_mask + + per_token_loss = rank.runtime.model[0]( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + attention_mask=_placeholder_attention_mask(rank.device), + labels=labels, + packed_seq_params=prepared.packed_seq_params, + **rank.runtime.model_support_handler.get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ), + ) + flat_logprobs = -per_token_loss.reshape(-1) + outputs: list[torch.Tensor] = [] + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + raise RuntimeError("native target oracle requires labels") + item_labels = item.labels.to(device=rank.device).index_select( + 0, + source_positions.to(device=rank.device), + ) + outputs.append( + flat_logprobs.index_select(0, positions.to(device=rank.device)).masked_fill( + item_labels == -100, + 0.0, + ) + ) + return outputs + + +def _adapter_sanity_metrics( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + adapter_slots: int, +) -> dict[str, float]: + target_request = next( + (request for request in requests if request.target_tokens is not None), + None, + ) + if target_request is None: + return {"adapter_sanity_skipped": 1.0} + base_request = ForwardInput( + input_tokens=target_request.input_tokens, + target_tokens=target_request.target_tokens, + checkpoint=None, + ) + slot_request = ForwardInput( + input_tokens=target_request.input_tokens, + target_tokens=target_request.target_tokens, + checkpoint="S0", + ) + for chunk in rank.runtime.model: + chunk.eval() + with torch.no_grad(): + base_output = rank.dp_rank_forward([base_request])[0] + slot_output = rank.dp_rank_forward([slot_request])[0] + if base_output.target_logprobs is None or slot_output.target_logprobs is None: + raise RuntimeError("adapter sanity target outputs were not produced") + output_diff = _mean_abs_pct( + base_output.target_logprobs, + slot_output.target_logprobs, + ) + output_max = float( + (slot_output.target_logprobs.float() - base_output.target_logprobs.float()) + .abs() + .max() + .item() + ) + + slot_params = list(rank._checkpoint_slot_params_by_name["S0"]) + other_params = ( + list(rank._checkpoint_slot_params_by_name["S1"]) if adapter_slots > 1 else [] + ) + before = [param.detach().clone() for param in slot_params] + other_before = [param.detach().clone() for param in other_params] + for chunk in rank.runtime.model: + chunk.train() + rank.zero_grad() + loss = _target_requests_loss(rank, [slot_request]) + loss.backward() + grad_sq = torch.tensor(0.0, device=rank.device) + for param in slot_params: + if param.grad is not None: + grad_sq = grad_sq + param.grad.detach().float().square().sum() + grad_norm = torch.sqrt(grad_sq) + rank.optim_step(params=params, checkpoints=["S0"]) + slot_delta = sum( + float((param.detach().float() - old.float()).abs().sum().item()) + for param, old in zip(slot_params, before, strict=True) + ) + other_delta = sum( + float((param.detach().float() - old.float()).abs().sum().item()) + for param, old in zip(other_params, other_before, strict=True) + ) + values = torch.tensor( + [output_diff, output_max, float(grad_norm.item()), slot_delta, other_delta], + device=rank.device, + ) + dist.all_reduce(values, op=dist.ReduceOp.MAX) + return { + "adapter_sanity_output_mean_abs_pct": float(values[0].item()), + "adapter_sanity_output_max_abs_diff": float(values[1].item()), + "adapter_sanity_grad_norm": float(values[2].item()), + "adapter_sanity_stepped_slot_delta": float(values[3].item()), + "adapter_sanity_unselected_slot_delta": float(values[4].item()), + } + + +def _runtime_output_shape(runtime: object) -> tuple[int, int, int]: + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + hidden_size = int( + getattr(provider, "hidden_size", None) + or getattr(getattr(model, "config", None), "hidden_size", 0) + ) + vocab_size = int( + getattr(getattr(model, "config", None), "padded_vocab_size", None) + or getattr(model, "vocab_size", 0) + ) + dtype_size = next(getattr(runtime, "model")[0].parameters()).element_size() + if hidden_size <= 0 or vocab_size <= 0: + raise RuntimeError( + f"could not infer output shape: hidden_size={hidden_size}, " + f"vocab_size={vocab_size}" + ) + return hidden_size, vocab_size, dtype_size + + +def _request_output_gb( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> float: + return ( + sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ) + / 1024**3 + ) + + +def _request_output_bytes( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> int: + seq_len = int(request.input_tokens.numel()) + bytes_total = 0 + if request.target_tokens is not None: + bytes_total += int(request.target_tokens.numel()) * 4 + if request.top_k is not None: + bytes_total += seq_len * int(request.top_k) * (4 + 8) + if request.logits: + bytes_total += seq_len * vocab_size * dtype_size + if request.hidden_states: + bytes_total += seq_len * hidden_size * dtype_size + return bytes_total + + +def _logits_requests( + requests: Sequence[ForwardInput[torch.Tensor, None, None, None]], +) -> list[ForwardInput[None, None, torch.Tensor, None]]: + return [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ] + + +def _rate_units( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + stats: dict[str, int | str], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> dict[str, int]: + return { + "packed_tokens": int(stats.get("packed_tokens", 0)), + "logical_tokens": int(stats.get("logical_tokens", 0)), + "target_values": _target_value_count(requests), + "output_bytes": sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ), + } + + +def _target_value_count( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> int: + count = 0 + for request in requests: + if request.target_tokens is not None: + count += int((request.target_tokens != -100).sum().item()) + return count + + +def _rate_metrics( + results: dict[str, float], + units_by_name: dict[str, dict[str, int]], +) -> dict[str, float]: + suffixes = { + "packed_tokens": "packed_tok_s", + "logical_tokens": "logical_tok_s", + "target_values": "target_logprob_s", + "output_bytes": "output_gb_s", + } + metrics: dict[str, float] = {} + for key, ms in results.items(): + if ms <= 0: + continue + name = key.removesuffix("_ms") + units = units_by_name.get(name, {}) + for unit_key, suffix in suffixes.items(): + value = int(units.get(unit_key, 0)) + if value <= 0: + continue + scale = 1024**3 if unit_key == "output_bytes" else 1 + metrics[f"{name}_{suffix}"] = round(value * 1000.0 / ms / scale, 3) + return metrics + + +def _packed_labels(items: object, prepared: object) -> torch.Tensor: + labels = torch.full_like(prepared.tokens, -100) + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + continue + labels.reshape(-1)[positions.to(device=labels.device)] = item.labels.to( + device=labels.device + ).index_select(0, source_positions.to(device=labels.device)) + return labels + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_topology_check.py b/dev/trainer_rank_topology_check.py new file mode 100644 index 000000000..22b8b286a --- /dev/null +++ b/dev/trainer_rank_topology_check.py @@ -0,0 +1,1238 @@ +from __future__ import annotations + +from dataclasses import dataclass +import json +import os +import time + +import torch +import torch.distributed as dist +import typer + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + _batch_seq_logits, + _language_model, + _select_positions, +) + + +@dataclass +class CheckOutput: + source_positions: torch.Tensor + target_logprobs: torch.Tensor | None + top_k: TopK | None + logits: torch.Tensor | None + hidden_states: torch.Tensor | None + + +@dataclass(frozen=True) +class DiffStats: + max_abs_diff: float = 0.0 + mean_abs_pct: float = 0.0 + + def merge(self, other: DiffStats) -> DiffStats: + return DiffStats( + max_abs_diff=max(self.max_abs_diff, other.max_abs_diff), + mean_abs_pct=max(self.mean_abs_pct, other.mean_abs_pct), + ) + + +def _gather_target_logprobs( + logprobs: torch.Tensor, + labels: torch.Tensor, +) -> torch.Tensor: + if int(labels.shape[0]) == 0: + return torch.empty(labels.shape, device=logprobs.device, dtype=logprobs.dtype) + flat_labels = labels.clamp_min(0).reshape(int(labels.shape[0]), -1) + selected = logprobs.gather(1, flat_labels).reshape(labels.shape) + return selected.masked_fill(labels == -100, 0.0) + + +def _empty_logits_like_positions( + positions: torch.Tensor, + model: object, + like: torch.Tensor, +) -> torch.Tensor: + vocab_size = getattr( + getattr(model, "config", None), + "padded_vocab_size", + None, + ) or getattr(model, "vocab_size", None) + if vocab_size is None: + raise RuntimeError("could not determine full padded vocabulary size") + return torch.empty( + (int(positions.numel()), int(vocab_size)), + device=like.device, + dtype=like.dtype, + ) + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + head_chunk_a: int = 17, + head_chunk_b: int = 512, + max_prefix_depth: int = 1, + request_case: str = "shared", + stress_tokens: int = 0, + max_unpacked_output_gb: float = 0.25, + debug_output: str = "none", + compare_independent: bool = False, + compare_same_layout: bool = False, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + torch.manual_seed(1234) + provider_configure = ( + (lambda provider: setattr(provider, "num_layers", layers)) + if layers > 0 + else None + ) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=provider_configure, + print_env=dist.get_rank() == 0, + ) + for chunk in runtime.model: + chunk.eval() + + requests = ( + _stress_requests(stress_tokens) + if stress_tokens > 0 + else _requests(request_case) + ) + requests = _debug_output_requests(requests, debug_output) + unpacked_output_gb = _estimate_unpacked_output_gb(requests, runtime) + if max_unpacked_output_gb > 0 and unpacked_output_gb > max_unpacked_output_gb: + if dist.get_rank() == 0: + print( + json.dumps( + { + "world": dist.get_world_size(), + "dp": int(ps.get_data_parallel_world_size()), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "stress_tokens": stress_tokens, + "estimated_unpacked_output_gb": round( + unpacked_output_gb, 3 + ), + "max_unpacked_output_gb": max_unpacked_output_gb, + "skipped": "unpacked_output_cap", + }, + sort_keys=True, + ), + flush=True, + ) + dist.barrier() + return + dp_rank = int(ps.get_data_parallel_rank()) + dp_size = int(ps.get_data_parallel_world_size()) + local_pairs = [ + (index, request) + for index, request in enumerate(requests) + if index % dp_size == dp_rank + ] + local_requests = [request for _, request in local_pairs] + + rank_a = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_a, + shared_prefix_max_depth=max_prefix_depth, + ) + rank_b = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_b, + shared_prefix_max_depth=max_prefix_depth, + ) + independent_outputs: list[CheckOutput] | None = None + same_layout_outputs: list[CheckOutput] | None = None + + torch.cuda.reset_peak_memory_stats() + diff_stats = DiffStats() + with torch.no_grad(): + started_at = time.perf_counter() + if request_case == "target_only": + _debug("forward-target-only") + outputs_a = list(rank_a.dp_rank_forward(local_requests)) + outputs_b = list(rank_b.dp_rank_forward(local_requests)) + oracle_outputs, actual_source_positions = _packed_oracle( + rank_a, local_requests + ) + elif stress_tokens > 0: + _debug("forward-a") + outputs_a = list(rank_a.dp_rank_forward(local_requests)) + outputs_b = outputs_a + actual_source_positions = _source_positions(rank_a, local_requests) + oracle_outputs = [ + _as_check_output(source_positions, output) + for source_positions, output in zip( + actual_source_positions, + outputs_a, + strict=True, + ) + ] + else: + _debug("forward-shared") + ( + outputs_a, + outputs_b, + oracle_outputs, + actual_source_positions, + ) = _shared_hidden_check(rank_a, rank_b, local_requests) + if compare_independent and request_case in {"shared", "unique", "deep"}: + independent_outputs = _independent_check_outputs( + rank_a, local_requests + ) + if int(ps.get_context_parallel_world_size()) <= 1: + for index, (actual, independent) in enumerate( + zip(outputs_a, independent_outputs, strict=True) + ): + diff_stats = diff_stats.merge( + _assert_close( + actual, + independent, + f"independent[{index}]", + ), + ) + if compare_same_layout and request_case in {"shared", "unique", "deep"}: + same_layout_outputs = _same_layout_check_outputs( + rank_a, + local_requests, + ) + for index, (actual, same_layout) in enumerate( + zip(outputs_a, same_layout_outputs, strict=True) + ): + diff_stats = diff_stats.merge( + _assert_close( + actual, + same_layout, + f"same_layout[{index}]", + ), + ) + _debug("compare") + elapsed_s = time.perf_counter() - started_at + + peak_memory_gb = torch.tensor( + torch.cuda.max_memory_allocated() / 1024**3, + device=rank_a.device, + ) + for index, (actual, chunked, oracle) in enumerate( + zip(outputs_a, outputs_b, oracle_outputs, strict=True) + ): + if int(oracle.source_positions.numel()) == 0: + continue + diff_stats = diff_stats.merge( + _assert_close(actual, chunked, f"chunk[{index}]"), + ) + diff_stats = diff_stats.merge( + _assert_close(actual, oracle, f"oracle[{index}]"), + ) + + diff_tensor = torch.tensor( + [diff_stats.max_abs_diff, diff_stats.mean_abs_pct], + device=rank_a.device, + ) + dist.all_reduce(diff_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(peak_memory_gb, op=dist.ReduceOp.MAX) + max_diff_value = float(diff_tensor[0].item()) + mean_abs_pct_value = float(diff_tensor[1].item()) + records = _records( + local_pairs=local_pairs, + actual_outputs=outputs_a, + actual_source_positions=actual_source_positions, + oracle_outputs=oracle_outputs, + independent_outputs=independent_outputs, + rank=int(dist.get_rank()), + dp=dp_rank, + tp=int(ps.get_tensor_model_parallel_rank()), + cp=int(ps.get_context_parallel_rank()), + ) + gathered: list[list[dict[str, object]] | None] = [None] * dist.get_world_size() + _debug("all-gather") + dist.all_gather_object(gathered, records) + _debug("reconstruct") + reconstruction_error: str | None = None + if dist.get_rank() == 0: + seen = { + record["input_index"] + for rank_records in gathered + for record in rank_records or [] + } + if seen != set(range(len(requests))): + reconstruction_error = f"DP reconstruction missed inputs: {seen}" + else: + try: + reconstructed_stats = _assert_reconstructed(gathered, requests) + max_diff_value = max( + max_diff_value, + reconstructed_stats.max_abs_diff, + ) + mean_abs_pct_value = max( + mean_abs_pct_value, + reconstructed_stats.mean_abs_pct, + ) + except AssertionError as exc: + reconstruction_error = str(exc) + if reconstruction_error is None: + print( + json.dumps( + { + "world": dist.get_world_size(), + "dp": dp_size, + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "mean_abs_pct": mean_abs_pct_value, + "max_abs_diff": max_diff_value, + "records": sum( + len(rank_records or []) for rank_records in gathered + ), + "same_layout": compare_same_layout, + "stress_tokens": stress_tokens, + "estimated_unpacked_output_gb": round( + unpacked_output_gb, 3 + ), + "elapsed_s": round(elapsed_s, 3), + "peak_memory_gb": round(float(peak_memory_gb.item()), 3), + }, + sort_keys=True, + ), + flush=True, + ) + errors = [reconstruction_error] + dist.broadcast_object_list(errors, src=0) + if errors[0] is not None: + raise AssertionError(errors[0]) + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _requests( + request_case: str = "shared", +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if request_case not in {"shared", "target_only", "unique", "deep"}: + raise ValueError( + "request_case must be 'shared', 'target_only', 'unique', or 'deep'" + ) + rows = [ + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([11, 12, 13, 14, 24, 25]), + torch.tensor([11, 12, 13, 14, 24, 26]), + torch.tensor([11, 12, 13, 27]), + torch.tensor([31, 32, 33, 34]), + torch.tensor([31, 32, 33, 35]), + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([41, 42, 43]), + torch.tensor([41, 42, 44, 45]), + torch.tensor([51, 52, 53, 54, 55]), + torch.tensor([61, 62, 63]), + torch.tensor([61, 62, 64, 65]), + torch.tensor([71, 72]), + torch.tensor([81, 82, 83, 84]), + torch.tensor([91, 92, 93]), + torch.tensor([101, 102, 103, 104, 105]), + ] + if request_case == "deep": + rows = _deep_rows() + if request_case == "unique": + rows = [row + 1000 * index for index, row in enumerate(rows)] + if request_case == "target_only": + target_only_labels = [_labels(row, 0) for row in rows] + target_only_labels[0][2] = -100 + target_only_labels[3][1] = -100 + target_only_labels[10][0] = -100 + return [ + ForwardInput(input_tokens=row, target_tokens=label) + for row, label in zip(rows, target_only_labels, strict=True) + ] + + labels = [_labels(row, offset) for offset, row in enumerate(rows)] + labels[0][2] = -100 + labels[3][1] = -100 + labels[10][0] = -100 + multi_labels = torch.stack((labels[1], (labels[1] + 17) % 1000), dim=1) + multi_labels[2, 1] = -100 + requests = [] + for mask, row in enumerate(rows): + target_tokens = None + if mask & 1: + target_tokens = multi_labels if mask == 1 else labels[mask] + requests.append( + ForwardInput( + input_tokens=row, + target_tokens=target_tokens, + top_k=3 if mask & 2 else None, + logits=bool(mask & 4), + hidden_states=bool(mask & 8), + ) + ) + return requests + + +def _debug_output_requests( + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + debug_output: str, +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if debug_output == "none": + return requests + if debug_output == "hidden": + return [ + ForwardInput(input_tokens=request.input_tokens, hidden_states=True) + for request in requests + ] + if debug_output == "logits": + return [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ] + if debug_output == "target": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=_labels(request.input_tokens, 0), + ) + for request in requests + ] + if debug_output == "topk": + return [ + ForwardInput(input_tokens=request.input_tokens, top_k=3) + for request in requests + ] + if debug_output == "target_topk": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=_labels(request.input_tokens, 0), + top_k=3, + ) + for request in requests + ] + if debug_output == "mixed_no_topk": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + logits=request.logits, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_no_logits": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + top_k=request.top_k, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_no_targets": + return [ + ForwardInput( + input_tokens=request.input_tokens, + top_k=request.top_k, + logits=request.logits, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_targets_only": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + ) + for request in requests + ] + if debug_output == "mixed_targets_hidden": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_targets_logits": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + logits=request.logits, + ) + for request in requests + ] + raise ValueError( + "debug_output must be 'none', 'hidden', 'logits', 'target', 'topk', " + "'target_topk', 'mixed_no_topk', 'mixed_no_logits', 'mixed_no_targets', " + "'mixed_targets_only', 'mixed_targets_hidden', or 'mixed_targets_logits'" + ) + + +def _deep_rows() -> list[torch.Tensor]: + return [ + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([11, 12, 13, 14, 15, 16, 18]), + torch.tensor([11, 12, 13, 14, 15, 19]), + torch.tensor([11, 12, 13, 14, 20]), + torch.tensor([11, 12, 21]), + torch.tensor([31, 32, 33, 34, 35]), + torch.tensor([31, 32, 33, 34, 36]), + torch.tensor([31, 32, 33, 37]), + torch.tensor([41, 42, 43]), + torch.tensor([41, 42, 44]), + torch.tensor([51, 52, 53, 54]), + torch.tensor([61, 62]), + torch.tensor([71, 72, 73, 74, 75]), + torch.tensor([71, 72, 73, 76]), + torch.tensor([81]), + torch.tensor([91, 92, 93]), + ] + + +def _stress_requests( + token_count: int, +) -> list[ForwardInput[None, None, None, torch.Tensor]]: + if token_count < 8: + raise ValueError("stress_tokens must be >= 8") + prefix_len = token_count // 2 + tail_len = max(1, token_count // 4) + prefix = _stress_tokens(0, prefix_len) + return [ + ForwardInput( + input_tokens=torch.cat((prefix, _stress_tokens(10_000, tail_len))), + hidden_states=True, + ), + ForwardInput( + input_tokens=torch.cat((prefix, _stress_tokens(20_000, tail_len))), + hidden_states=True, + ), + ForwardInput(input_tokens=_stress_tokens(30_000, tail_len), hidden_states=True), + ] + + +def _stress_tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _estimate_unpacked_output_gb( + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + runtime: object, +) -> float: + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + hidden_size = int( + getattr(provider, "hidden_size", None) + or getattr(getattr(model, "config", None), "hidden_size", 0) + ) + vocab_size = int( + getattr(getattr(model, "config", None), "padded_vocab_size", None) + or getattr(model, "vocab_size", 0) + ) + dtype_size = next(getattr(runtime, "model")[0].parameters()).element_size() + bytes_total = sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ) + return bytes_total / 1024**3 + + +def _request_output_bytes( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> int: + seq_len = int(request.input_tokens.numel()) + bytes_total = 0 + if request.target_tokens is not None: + bytes_total += int(request.target_tokens.numel()) * 4 + if request.top_k is not None: + bytes_total += seq_len * int(request.top_k) * (4 + 8) + if request.logits: + bytes_total += seq_len * vocab_size * dtype_size + if request.hidden_states: + bytes_total += seq_len * hidden_size * dtype_size + return bytes_total + + +def _debug(label: str) -> None: + if os.environ.get("TRAINER_RANK_CHECK_DEBUG") != "1": + return + print(f"[rank{dist.get_rank()}] {label}", flush=True) + + +def _labels(tokens: torch.Tensor, offset: int) -> torch.Tensor: + return ((tokens * 7 + 3 + offset) % 1000).to(dtype=torch.long) + + +def _packed_oracle( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[list[CheckOutput], tuple[torch.Tensor, ...]]: + items = [rank._forward_item(request) for request in requests] + prepared = rank._prepare_packed_forward( + pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + ) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + return ( + _packed_oracle_from_hidden(rank, items, prepared, hidden), + prepared.source_positions_by_item, + ) + + +def _shared_hidden_check( + rank_a: TrainerRank, + rank_b: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[ + list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + list[CheckOutput], + tuple[torch.Tensor, ...], +]: + items = [rank_a._forward_item(request) for request in requests] + prepared = rank_a._prepare_packed_forward( + pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank_a.shared_prefix_max_depth, + ) + ) + hidden = rank_a._gather_sequence_parallel_hidden(rank_a._decoder_hidden(prepared)) + outputs_a = _outputs_from_hidden(rank_a, items, prepared, hidden) + outputs_b = _outputs_from_hidden(rank_b, items, prepared, hidden) + oracle = _packed_oracle_from_hidden(rank_a, items, prepared, hidden) + return ( + outputs_a, + outputs_b, + oracle, + prepared.source_positions_by_item, + ) + + +def _independent_check_outputs( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> list[CheckOutput]: + outputs: list[CheckOutput] = [] + for request in requests: + source_positions = _source_positions(rank, [request])[0] + outputs.append( + _as_check_output(source_positions, rank.dp_rank_forward([request])[0]) + ) + return outputs + + +def _same_layout_check_outputs( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> list[CheckOutput]: + items = [rank._forward_item(request) for request in requests] + batch = pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + outputs = [] + for index, positions in enumerate(batch.positions_by_sequence): + mutated = _mutated_batch(batch, keep_positions=positions) + prepared = rank._prepare_packed_forward(mutated) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + mutated_outputs = _outputs_from_hidden(rank, items, prepared, hidden) + outputs.append( + _as_check_output( + prepared.source_positions_by_item[index], + mutated_outputs[index], + ) + ) + return outputs + + +def _mutated_batch( + batch: SharedPrefixPack, + *, + keep_positions: torch.Tensor, +) -> SharedPrefixPack: + tokens = batch.tokens.clone() + mutate = torch.ones(int(tokens.shape[1]), dtype=torch.bool, device=tokens.device) + mutate[keep_positions.to(device=tokens.device)] = False + replacement = ( + torch.arange(int(tokens.shape[1]), dtype=tokens.dtype, device=tokens.device) + + 50_000 + ) + tokens[0, mutate] = replacement[mutate] % 100_000 + return SharedPrefixPack( + tokens=tokens, + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + position_ids=batch.position_ids, + positions_by_sequence=batch.positions_by_sequence, + ) + + +def _outputs_from_hidden( + rank: TrainerRank, + items: list[object], + prepared: object, + hidden: torch.Tensor, +) -> list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + return rank._project_head(items, prepared, hidden) + + +def _packed_oracle_from_hidden( + rank: TrainerRank, + items: list[object], + prepared: object, + hidden: torch.Tensor, +) -> list[CheckOutput]: + model = _language_model(rank.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + + outputs: list[CheckOutput] = [] + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + needs_projection = ( + item.labels is not None or item.request.logits or item.request.top_k + ) + all_logits = None + if needs_projection: + if int(positions.numel()): + local_logits = rank._local_logits_from_hidden_rows( + model, + _select_positions(hidden, positions), + output_weight=output_weight, + ) + all_logits = _batch_seq_logits( + rank._gather_tensor_parallel_logits(local_logits.unsqueeze(1)), + seq_len=int(positions.numel()), + ).squeeze(0) + else: + all_logits = _empty_logits_like_positions(positions, model, hidden) + logprobs = ( + None + if all_logits is None + else torch.log_softmax(all_logits.float(), dim=-1) + ) + + target_logprobs = None + if item.labels is not None: + if logprobs is None: + raise RuntimeError("target_logprobs oracle requires logprobs") + labels = item.labels.to(device=logprobs.device).index_select( + 0, source_positions.to(device=logprobs.device) + ) + target_logprobs = _gather_target_logprobs(logprobs, labels) + + top_k = None + if item.request.top_k is not None: + if all_logits is None: + raise RuntimeError("top_k oracle requires logits") + log_z = torch.logsumexp(all_logits.float(), dim=-1) + values, tokens = torch.topk( + all_logits.float(), k=item.request.top_k, dim=-1 + ) + top_k = TopK(logprobs=values - log_z.unsqueeze(1), tokens=tokens) + + hidden_states = None + if item.request.hidden_states: + hidden_states = _select_positions(hidden, positions) + + outputs.append( + CheckOutput( + source_positions=source_positions, + target_logprobs=target_logprobs, + top_k=top_k, + logits=all_logits if item.request.logits else None, + hidden_states=hidden_states, + ) + ) + return outputs + + +def _source_positions( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[torch.Tensor, ...]: + items = [rank._forward_item(request) for request in requests] + prepared = rank._prepare_packed_forward( + pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + ) + return prepared.source_positions_by_item + + +def _as_check_output( + source_positions: torch.Tensor, + output: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], +) -> CheckOutput: + return CheckOutput( + source_positions=source_positions, + target_logprobs=output.target_logprobs, + top_k=output.top_k, + logits=output.logits, + hidden_states=output.hidden_states, + ) + + +def _records( + *, + local_pairs: list[ + tuple[ + int, + ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, + ], + ] + ], + actual_outputs: list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + actual_source_positions: tuple[torch.Tensor, ...], + oracle_outputs: list[CheckOutput], + independent_outputs: list[CheckOutput] | None, + rank: int, + dp: int, + tp: int, + cp: int, +) -> list[dict[str, object]]: + records: list[dict[str, object]] = [] + independent_records: list[CheckOutput | None] = ( + independent_outputs + if independent_outputs is not None + else [None] * len(local_pairs) + ) + for local_index, ( + (input_index, _), + actual, + actual_sources, + oracle, + independent, + ) in enumerate( + zip( + local_pairs, + actual_outputs, + actual_source_positions, + oracle_outputs, + independent_records, + strict=True, + ) + ): + records.append( + { + "input_index": input_index, + "local_index": local_index, + "rank": rank, + "dp": dp, + "tp": tp, + "cp": cp, + "actual": _cpu_record(actual_sources, actual), + "oracle": _cpu_record(oracle.source_positions, oracle), + "independent": ( + None + if independent is None + else _cpu_record(independent.source_positions, independent) + ), + } + ) + return records + + +def _cpu_record( + source_positions: torch.Tensor, + output: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + | CheckOutput, +) -> dict[str, torch.Tensor | None]: + return { + "source_positions": source_positions.cpu(), + "target_logprobs": _cpu(output.target_logprobs), + "logits": _cpu(output.logits), + "hidden_states": _cpu(output.hidden_states), + "top_k_logprobs": None if output.top_k is None else _cpu(output.top_k.logprobs), + "top_k_tokens": None if output.top_k is None else _cpu(output.top_k.tokens), + } + + +def _cpu(tensor: torch.Tensor | None) -> torch.Tensor | None: + return None if tensor is None else tensor.detach().cpu() + + +def _assert_reconstructed( + gathered: list[list[dict[str, object]] | None], + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> DiffStats: + diff_stats = DiffStats() + records = [ + record + for rank_records in gathered + for record in rank_records or [] + if record["tp"] == 0 + ] + for input_index, request in enumerate(requests): + _debug(f"reconstruct-input-{input_index}") + actual = [ + record["actual"] + for record in records + if record["input_index"] == input_index + ] + oracle = [ + record["oracle"] + for record in records + if record["input_index"] == input_index + ] + independent = [ + record["independent"] + for record in records + if record["input_index"] == input_index + and record.get("independent") is not None + ] + length = int(request.input_tokens.numel()) + for key in ("target_logprobs", "logits", "hidden_states", "top_k_logprobs"): + _debug(f"reconstruct-input-{input_index}-{key}") + _debug(f"reconstruct-input-{input_index}-{key}-assemble-actual") + actual_value = _assemble(actual, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-actual-" + f"{_tensor_summary(actual_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-assemble-oracle") + oracle_value = _assemble(oracle, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-oracle-" + f"{_tensor_summary(oracle_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-oracle") + diff_stats = diff_stats.merge( + _tensor_diff_value( + actual_value, + oracle_value, + f"reconstructed[{input_index}].{key}", + ), + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-oracle-done") + if independent: + _debug(f"reconstruct-input-{input_index}-{key}-assemble-independent") + independent_value = _assemble(independent, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-independent-" + f"{_tensor_summary(independent_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-independent") + diff_stats = diff_stats.merge( + _tensor_diff_value( + actual_value, + independent_value, + f"independent[{input_index}].{key}", + ), + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-independent-done") + _debug(f"reconstruct-input-{input_index}-{key}-done") + actual_tokens = _assemble(actual, "top_k_tokens", length) + oracle_tokens = _assemble(oracle, "top_k_tokens", length) + if actual_tokens is None or oracle_tokens is None: + if actual_tokens is not oracle_tokens: + raise AssertionError( + f"reconstructed[{input_index}].top_k None mismatch" + ) + elif not torch.equal(actual_tokens, oracle_tokens): + actual_logprobs = _assemble(actual, "top_k_logprobs", length) + oracle_logprobs = _assemble(oracle, "top_k_logprobs", length) + if ( + actual_logprobs is None + or oracle_logprobs is None + or _tensor_diff_value( + actual_logprobs, + oracle_logprobs, + f"reconstructed[{input_index}].top_k.logprobs", + ).max_abs_diff + > 5e-6 + ): + raise AssertionError( + f"reconstructed[{input_index}].top_k.tokens mismatch" + ) + if independent: + independent_tokens = _assemble(independent, "top_k_tokens", length) + if actual_tokens is None or independent_tokens is None: + if actual_tokens is not independent_tokens: + raise AssertionError( + f"independent[{input_index}].top_k None mismatch" + ) + elif not torch.equal(actual_tokens, independent_tokens): + actual_logprobs = _assemble(actual, "top_k_logprobs", length) + independent_logprobs = _assemble( + independent, + "top_k_logprobs", + length, + ) + if ( + actual_logprobs is None + or independent_logprobs is None + or _tensor_diff_value( + actual_logprobs, + independent_logprobs, + f"independent[{input_index}].top_k.logprobs", + ).max_abs_diff + > 5e-6 + ): + raise AssertionError( + f"independent[{input_index}].top_k.tokens mismatch" + ) + return diff_stats + + +def _assemble( + records: list[object], + key: str, + length: int, +) -> torch.Tensor | None: + typed_records = [record for record in records if isinstance(record, dict)] + values = [record[key] for record in typed_records if record[key] is not None] + if not values: + return None + first = values[0] + if not isinstance(first, torch.Tensor): + raise TypeError(key) + output = torch.empty((length, *first.shape[1:]), dtype=first.dtype) + filled = torch.zeros(length, dtype=torch.bool) + for record in typed_records: + value = record[key] + if value is None: + continue + if not isinstance(value, torch.Tensor): + raise TypeError(key) + positions = record["source_positions"] + if not isinstance(positions, torch.Tensor): + raise TypeError("source_positions") + output[positions] = value + filled[positions] = True + if not bool(filled.all().item()): + raise AssertionError(f"{key} reconstruction missed positions") + return output + + +def _tensor_summary(tensor: torch.Tensor | None) -> str: + if tensor is None: + return "None" + return f"shape={tuple(tensor.shape)} device={tensor.device} dtype={tensor.dtype}" + + +def _assert_close( + actual: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + expected: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + | CheckOutput, + label: str, +) -> DiffStats: + diffs = [ + _tensor_diff( + actual.target_logprobs, expected.target_logprobs, f"{label}.target_logprobs" + ) + ] + diffs.append(_tensor_diff(actual.logits, expected.logits, f"{label}.logits")) + diffs.append( + _tensor_diff( + actual.hidden_states, expected.hidden_states, f"{label}.hidden_states" + ) + ) + if actual.top_k is None or expected.top_k is None: + if actual.top_k is not expected.top_k: + raise AssertionError(f"{label}.top_k None mismatch") + else: + try: + top_k_diff = _tensor_diff( + actual.top_k.logprobs, + expected.top_k.logprobs, + f"{label}.top_k.logprobs", + ) + except AssertionError as exc: + flat_offset = int( + (actual.top_k.logprobs.float() - expected.top_k.logprobs.float()) + .abs() + .flatten() + .argmax() + ) + row, _ = divmod(flat_offset, int(actual.top_k.logprobs.shape[1])) + raise AssertionError( + f"{exc}; actual_row={actual.top_k.logprobs[row].tolist()} " + f"expected_row={expected.top_k.logprobs[row].tolist()} " + f"actual_tokens={actual.top_k.tokens[row].tolist()} " + f"expected_tokens={expected.top_k.tokens[row].tolist()}" + ) from exc + diffs.append(top_k_diff) + if ( + not torch.equal(actual.top_k.tokens, expected.top_k.tokens) + and top_k_diff.max_abs_diff > 5e-6 + ): + mismatch = torch.nonzero( + actual.top_k.tokens != expected.top_k.tokens, + as_tuple=False, + )[0] + row = int(mismatch[0].item()) + col = int(mismatch[1].item()) + raise AssertionError( + f"{label}.top_k.tokens mismatch at ({row}, {col}): " + f"actual={int(actual.top_k.tokens[row, col].item())} " + f"expected={int(expected.top_k.tokens[row, col].item())} " + f"actual_logprob={float(actual.top_k.logprobs[row, col].item())} " + f"expected_logprob={float(expected.top_k.logprobs[row, col].item())}" + ) + return _merge_diff_stats(diffs) + + +def _tensor_diff( + actual: torch.Tensor | None, + expected: torch.Tensor | None, + label: str, +) -> DiffStats: + return _tensor_diff_value(actual, expected, label) + + +def _tensor_diff_value( + actual: torch.Tensor | None, + expected: torch.Tensor | None, + label: str, +) -> DiffStats: + if actual is None or expected is None: + if actual is not expected: + raise AssertionError(f"{label} None mismatch") + return DiffStats() + if actual.shape != expected.shape: + raise AssertionError( + f"{label} shape mismatch: {actual.shape} != {expected.shape}" + ) + actual_for_diff = actual + expected_for_diff = expected + if torch.cuda.is_available(): + actual_for_diff = actual_for_diff.to(device="cuda") + expected_for_diff = expected_for_diff.to(device="cuda") + if actual_for_diff.numel(): + abs_diff = (actual_for_diff.float() - expected_for_diff.float()).abs() + max_abs_diff = float(abs_diff.max().item()) + denominator = float(expected_for_diff.float().abs().mean().item()) + mean_abs_pct = float(abs_diff.mean().item()) / (denominator + 1e-18) + else: + max_abs_diff = 0.0 + mean_abs_pct = 0.0 + mean_abs_pct_tolerance = 5e-3 if label.startswith("independent[") else 2e-5 + max_abs_tolerance = 0.0 + _debug( + f"{label} max_abs_diff={max_abs_diff} " + f"mean_abs_pct={mean_abs_pct} tolerance={mean_abs_pct_tolerance}" + ) + if mean_abs_pct > mean_abs_pct_tolerance: + raise AssertionError( + f"{label} mean_abs_pct {mean_abs_pct} max_abs_diff {max_abs_diff}" + ) + if max_abs_diff > max_abs_tolerance and not actual_for_diff.is_floating_point(): + raise AssertionError(f"{label} max diff {max_abs_diff}") + return DiffStats(max_abs_diff=max_abs_diff, mean_abs_pct=mean_abs_pct) + + +def _merge_diff_stats(stats: list[DiffStats]) -> DiffStats: + merged = DiffStats() + for stat in stats: + merged = merged.merge(stat) + return merged + + +if __name__ == "__main__": + typer.run(main) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py new file mode 100644 index 000000000..a3fe29bad --- /dev/null +++ b/src/art/trainer_rank/__init__.py @@ -0,0 +1,2199 @@ +from __future__ import annotations + +from collections.abc import ( + Callable, + Iterable, + Iterator, + Mapping, + MutableMapping, + Sequence, +) +from dataclasses import dataclass +import os +from typing import TYPE_CHECKING, Generic, Literal, ParamSpec, TypeVar, cast, overload + +import torch +import torch.distributed as dist + +from art.megatron.shared_prefix_packing import ( + SharedPrefixPack, + _local_position_pairs, + estimate_shared_prefix_packed_tokens, + pack_shared_prefixes, +) + +if TYPE_CHECKING: + from megatron.core.models.gpt.gpt_model import GPTModel + from megatron.core.optimizer import MegatronOptimizer, OptimizerConfig + from megatron.core.packed_seq_params import PackedSeqParams + + from art.megatron.context_parallel.types import ( + ArtContextParallelState, + ParallelTopology, + ) + from art.megatron.lora import LoRASlotRef + from art.megatron.shared_prefix_state import SharedPrefixAttentionState + from art.megatron.train import TrainingRuntime + + +@dataclass(frozen=True) +class AdamParams: + learning_rate: float + beta1: float = 0.9 + beta2: float = 0.99 + weight_decay: float = 0.1 + grad_clip_norm: float = 0.1 + + +@dataclass(frozen=True) +class TopK: + logprobs: torch.Tensor + tokens: torch.Tensor + + +LogprobsT = TypeVar("LogprobsT", bound=torch.Tensor | None, covariant=True) +TopKT = TypeVar("TopKT", bound=TopK | None, covariant=True) +LogitsT = TypeVar("LogitsT", bound=torch.Tensor | None, covariant=True) +HiddenStatesT = TypeVar("HiddenStatesT", bound=torch.Tensor | None, covariant=True) +T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") + +_COMPILED_FUNCTIONS: dict[Callable[..., object], Callable[..., object]] = {} +_MEMORY_PROFILE_TRUST_GROWTH = 8 + + +class _Unset: + pass + + +Unset = _Unset() +type AdapterSelection = str | None | _Unset + + +@dataclass(frozen=True) +class ForwardOutput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): + target_logprobs: LogprobsT + top_k: TopKT + logits: LogitsT + hidden_states: HiddenStatesT + + +@dataclass(slots=True) +class ForwardInput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): + input_tokens: torch.Tensor + target_tokens: torch.Tensor | None = None + top_k: int | None = None + logits: bool = False + hidden_states: bool = False + checkpoint: AdapterSelection = Unset + lora: AdapterSelection = Unset + + def __post_init__(self) -> None: + if self.top_k is not None and self.top_k < 1: + raise ValueError("top_k must be >= 1") + if self.checkpoint is not Unset and self.lora is not Unset: + raise ValueError("ForwardInput cannot set both checkpoint and lora") + + +type AnyForwardInput = ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, +] +type AnyForwardOutput = ForwardOutput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, +] +type ForwardInputs = AnyForwardInput | Iterable["ForwardInputs"] +type ForwardOutputs = AnyForwardOutput | Sequence["ForwardOutputs"] +ForwardInputsT = TypeVar("ForwardInputsT", bound=ForwardInputs) + + +@dataclass(frozen=True) +class MicroBatch(Generic[ForwardInputsT]): + inputs: Sequence[ForwardInputsT] + outputs: Sequence[ForwardOutputs] + indices: Sequence[int] + stats: "MicroBatchStats" + + def select(self, xs: Sequence[T]) -> Sequence[T]: + return [xs[i] for i in self.indices] + + +@dataclass(frozen=True) +class MicroBatchStats: + global_start: int + global_stop: int + global_count: int + local_count: int + packed_tokens: int + logical_tokens: int + estimated_required_bytes: int + available_bytes: int + rejected_candidates: int + cold_start: bool + + +@dataclass(frozen=True) +class _MemoryCheck: + estimated_required_bytes: int + available_bytes: int + fits: bool + + +@dataclass(frozen=True) +class _MemoryProfile: + bytes_per_token: float + packed_tokens: int + + +@dataclass(frozen=True) +class _CandidateMicroBatch(Generic[ForwardInputsT]): + inputs: Sequence[ForwardInputsT] + indices: tuple[int, ...] + plan: "_FlatForwardPlan" + check: _MemoryCheck + stats_global_count: int + rejected_candidates: int + cold_start: bool + + +class TrainerRankMemoryError(RuntimeError): + pass + + +@dataclass(frozen=True) +class _PushedSlot: + trainer: "TrainerRank" + ref: "LoRASlotRef" + + def __enter__(self) -> "_PushedSlot": + return self + + def __exit__(self, *args: object) -> bool: + if not self.trainer._slot_stack or self.trainer._slot_stack[-1] != self.ref: + raise RuntimeError( + "Pushed LoRA/checkpoint stack changed before context exit" + ) + self.trainer.pop_pushed_lora_or_checkpoint() + return False + + +@dataclass(frozen=True) +class _ForwardItem: + request: AnyForwardInput + input_ids: torch.Tensor + labels: torch.Tensor | None + + +@dataclass(frozen=True) +class _PreparedPackedForward: + tokens: torch.Tensor + position_ids: torch.Tensor + attention_state: "SharedPrefixAttentionState | ArtContextParallelState" + packed_seq_params: "PackedSeqParams | None" + positions_by_item: tuple[torch.Tensor, ...] + source_positions_by_item: tuple[torch.Tensor, ...] + + +@dataclass(frozen=True) +class _RowMatch: + source_offsets: torch.Tensor + row_offsets: torch.Tensor + + +@dataclass(frozen=True) +class _MemorySignature: + topology: tuple[int, int, int, int] + shared_prefix_max_depth: int + slot_group_count: int + request_mix: tuple[str, ...] + + +@dataclass(frozen=True) +class _ForwardGroupPlan: + slot_ref: "LoRASlotRef | None" + request_indices: tuple[int, ...] + items: tuple[_ForwardItem, ...] + packed: SharedPrefixPack + + +@dataclass(frozen=True) +class _FlatForwardPlan: + request_count: int + groups: tuple[_ForwardGroupPlan, ...] + packed_tokens: int + logical_tokens: int + output_bytes: int + signature: _MemorySignature + + +type _AdaptivePlanCacheKey = tuple[tuple[int, ...], object, tuple[object, ...], int] + + +class TrainerRank: + def __init__( + self, + runtime: TrainingRuntime, + *, + head_chunk_tokens: int = 512, + shared_prefix_max_depth: int = 1, + memory_safety_factor: float = 1.10, + memory_reserve_fraction: float = 0.03, + ) -> None: + if head_chunk_tokens < 1: + raise ValueError("head_chunk_tokens must be >= 1") + if shared_prefix_max_depth < 0: + raise ValueError("shared_prefix_max_depth must be >= 0") + if memory_safety_factor < 1.0: + raise ValueError("memory_safety_factor must be >= 1.0") + if not (0.0 <= memory_reserve_fraction < 1.0): + raise ValueError("memory_reserve_fraction must be in [0, 1)") + self.runtime: TrainingRuntime = runtime + self.head_chunk_tokens = head_chunk_tokens + self.shared_prefix_max_depth = shared_prefix_max_depth + self.memory_safety_factor = memory_safety_factor + self.memory_reserve_fraction = memory_reserve_fraction + self.device = next(runtime.model[0].parameters()).device + self._param_dtype_size = _dtype_size(next(runtime.model[0].parameters()).dtype) + try: + metadata_model = _language_model(runtime.model[0]) + except RuntimeError: + metadata_model = None + self._hidden_size = _hidden_size(metadata_model, runtime.provider) + self._padded_vocab_size = ( + None if metadata_model is None else _padded_vocab_size(metadata_model) + ) + self._num_layers = int( + getattr(getattr(metadata_model, "config", None), "num_layers", 0) + or getattr(runtime.provider, "num_layers", 1) + or 1 + ) + self._default_slot_ref: LoRASlotRef | None = None + self._slot_stack: list[LoRASlotRef] = [] + self._dynamic_optimizers: dict[str, torch.optim.Optimizer] = {} + self._checkpoint_slot_params_by_name: dict[ + str, tuple[torch.nn.Parameter, ...] + ] = {} + self._memory_profiles: dict[_MemorySignature, _MemoryProfile] = {} + self._adaptive_plan_cache: dict[_AdaptivePlanCacheKey, _FlatForwardPlan] = {} + self._adaptive_plan_cache_top_level_ids: tuple[int, ...] = () + self._adaptive_estimate_cache: dict[ + _AdaptivePlanCacheKey, tuple[_MemoryCheck, bool] | None + ] = {} + self._last_global_micro_batch_size: int | None = None + self.zero_grad() + + def zero_grad(self) -> None: + for chunk in self.runtime.model: + zero_grad_buffer = getattr(chunk, "zero_grad_buffer", None) + if callable(zero_grad_buffer): + zero_grad_buffer() + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is not None: + optimizer.zero_grad() + for params in self._checkpoint_slot_params_by_name.values(): + for param in params: + param.grad = None + + def _optimizer(self) -> "MegatronOptimizer": + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is None: + raise RuntimeError("TrainerRank requires a runtime with an optimizer") + return optimizer + + def set_checkpoint(self, name: str | None) -> None: + self._set_default_slot(self._slot_ref("checkpoint", name)) + + def set_lora(self, name: str | None) -> None: + self._set_default_slot(self._slot_ref("lora", name)) + + def push_checkpoint(self, name: str | None) -> _PushedSlot: + ref = self._slot_ref("checkpoint", name) + self._slot_stack.append(ref) + return _PushedSlot(self, ref) + + def push_lora(self, name: str | None) -> _PushedSlot: + ref = self._slot_ref("lora", name) + self._slot_stack.append(ref) + return _PushedSlot(self, ref) + + def pop_pushed_lora_or_checkpoint(self) -> None: + if not self._slot_stack: + raise RuntimeError("No pushed LoRA or checkpoint to pop") + self._slot_stack.pop() + + def load_checkpoint_slot( + self, + name: str, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float | None = None, + ) -> int: + loaded = self._load_slot( + "checkpoint", name, adapter_model, trainable=True, alpha=alpha + ) + self._checkpoint_slot_params_by_name[name] = ( + self._validate_dynamic_slot_consistency("checkpoint", name, loaded) + ) + self._dynamic_optimizers.pop(name, None) + return loaded + + def load_lora_slot( + self, + name: str, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float | None = None, + ) -> int: + loaded = self._load_slot( + "lora", name, adapter_model, trainable=False, alpha=alpha + ) + self._validate_dynamic_slot_consistency("lora", name, loaded) + return loaded + + def _load_slot( + self, + kind: Literal["checkpoint", "lora"], + name: str, + adapter_model: dict[str, torch.Tensor], + *, + trainable: bool, + alpha: float | None, + ) -> int: + from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model + + return load_lora_slot_into_model( + self.runtime.model, + self._slot_ref(kind, name), + adapter_model, + alpha=LORA_ALPHA if alpha is None else alpha, + requires_grad=trainable, + ) + + def _set_default_slot(self, ref: "LoRASlotRef") -> None: + if self._slot_stack: + raise RuntimeError("Cannot set a LoRA/checkpoint while a slot is pushed") + self._default_slot_ref = ref + + @staticmethod + def _slot_ref( + kind: Literal["checkpoint", "lora"], name: str | None + ) -> "LoRASlotRef": + from art.megatron.lora import LoRASlotRef + + return LoRASlotRef(kind=kind, name=name) + + def _validate_dynamic_slot_consistency( + self, + kind: Literal["checkpoint", "lora"], + name: str, + loaded_sites: int, + ) -> tuple[torch.nn.Parameter, ...]: + from art.megatron.lora import iter_lora_slot_parameters + + ref = self._slot_ref(kind, name) + params = tuple(iter_lora_slot_parameters(self.runtime.model, ref)) + if not (dist.is_available() and dist.is_initialized()): + return params + + local = { + "rank": dist.get_rank(), + "loaded_sites": int(loaded_sites), + "param_count": len(params), + "numel": sum(int(param.numel()) for param in params), + "signature": [ + ( + tuple(int(dim) for dim in param.shape), + str(param.dtype), + bool(getattr(param, "allreduce", True)), + str(getattr(param, "grad_sync_domain", "tp_default")), + str(getattr(param, "grad_sync_op", "none")), + ) + for param in params + ], + } + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + ranks = [rank for rank in gathered if rank is not None] + reference = ranks[0] + if all( + rank["loaded_sites"] == reference["loaded_sites"] + and rank["signature"] == reference["signature"] + for rank in ranks + ): + return params + + summary = [ + {key: rank[key] for key in ("rank", "loaded_sites", "param_count", "numel")} + for rank in ranks + ] + raise RuntimeError( + f"Dynamic LoRA slot {kind}:{name} is not loaded consistently across " + "distributed ranks. This usually means a sharded/exported LoRA state " + "dict was passed directly to TrainerRank; gather or materialize the " + "full adapter state before loading a dynamic slot. " + f"Rank summary: {summary}." + ) + + def _resolve_slot_ref(self, request: AnyForwardInput) -> "LoRASlotRef | None": + if request.checkpoint is not Unset: + return self._slot_ref("checkpoint", cast(str | None, request.checkpoint)) + if request.lora is not Unset: + return self._slot_ref("lora", cast(str | None, request.lora)) + if self._slot_stack: + return self._slot_stack[-1] + return self._default_slot_ref + + def forward_micro_batches( + self, + inputs: Iterable[ForwardInputsT], + ) -> Iterator[MicroBatch[ForwardInputsT]]: + items = list(inputs) + self._validate_replicated_top_level_count(len(items)) + start = 0 + while start < len(items): + candidate = self._select_next_micro_batch(items, start) + flat_outputs = iter( + self._run_flat_plan_with_memory_tracking( + candidate.plan, + context="forward_micro_batches", + ) + ) + outputs = [_unflatten(item, flat_outputs) for item in candidate.inputs] + stop = start + candidate.stats_global_count + if stop < len(items): + self._last_global_micro_batch_size = max( + self._last_global_micro_batch_size or 0, + candidate.stats_global_count, + ) + yield MicroBatch( + inputs=candidate.inputs, + outputs=outputs, + indices=candidate.indices, + stats=MicroBatchStats( + global_start=start, + global_stop=stop, + global_count=candidate.stats_global_count, + local_count=len(candidate.inputs), + packed_tokens=candidate.plan.packed_tokens, + logical_tokens=candidate.plan.logical_tokens, + estimated_required_bytes=candidate.check.estimated_required_bytes, + available_bytes=candidate.check.available_bytes, + rejected_candidates=candidate.rejected_candidates, + cold_start=candidate.cold_start, + ), + ) + start = stop + + @overload + def dp_rank_forward( + self, + inputs: Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]], + ) -> Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]: ... + + @overload + def dp_rank_forward( + self, + inputs: Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ], + ) -> Sequence[ + Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ]: ... + + def dp_rank_forward(self, inputs: ForwardInputs) -> ForwardOutputs: + materialized = _materialize(inputs) + plan = self._plan_flat_forward(list(_flatten(materialized))) + check = self._memory_check(plan) + if not check.fits: + self._raise_memory_error( + plan, + check, + context="dp_rank_forward", + message="forward is predicted to exceed available memory", + ) + outputs = iter( + self._run_flat_plan_with_memory_tracking( + plan, + context="dp_rank_forward", + ) + ) + return _unflatten(materialized, outputs) + + def dp_reduce( + self, + tensor: torch.Tensor, + *, + op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, + ) -> None: + from megatron.core import parallel_state as ps + + dist.all_reduce( + tensor, + op=op, + group=ps.get_data_parallel_group(with_context_parallel=True), + ) + + def optim_step( + self, + *, + params: AdamParams, + scale_grads: float = 1.0, + checkpoints: Sequence[str] | None = None, + ) -> dict[str, float]: + selected_checkpoints = self._selected_dynamic_checkpoints(checkpoints) + if selected_checkpoints: + return self._dynamic_optim_step( + selected_checkpoints, + params=params, + scale_grads=scale_grads, + ) + + from art.megatron.training.finalize_grads import ( + finalize_model_grads_extended, + flush_param_grads_to_main_grads, + ) + from art.megatron.training.model_chunks import as_megatron_api_chunks + + optimizer = self._optimizer() + flush_param_grads_to_main_grads(self.runtime.model) + finalize_model_grads_extended( + as_megatron_api_chunks(self.runtime.model), + num_tokens=None, + ) + self._scale_main_grads(scale_grads) + self._configure_optimizer(params) + update_successful, grad_norm, num_zeros = optimizer.step() + optimizer.zero_grad() + self.zero_grad() + return { + "learning_rate": float(params.learning_rate), + "grad_norm": float(grad_norm), + "update_successful": float(bool(update_successful)), + "num_zeros_in_grad": float(num_zeros or 0), + } + + def _selected_dynamic_checkpoints( + self, + checkpoints: Sequence[str] | None, + ) -> tuple[str, ...]: + if checkpoints is not None: + if ( + unknown := set(checkpoints) + - self._checkpoint_slot_params_by_name.keys() + ): + raise ValueError(f"Unknown checkpoint slots: {sorted(unknown)}") + return tuple(dict.fromkeys(checkpoints)) + slots = tuple(sorted(self._checkpoint_slot_params_by_name.items())) + if not slots: + return () + has_grad = torch.tensor( + [ + int(any(param.grad is not None for param in params)) + for _, params in slots + ], + device=self.device, + dtype=torch.int32, + ) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(has_grad, op=dist.ReduceOp.MAX) + return tuple(name for (name, _), flag in zip(slots, has_grad.tolist()) if flag) + + def _dynamic_optim_step( + self, + checkpoint_names: Sequence[str], + *, + params: AdamParams, + scale_grads: float, + ) -> dict[str, float]: + all_params: list[torch.nn.Parameter] = [] + for name in checkpoint_names: + slot_params = self._checkpoint_slot_params_by_name[name] + for param in slot_params: + if param.grad is None: + param.grad = torch.zeros_like(param) + elif scale_grads != 1.0: + param.grad.mul_(scale_grads) + self._reduce_dynamic_grads(slot_params) + all_params.extend(slot_params) + + grad_norm = torch.nn.utils.clip_grad_norm_( + all_params, + max_norm=params.grad_clip_norm, + ) + for name in checkpoint_names: + optimizer = self._dynamic_optimizer(name, params) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + return { + "learning_rate": float(params.learning_rate), + "grad_norm": float(grad_norm), + "update_successful": 1.0, + "num_zeros_in_grad": 0.0, + } + + def _dynamic_optimizer( + self, + name: str, + params: AdamParams, + ) -> torch.optim.Optimizer: + optimizer = self._dynamic_optimizers.get(name) + if optimizer is None: + optimizer = torch.optim.AdamW( + self._checkpoint_slot_params_by_name[name], + lr=params.learning_rate, + betas=(params.beta1, params.beta2), + weight_decay=params.weight_decay, + ) + self._dynamic_optimizers[name] = optimizer + return optimizer + for group in optimizer.param_groups: + group["lr"] = params.learning_rate + group["betas"] = (params.beta1, params.beta2) + group["weight_decay"] = params.weight_decay + return optimizer + + def _reduce_dynamic_grads(self, params: Sequence[torch.nn.Parameter]) -> None: + from megatron.core import parallel_state as ps + + from art.megatron.training.finalize_grads import ( + coalesced_all_reduce, + tensor_parallel_grad_sync, + ) + + buckets: dict[ + tuple[int, str, torch.dtype, torch.device], + tuple[object, dist.ReduceOp.RedOpType, list[torch.Tensor]], + ] = {} + + def add(group: object, op: dist.ReduceOp.RedOpType, grad: torch.Tensor) -> None: + key = (id(group), str(op), grad.dtype, grad.device) + buckets.setdefault(key, (group, op, []))[2].append(grad) + + for param in params: + grad = param.grad + if grad is None: + continue + if bool(getattr(param, "allreduce", True)): + group = ps.get_data_parallel_group(with_context_parallel=True) + else: + group = ps.get_expert_data_parallel_group() + if group is not None and group.size() > 1: + add(group, dist.ReduceOp.SUM, grad) + + sync = tensor_parallel_grad_sync(param, name="dynamic LoRA") + if sync is not None: + group, reduce_op = sync + add(group, reduce_op, grad) + + for group, op, grads in buckets.values(): + coalesced_all_reduce(grads, group=group, op=op) + + def _select_next_micro_batch( + self, + items: Sequence[ForwardInputsT], + start: int, + ) -> _CandidateMicroBatch[ForwardInputsT]: + dp_rank, dp_size = self._dp_rank_and_size() + remaining = len(items) - start + min_width = min(dp_size, remaining) + if min_width <= 0: + raise RuntimeError("cannot select an empty microbatch window") + top_level_ids = tuple(id(item) for item in items) + if top_level_ids != self._adaptive_plan_cache_top_level_ids: + self._adaptive_plan_cache.clear() + self._adaptive_estimate_cache.clear() + self._adaptive_plan_cache_top_level_ids = top_level_ids + + def clamp_width(width: int) -> int: + return max(min_width, min(width, remaining)) + + base_granularity = 1 if remaining < 64 else 8 if remaining < 256 else 32 + granularity = max( + 1, + ((base_granularity + dp_size - 1) // dp_size) * dp_size, + ) + + def snap_width(width: int) -> int: + width = clamp_width(width) + if width in (min_width, remaining) or granularity <= 1: + return width + if width < granularity: + return width + return max(min_width, (width // granularity) * granularity) + + def local_slice(width: int) -> tuple[tuple[int, ...], list[ForwardInputsT]]: + stop = start + clamp_width(width) + indices = tuple(range(start + dp_rank, stop, dp_size)) + return indices, [items[index] for index in indices] + + def candidate( + width: int, + estimated_check: _MemoryCheck | None = None, + *, + rejected: int, + ) -> _CandidateMicroBatch[ForwardInputsT]: + width = clamp_width(width) + indices, local_inputs = local_slice(width) + plan = self._cached_adaptive_plan(indices, local_inputs) + return _CandidateMicroBatch( + inputs=local_inputs, + indices=indices, + plan=plan, + check=estimated_check or self._memory_check(plan), + stats_global_count=width, + rejected_candidates=rejected, + cold_start=not self._all_ranks_have_memory_profile( + packed_tokens=plan.packed_tokens, + signature=plan.signature, + ), + ) + + def estimate(width: int) -> tuple[_MemoryCheck, bool] | None: + indices, local_inputs = local_slice(width) + return self._cached_adaptive_estimate(indices, local_inputs) + + def probe(width: int) -> tuple[bool, _MemoryCheck | None, bool]: + estimated = estimate(width) + if estimated is not None: + check, trusted = estimated + return trusted and check.fits, check, trusted + item = candidate(width, rejected=0) + return item.check.fits, item.check, not item.cold_start + + rejected = 0 + best_width = min_width + best_check: _MemoryCheck | None = None + + def fit(width: int) -> bool: + nonlocal best_width, best_check, rejected + ok, check, _ = probe(width) + if ok: + best_width = snap_width(width) + best_check = check + else: + rejected += 1 + return ok + + def search_below(failed_width: int) -> None: + low = best_width + 1 + high = failed_width - 1 + while low <= high: + mid = (low + high) // 2 + if fit(mid): + low = mid + 1 + else: + high = mid - 1 + + first_fits, first_check, first_trusted = probe(min_width) + if not first_fits: + first = candidate(min_width, first_check, rejected=rejected) + if not first.check.fits: + self._raise_memory_error( + first.plan, + first.check, + context="forward_micro_batches", + message="smallest DP microbatch is predicted to exceed available memory", + ) + if first.cold_start: + return first + best_check = first.check + else: + best_check = first_check + + stable_width = self._last_global_micro_batch_size + if stable_width is not None and stable_width >= max(64, granularity * 2): + stable_capacity = stable_width + stable_width = clamp_width(stable_capacity) + if fit(stable_width): + grow_multiplier = 4 if stable_capacity < 256 else 2 + grow_capacity = min(remaining, stable_capacity * grow_multiplier) + if remaining > grow_capacity: + grow_width = clamp_width(grow_capacity) + if grow_width > stable_width and not fit(grow_width): + search_below(grow_width) + return candidate(best_width, best_check, rejected=rejected) + search_below(stable_width) + self._last_global_micro_batch_size = best_width + return candidate(best_width, best_check, rejected=rejected) + + high_fail: int | None = None + width = min( + remaining, + max(min_width, (self._last_global_micro_batch_size or min_width) * 2), + ) + while width <= remaining: + if fit(width): + if width == remaining: + break + width = min(remaining, max(width + 1, width * 2)) + continue + high_fail = width + break + + if high_fail is not None: + search_below(high_fail) + + if not first_trusted and best_width == min_width and best_check is None: + return candidate(min_width, first_check, rejected=rejected) + return candidate(best_width, best_check, rejected=rejected) + + def _cached_adaptive_plan( + self, + indices: tuple[int, ...], + local_inputs: Sequence[ForwardInputsT], + ) -> _FlatForwardPlan: + key = self._adaptive_cache_key(indices) + cached = self._adaptive_plan_cache.get(key) + if cached is not None: + return cached + plan = self._plan_flat_forward(list(_flatten(local_inputs))) + self._adaptive_plan_cache[key] = plan + return plan + + def _cached_adaptive_estimate( + self, + indices: tuple[int, ...], + local_inputs: Sequence[ForwardInputsT], + ) -> tuple[_MemoryCheck, bool] | None: + key = self._adaptive_cache_key(indices) + if key in self._adaptive_estimate_cache: + return self._adaptive_estimate_cache[key] + estimate = self._estimate_flat_forward(list(_flatten(local_inputs))) + if estimate is not None: + packed_tokens, output_bytes, signature = estimate + estimate = ( + self._memory_check_required( + self._estimate_required_memory_bytes_from_values( + packed_tokens=packed_tokens, + output_bytes=output_bytes, + signature=signature, + ) + ), + self._all_ranks_have_memory_profile( + packed_tokens=packed_tokens, + signature=signature, + ), + ) + self._adaptive_estimate_cache[key] = estimate + return estimate + + def _adaptive_cache_key( + self, + indices: tuple[int, ...], + ) -> _AdaptivePlanCacheKey: + return ( + indices, + self._default_slot_ref, + tuple(self._slot_stack), + self.shared_prefix_max_depth, + ) + + def _validate_replicated_top_level_count(self, count: int) -> None: + if not (dist.is_available() and dist.is_initialized()): + return + counts = [0 for _ in range(dist.get_world_size())] + dist.all_gather_object(counts, int(count)) + if len(set(counts)) == 1: + return + raise ValueError( + "forward_micro_batches requires the same top-level input count on every " + "distributed rank. Pass already-DP-local inputs to dp_rank_forward instead. " + f"Observed counts by rank: {counts}." + ) + + def _dp_rank_and_size(self) -> tuple[int, int]: + try: + from megatron.core import parallel_state as ps + + return int(ps.get_data_parallel_rank()), int( + ps.get_data_parallel_world_size() + ) + except (AssertionError, ImportError, RuntimeError, ValueError): + return 0, 1 + + def _plan_flat_forward( + self, requests: Sequence[AnyForwardInput] + ) -> _FlatForwardPlan: + plans: list[_ForwardGroupPlan] = [] + output_bytes = self._estimate_group_request_output_bytes(requests) + logical_tokens = sum(int(request.input_tokens.numel()) for request in requests) + groups = self._group_active_request_indices(requests) + for slot_ref, group_indices in groups: + items = tuple( + self._forward_item(requests[index]) for index in group_indices + ) + packed = pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=self.shared_prefix_max_depth, + ) + plans.append( + _ForwardGroupPlan( + slot_ref=slot_ref, + request_indices=tuple(group_indices), + items=items, + packed=packed, + ) + ) + + return _FlatForwardPlan( + request_count=len(requests), + groups=tuple(plans), + packed_tokens=sum(int(plan.packed.tokens.numel()) for plan in plans), + logical_tokens=logical_tokens, + output_bytes=output_bytes, + signature=self._memory_signature_from_requests( + requests, + slot_group_count=len(plans), + ), + ) + + def _estimate_flat_forward( + self, requests: Sequence[AnyForwardInput] + ) -> tuple[int, int, _MemorySignature] | None: + groups = self._group_active_request_indices(requests) + packed_tokens = 0 + for _, group_indices in groups: + group_packed_tokens = estimate_shared_prefix_packed_tokens( + (requests[index].input_tokens for index in group_indices), + max_depth=self.shared_prefix_max_depth, + ) + if group_packed_tokens is None: + return None + packed_tokens += group_packed_tokens + + return ( + packed_tokens, + self._estimate_group_request_output_bytes(requests), + self._memory_signature_from_requests( + requests, + slot_group_count=len(groups), + ), + ) + + def _group_active_request_indices( + self, + requests: Sequence[AnyForwardInput], + ) -> tuple[tuple["LoRASlotRef | None", tuple[int, ...]], ...]: + groups: dict[LoRASlotRef | None, list[int]] = {} + for index, request in enumerate(requests): + if ( + request.target_tokens is not None + or request.logits + or request.top_k is not None + or request.hidden_states + ): + groups.setdefault(self._resolve_slot_ref(request), []).append(index) + return tuple((slot_ref, tuple(indices)) for slot_ref, indices in groups.items()) + + def _run_flat_plan_with_memory_tracking( + self, + plan: _FlatForwardPlan, + *, + context: str, + ) -> list[AnyForwardOutput]: + if torch.cuda.is_available() and self.device.type == "cuda": + torch.cuda.synchronize(self.device) + baseline = int(torch.cuda.memory_allocated(self.device)) + torch.cuda.reset_peak_memory_stats(self.device) + else: + baseline = 0 + try: + outputs = self._execute_flat_plan(plan) + except torch.cuda.OutOfMemoryError as exc: + check = self._memory_check(plan) + self._raise_memory_error( + plan, + check, + context=context, + message="CUDA OOM occurred despite the planner estimate", + ) + raise AssertionError("unreachable") from exc + if torch.cuda.is_available() and self.device.type == "cuda": + torch.cuda.synchronize(self.device) + peak = int(torch.cuda.max_memory_allocated(self.device)) + self._update_memory_profile(plan, max(0, peak - baseline)) + return outputs + + def _execute_flat_plan(self, plan: _FlatForwardPlan) -> list[AnyForwardOutput]: + outputs = [ + ForwardOutput( + target_logprobs=None, + top_k=None, + logits=None, + hidden_states=None, + ) + for _ in range(plan.request_count) + ] + for group in plan.groups: + from art.megatron.lora import use_lora_slot + + with use_lora_slot(group.slot_ref): + prepared = self._prepare_packed_forward(group.packed) + item_outputs = self._forward_packed(group.items, prepared) + for index, output in zip(group.request_indices, item_outputs, strict=True): + outputs[index] = output + return outputs + + def _estimate_group_request_output_bytes( + self, + requests: Sequence[AnyForwardInput], + ) -> int: + total = 0 + for request in requests: + seq_len = int(request.input_tokens.numel()) + if request.target_tokens is not None: + total += int(request.target_tokens.numel()) * _dtype_size(torch.float32) + if request.top_k is not None: + total += ( + seq_len + * int(request.top_k) + * (_dtype_size(torch.float32) + _dtype_size(torch.long)) + ) + if request.logits: + if self._padded_vocab_size is None: + raise RuntimeError("logits output memory requires a GPT model") + total += seq_len * self._padded_vocab_size * self._param_dtype_size + if request.hidden_states: + total += seq_len * self._hidden_size * self._param_dtype_size + return total + + def _memory_signature_from_requests( + self, + requests: Sequence[AnyForwardInput], + *, + slot_group_count: int, + ) -> _MemorySignature: + return _MemorySignature( + topology=self._topology_key(), + shared_prefix_max_depth=self.shared_prefix_max_depth, + slot_group_count=slot_group_count, + request_mix=tuple( + sorted({_request_mix_key(request) for request in requests}) + ), + ) + + def _topology_key(self) -> tuple[int, int, int, int]: + try: + topology = self._topology() + return cast( + tuple[int, int, int, int], + tuple( + int(getattr(topology, name)) for name in ("dp", "tp", "cp", "pp") + ), + ) + except (AssertionError, AttributeError, ImportError, RuntimeError, ValueError): + return (1, 1, 1, 1) + + def _memory_check( + self, + forward: _FlatForwardPlan, + ) -> _MemoryCheck: + return self._memory_check_required( + self._estimate_required_memory_bytes_from_values( + packed_tokens=forward.packed_tokens, + output_bytes=forward.output_bytes, + signature=forward.signature, + ) + ) + + def _memory_check_required(self, required: int) -> _MemoryCheck: + available = self._available_memory_bytes() + if dist.is_available() and dist.is_initialized(): + group = self._forward_memory_group() + values = torch.tensor( + [float(required), float(available)], + device=self.device if self.device.type == "cuda" else "cpu", + dtype=torch.float64, + ) + dist.all_reduce(values[0], op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(values[1], op=dist.ReduceOp.MIN, group=group) + required = int(values[0].item()) + available = int(values[1].item()) + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=available, + fits=required <= available, + ) + + @staticmethod + def _forward_memory_group() -> object | None: + try: + from megatron.core import parallel_state as ps + + return ps.get_tensor_and_context_parallel_group(check_initialized=False) + except (AssertionError, ImportError, RuntimeError, ValueError): + return None + + def _raise_memory_error( + self, + plan: _FlatForwardPlan, + check: _MemoryCheck, + *, + context: str, + message: str, + ) -> None: + raise TrainerRankMemoryError( + f"{context}: {message}. " + f"packed_tokens={plan.packed_tokens} " + f"logical_tokens={plan.logical_tokens} " + f"output_gb={plan.output_bytes / 1024**3:.3f} " + f"estimated_required_gb={check.estimated_required_bytes / 1024**3:.3f} " + f"available_gb={check.available_bytes / 1024**3:.3f}. " + "Use smaller top-level items, reduce output requests, or call " + "dp_rank_forward with already-DP-local smaller inputs." + ) + + def _estimate_required_memory_bytes_from_values( + self, + *, + packed_tokens: int, + output_bytes: int, + signature: _MemorySignature, + ) -> int: + if packed_tokens <= 0: + return output_bytes + profiled = self._memory_profiles.get(signature) + activation_factor = max(4, min(16, self._num_layers // 4 + 4)) + static_compute = ( + packed_tokens + * self._hidden_size + * self._param_dtype_size + * activation_factor + ) + if ( + profiled is None + or profiled.packed_tokens * _MEMORY_PROFILE_TRUST_GROWTH < packed_tokens + ): + compute = static_compute + else: + compute = max(static_compute, int(profiled.bytes_per_token * packed_tokens)) + return int((output_bytes + compute) * self.memory_safety_factor) + + def _available_memory_bytes(self) -> int: + if not (torch.cuda.is_available() and self.device.type == "cuda"): + return 1 << 60 + free, total = torch.cuda.mem_get_info(self.device) + allocated = int(torch.cuda.memory_allocated(self.device)) + reserved = int(torch.cuda.memory_reserved(self.device)) + reusable_reserved = max(0, reserved - allocated) + reserve = int(total * self.memory_reserve_fraction) + return max(0, int(free) + reusable_reserved - reserve) + + def _all_ranks_have_memory_profile( + self, + *, + packed_tokens: int, + signature: _MemorySignature, + ) -> bool: + profile = self._memory_profiles.get(signature) + local = packed_tokens <= 0 or ( + profile is not None + and profile.packed_tokens * _MEMORY_PROFILE_TRUST_GROWTH >= packed_tokens + ) + if dist.is_available() and dist.is_initialized(): + value = torch.tensor( + int(local), + device=self.device if self.device.type == "cuda" else "cpu", + dtype=torch.int32, + ) + dist.all_reduce(value, op=dist.ReduceOp.MIN) + return bool(value.item()) + return local + + def _update_memory_profile( + self, plan: _FlatForwardPlan, peak_delta_bytes: int + ) -> None: + if plan.packed_tokens <= 0: + return + compute_delta = max(0, peak_delta_bytes - plan.output_bytes) + bytes_per_token = compute_delta / max(1, plan.packed_tokens) + previous = self._memory_profiles.get(plan.signature) + self._memory_profiles[plan.signature] = _MemoryProfile( + bytes_per_token=max( + bytes_per_token, + 0.0 if previous is None else previous.bytes_per_token, + ), + packed_tokens=max( + plan.packed_tokens, + 0 if previous is None else previous.packed_tokens, + ), + ) + + def _forward_item(self, request: AnyForwardInput) -> _ForwardItem: + if request.top_k is not None: + _validate_top_k(request.top_k, _language_model(self.runtime.model[0])) + input_ids = request.input_tokens.reshape(-1).to(dtype=torch.long) + if int(input_ids.numel()) == 0: + raise ValueError("input_tokens must not be empty") + labels = None + if request.target_tokens is not None: + labels = request.target_tokens.to(dtype=torch.long) + if int(labels.numel()) == 0: + raise ValueError("target_tokens must not be empty") + input_shape = tuple(request.input_tokens.shape) + if tuple(labels.shape) == input_shape: + labels = labels.reshape(-1) + elif ( + labels.ndim > request.input_tokens.ndim + and tuple(labels.shape[: request.input_tokens.ndim]) == input_shape + ): + labels = labels.reshape( + int(input_ids.numel()), *labels.shape[request.input_tokens.ndim :] + ) + elif labels.ndim < 1 or int(labels.shape[0]) != int(input_ids.numel()): + raise ValueError( + "target_tokens must match input_tokens or add trailing target " + f"dimensions: input_tokens={input_shape} " + f"target_tokens={tuple(labels.shape)}" + ) + return _ForwardItem(request=request, input_ids=input_ids, labels=labels) + + def _forward_packed( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + ) -> list[AnyForwardOutput]: + hidden_by_row = self._gather_sequence_parallel_hidden( + self._decoder_hidden(prepared) + ) + return self._project_head(items, prepared, hidden_by_row) + + def _decoder_hidden( + self, + prepared: _PreparedPackedForward, + ) -> torch.Tensor: + from art.megatron.train import _placeholder_attention_mask + + handler = self.runtime.model_support_handler + model = _language_model(self.runtime.model[0]) + attention_mask = _placeholder_attention_mask(self.device) + forward_kwargs = handler.get_forward_kwargs( + self.runtime.model[0], + attention_bias=prepared.attention_state, + ) + extra_block_kwargs = cast( + dict[str, object] | None, + forward_kwargs.pop("extra_block_kwargs", None), + ) + preprocessed = model._preprocess( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + packed_seq_params=cast("PackedSeqParams", prepared.packed_seq_params), + ) + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = preprocessed[:6] + rotary_pos_cos_sin = preprocessed[6] if len(preprocessed) == 7 else None + return cast( + torch.Tensor, + model.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + packed_seq_params=prepared.packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + **(extra_block_kwargs or {}), + ), + ) + + def _project_head( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + hidden_by_row: torch.Tensor, + ) -> list[AnyForwardOutput]: + model = _language_model(self.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + device = hidden_by_row.device + target_logprobs = [None for _ in items] + logits: list[torch.Tensor | None] = [None for _ in items] + top_k: list[TopK | None] = [None for _ in items] + label_rows: list[torch.Tensor | None] = [None for _ in items] + projected_rows: list[torch.Tensor] = [] + + for index, (item, positions_cpu) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ): + positions = positions_cpu.to(device=device) + if item.request.logits or item.request.top_k is not None: + projected_rows.append(positions) + if item.labels is not None: + source_positions = prepared.source_positions_by_item[index].to(device) + labels = item.labels.to(device=device).index_select(0, source_positions) + label_rows[index] = labels + target_logprobs[index] = torch.zeros( + tuple(labels.shape), + device=device, + dtype=torch.float32, + ) + if item.request.top_k is None and not item.request.logits: + valid = labels != -100 + if labels.ndim > 1: + valid = valid.reshape(int(labels.shape[0]), -1).any(dim=1) + valid_offsets = torch.nonzero(valid, as_tuple=False).reshape(-1) + if int(valid_offsets.numel()): + projected_rows.append(positions.index_select(0, valid_offsets)) + if item.request.logits: + logits[index] = torch.empty( + (int(positions.numel()), _padded_vocab_size(model)), + device=hidden_by_row.device, + dtype=hidden_by_row.dtype, + ) + + row_tensor = ( + torch.cat(projected_rows).unique(sorted=True) + if projected_rows + else torch.empty(0, dtype=torch.long, device=device) + ) + if int(row_tensor.numel()): + local_row_matches = tuple( + _row_match(positions.to(device=device), row_tensor) + for positions in prepared.positions_by_item + ) + self._project_vocab_parallel( + items, + hidden_by_row, + row_tensor, + row_matches=local_row_matches, + item_lengths=tuple( + int(positions.numel()) for positions in prepared.positions_by_item + ), + output_weight=output_weight, + target_logprobs=target_logprobs, + top_k=top_k, + logits=logits, + label_rows=label_rows, + ) + + target_logprobs, top_k = _anchor_disconnected_outputs( + target_logprobs, + top_k, + hidden_by_row, + ) + return [ + ForwardOutput( + target_logprobs=target_logprobs[index], + top_k=top_k[index], + logits=logits[index], + hidden_states=( + _select_positions(hidden_by_row, positions) + if item.request.hidden_states + else None + ), + ) + for index, (item, positions) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ) + ] + + def _project_vocab_parallel( + self, + items: Sequence[_ForwardItem], + hidden_by_row: torch.Tensor, + rows: torch.Tensor, + *, + row_matches: Sequence[_RowMatch], + item_lengths: Sequence[int], + output_weight: torch.Tensor | None, + target_logprobs: list[torch.Tensor | None], + top_k: list[TopK | None], + logits: list[torch.Tensor | None], + label_rows: list[torch.Tensor | None], + ) -> None: + model = _language_model(self.runtime.model[0]) + max_top_k = max((int(item.request.top_k or 0) for item in items), default=0) + need_log_z = any( + item.labels is not None or item.request.top_k is not None for item in items + ) + for start in range(0, int(rows.numel()), self.head_chunk_tokens): + chunk_rows = rows[start : start + self.head_chunk_tokens] + local_logits = self._local_logits_from_hidden_rows( + model, + _select_positions(hidden_by_row, chunk_rows), + output_weight=output_weight, + ) + log_z: torch.Tensor | None = None + local_topk: tuple[torch.Tensor, torch.Tensor] | None = None + if need_log_z: + topk_stats = _try_triton_local_topk_stats(local_logits, k=max_top_k) + logsumexp_stats = ( + _try_triton_local_logsumexp_stats(local_logits) + if topk_stats is None + else None + ) + stats = topk_stats if topk_stats is not None else logsumexp_stats + if stats is not None: + local_max, local_sum = stats[:2] + local_max = local_max.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + global_sum = _all_reduce_tensor_parallel_sum( + local_sum * torch.exp(local_max - global_max) + ) + log_z = global_max + torch.log(global_sum) + else: + log_z = _vocab_parallel_log_z(local_logits) + + if topk_stats is not None: + _, _, local_values, local_tokens = topk_stats + local_topk = (local_values, local_tokens) + elif logsumexp_stats is not None and max_top_k > 0: + local_k = min(max_top_k, int(local_logits.shape[1])) + local_values, local_tokens = torch.topk( + local_logits, k=local_k, dim=-1 + ) + local_topk = (local_values.float(), local_tokens) + + logit_chunks = [ + chunk_offsets + for item, match in zip(items, row_matches, strict=True) + if item.request.logits + for _, chunk_offsets in ( + _match_chunk_offsets( + match, + start=start, + end=start + int(chunk_rows.numel()), + ), + ) + if int(chunk_offsets.numel()) + ] + logit_chunk_offsets = ( + torch.cat(logit_chunks).unique(sorted=True) + if logit_chunks + else torch.empty(0, dtype=torch.long, device=rows.device) + ) + chunk_logits: torch.Tensor | None = None + if int(logit_chunk_offsets.numel()): + chunk_logits = _batch_seq_logits( + self._gather_tensor_parallel_logits( + local_logits.index_select(0, logit_chunk_offsets).unsqueeze(1) + ), + seq_len=int(logit_chunk_offsets.numel()), + ).squeeze(0) + + for index, item in enumerate(items): + offsets, chunk_offsets = _match_chunk_offsets( + row_matches[index], + start=start, + end=start + int(chunk_rows.numel()), + ) + if int(offsets.numel()) == 0: + continue + item_logits = logits[index] + if item_logits is not None: + if chunk_logits is None: + raise RuntimeError("logits output requires gathered logits") + source_offsets, gathered_offsets = _matching_offsets( + chunk_offsets, + logit_chunk_offsets, + ) + item_logits[offsets.index_select(0, source_offsets)] = ( + chunk_logits.index_select(0, gathered_offsets) + ) + labels = label_rows[index] + item_logprobs = target_logprobs[index] + if item_logprobs is not None and labels is not None: + if log_z is None: + raise RuntimeError("target logprobs require logsumexp") + selected_log_z = log_z.index_select(0, chunk_offsets) + item_logprobs[offsets] = _vocab_parallel_target_logprobs( + local_logits, + labels.index_select(0, offsets), + selected_log_z, + row_offsets=chunk_offsets, + ) + k = item.request.top_k + if k is not None: + if log_z is None: + raise RuntimeError("top_k requires logsumexp") + selected_log_z = log_z.index_select(0, chunk_offsets) + if local_topk is not None: + local_values, local_tokens = local_topk + selected_values = local_values.index_select(0, chunk_offsets) + selected_tokens = local_tokens.index_select(0, chunk_offsets) + else: + selected_logits = local_logits.index_select(0, chunk_offsets) + selected_values, selected_tokens = torch.topk( + selected_logits.float(), + k=min(k, int(selected_logits.shape[1])), + dim=-1, + ) + values = _vocab_parallel_topk_from_local( + selected_values, + selected_tokens, + k=k, + log_z=selected_log_z, + vocab_start=_vocab_range(local_logits)[0], + ) + current = top_k[index] + if current is None: + current = TopK( + logprobs=torch.empty( + (item_lengths[index], int(values.logprobs.shape[1])), + device=values.logprobs.device, + dtype=values.logprobs.dtype, + ), + tokens=torch.empty( + (item_lengths[index], int(values.tokens.shape[1])), + device=values.tokens.device, + dtype=values.tokens.dtype, + ), + ) + top_k[index] = current + current.logprobs[offsets] = values.logprobs + current.tokens[offsets] = values.tokens + + def _local_logits_from_hidden_rows( + self, + model: "GPTModel", + hidden: torch.Tensor, + *, + output_weight: torch.Tensor | None, + ) -> torch.Tensor: + output_layer = model.output_layer + sequence_parallel = bool(getattr(output_layer, "sequence_parallel", False)) + if sequence_parallel: + output_layer.sequence_parallel = False + try: + logits, _ = output_layer( + hidden.unsqueeze(1), + weight=output_weight, + runtime_gather_output=None, + ) + finally: + if sequence_parallel: + output_layer.sequence_parallel = True + return _batch_seq_logits( + model._scale_logits(logits), + seq_len=int(hidden.shape[0]), + ).squeeze(0) + + def _gather_sequence_parallel_hidden(self, hidden: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return hidden.squeeze(1) + from megatron.core import tensor_parallel + + gathered = tensor_parallel.gather_from_sequence_parallel_region( + hidden, + tensor_parallel_output_grad=True, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ) + return cast(torch.Tensor, gathered).squeeze(1) + + def _prepare_packed_forward( + self, + batch: SharedPrefixPack, + ) -> _PreparedPackedForward: + topology = self._topology() + batch = _pad_packed_batch(batch, multiple=int(topology.tp)) + if int(topology.cp) > 1: + return self._prepare_context_parallel_forward(batch, topology=topology) + from art.megatron.shared_prefix_state import create_shared_prefix_state + + handler = self.runtime.model_support_handler + provider = self.runtime.provider + return _PreparedPackedForward( + tokens=batch.tokens.to(self.device), + position_ids=batch.position_ids.to(self.device), + attention_state=create_shared_prefix_state( + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + target_device=self.device, + build_gdn_execution_spec=handler.build_gdn_execution_spec, + attention_head_dim=provider.kv_channels, + attention_value_head_dim=provider.kv_channels, + ), + packed_seq_params=None, + positions_by_item=batch.positions_by_sequence, + source_positions_by_item=tuple( + torch.arange( + int(positions.numel()), + dtype=torch.long, + device=positions.device, + ) + for positions in batch.positions_by_sequence + ), + ) + + def _prepare_context_parallel_forward( + self, + batch: SharedPrefixPack, + *, + topology: "ParallelTopology", + ) -> _PreparedPackedForward: + from megatron.core import parallel_state as ps + + from art.megatron.context_parallel.runtime import ( + _dispatch_tensor, + prepare_cp_micro, + ) + from art.megatron.training.microbatches import ( + _context_parallel_config_for_provider, + ) + from art.preprocessing.pack import PackedTensors + + assistant_mask = torch.ones_like(batch.tokens, dtype=torch.bool) + sparse_micro: PackedTensors = { + "tokens": batch.tokens, + "group_ids": batch.group_ids, + "parent_ids": batch.parent_ids, + "input_pos": batch.position_ids, + "assistant_mask": assistant_mask, + "logprobs": torch.full_like( + batch.tokens, float("nan"), dtype=torch.float32 + ), + "advantages": torch.zeros_like(batch.tokens, dtype=torch.float32), + "weights": assistant_mask.to(dtype=torch.float32), + "pixel_values": [None], + "image_grid_thw": [None], + "moe_routing_replay": None, + } + handler = self.runtime.model_support_handler + prepared = prepare_cp_micro( + micro=sparse_micro, + topology=topology, + config=_context_parallel_config_for_provider( + self.runtime.provider, self.device + ), + cp_group=ps.get_context_parallel_group(check_initialized=False), + cp_rank=ps.get_context_parallel_rank(), + build_gdn_execution_spec=handler.build_gdn_execution_spec, + target_device=self.device, + ) + if prepared.rank_plan is None: + raise RuntimeError("CP forward preparation did not return a rank plan") + local_positions = _dispatch_tensor( + torch.arange( + int(batch.tokens.shape[1]), + dtype=torch.long, + ).unsqueeze(0), + rank_plan=prepared.rank_plan, + pad_value=-1, + pad_multiple=prepared.pad_multiple, + ) + local_position_pairs = tuple( + _local_position_pairs(local_positions, positions) + for positions in batch.positions_by_sequence + ) + return _PreparedPackedForward( + tokens=prepared.tensors.tokens, + position_ids=prepared.tensors.input_pos, + attention_state=cast("ArtContextParallelState", prepared.attention_state), + packed_seq_params=prepared.packed_seq_params, + positions_by_item=tuple(pair[0] for pair in local_position_pairs), + source_positions_by_item=tuple(pair[1] for pair in local_position_pairs), + ) + + def _topology(self) -> "ParallelTopology": + from art.megatron.train import _infer_parallel_topology + + return _infer_parallel_topology(self.runtime.model) + + def _gather_tensor_parallel_logits(self, logits: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return logits + from megatron.core import tensor_parallel + + return cast( + torch.Tensor, + tensor_parallel.gather_from_tensor_model_parallel_region(logits), + ) + + def _configure_optimizer(self, params: AdamParams) -> None: + optimizer = self._optimizer() + config = cast("OptimizerConfig | None", optimizer.config) + if config is not None: + config.lr = params.learning_rate + config.adam_beta1 = params.beta1 + config.adam_beta2 = params.beta2 + config.weight_decay = params.weight_decay + config.clip_grad = params.grad_clip_norm + for group in optimizer.param_groups: + param_group = cast(MutableMapping[str, object], group) + param_group["lr"] = params.learning_rate + param_group["weight_decay"] = params.weight_decay + if "betas" in param_group: + param_group["betas"] = (params.beta1, params.beta2) + + def _scale_main_grads(self, scale: float) -> None: + if scale == 1.0: + return + for chunk in self.runtime.model: + for param in chunk.parameters(): + grad = getattr(param, "main_grad", None) + if isinstance(grad, torch.Tensor): + grad.mul_(scale) + elif param.grad is not None: + param.grad.mul_(scale) + + +def _validate_top_k(top_k: int, model: "GPTModel") -> None: + vocab_size = _padded_vocab_size(model) + if top_k > vocab_size: + raise ValueError(f"top_k={top_k} exceeds vocabulary size {vocab_size}") + + +def _request_mix_key(request: AnyForwardInput) -> str: + parts = [] + if request.target_tokens is not None: + target = request.target_tokens + tail_shape = tuple(target.shape[request.input_tokens.ndim :]) + parts.append(f"target:{tail_shape or 'single'}") + if request.top_k is not None: + parts.append(f"topk:{int(request.top_k)}") + if request.logits: + parts.append("logits") + if request.hidden_states: + parts.append("hidden") + return "+".join(parts) if parts else "inactive" + + +def _pad_packed_batch( + batch: SharedPrefixPack, + *, + multiple: int, +) -> SharedPrefixPack: + if multiple <= 1: + return batch + seq_len = int(batch.tokens.shape[1]) + pad = -seq_len % multiple + if pad == 0: + return batch + + device = batch.tokens.device + next_group = ( + int(batch.group_ids.max().item()) + 1 if int(batch.group_ids.numel()) else 1 + ) + pad_group_ids = torch.arange( + next_group, + next_group + pad, + dtype=batch.group_ids.dtype, + device=device, + ).unsqueeze(0) + return SharedPrefixPack( + tokens=torch.cat( + ( + batch.tokens, + torch.zeros((1, pad), dtype=batch.tokens.dtype, device=device), + ), + dim=1, + ), + group_ids=torch.cat((batch.group_ids, pad_group_ids), dim=1), + parent_ids=torch.cat((batch.parent_ids, pad_group_ids), dim=1), + position_ids=torch.cat( + ( + batch.position_ids, + torch.zeros((1, pad), dtype=batch.position_ids.dtype, device=device), + ), + dim=1, + ), + positions_by_sequence=batch.positions_by_sequence, + ) + + +def _language_model(model: torch.nn.Module) -> "GPTModel": + module: object = model + while hasattr(module, "module"): + module = getattr(module, "module") + if hasattr(module, "_preprocess") and hasattr(module, "decoder"): + return cast("GPTModel", module) + language_model = getattr(module, "language_model", None) + if language_model is not None: + return cast("GPTModel", language_model) + raise RuntimeError("expected a Megatron GPT model") + + +def _padded_vocab_size(model: "GPTModel") -> int: + vocab_size = getattr(getattr(model, "config", None), "padded_vocab_size", None) + if vocab_size is None: + vocab_size = getattr(model, "vocab_size", None) + if vocab_size is None: + raise RuntimeError("could not determine full padded vocabulary size") + return int(vocab_size) + + +def _hidden_size(model: "GPTModel | None", provider: object) -> int: + for source in (getattr(model, "config", None), model, provider): + if source is None: + continue + hidden_size = getattr(source, "hidden_size", None) + if hidden_size is not None: + return int(hidden_size) + raise RuntimeError("could not determine hidden size") + + +def _dtype_size(dtype: torch.dtype) -> int: + return torch.empty((), dtype=dtype).element_size() + + +def _vocab_parallel_target_logprobs( + local_logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, + *, + row_offsets: torch.Tensor, +) -> torch.Tensor: + start, _ = _vocab_range(local_logits) + target_logits = _call_compiled( + _owned_target_logits_for_rows, + local_logits, + labels, + start, + row_offsets, + ) + target_logits = _all_reduce_tensor_parallel_sum(target_logits) + return _call_compiled(_finish_target_logprobs, target_logits, labels, log_z) + + +def _owned_target_logits_for_rows( + local_logits: torch.Tensor, + labels: torch.Tensor, + vocab_start: int, + row_offsets: torch.Tensor, +) -> torch.Tensor: + flat_labels = labels.reshape(int(labels.shape[0]), -1) + local_labels = flat_labels - vocab_start + owns_label = ( + (flat_labels != -100) + & (local_labels >= 0) + & (local_labels < int(local_logits.shape[1])) + ) + rows = row_offsets.reshape(int(row_offsets.shape[0]), 1).expand_as(flat_labels) + selected = local_logits[ + rows, + local_labels.clamp(0, int(local_logits.shape[1]) - 1), + ].float() + return selected.masked_fill(~owns_label, 0.0).reshape(labels.shape) + + +def _finish_target_logprobs( + target_logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, +) -> torch.Tensor: + log_z = log_z.reshape(int(log_z.shape[0]), *((1,) * (int(labels.ndim) - 1))) + return (target_logits.float() - log_z).masked_fill(labels == -100, 0.0) + + +def _anchor_disconnected_outputs( + target_logprobs: list[torch.Tensor | None], + top_k: list[TopK | None], + hidden_by_row: torch.Tensor, +) -> tuple[list[torch.Tensor | None], list[TopK | None]]: + if not hidden_by_row.requires_grad: + return target_logprobs, top_k + anchor: torch.Tensor | None = None + + def anchor_tensor(tensor: torch.Tensor) -> torch.Tensor: + nonlocal anchor + if tensor.requires_grad: + return tensor + if anchor is None: + anchor = hidden_by_row.reshape(-1)[:1].float().sum() * 0.0 + return tensor + anchor + + return ( + [ + None if item_logprobs is None else anchor_tensor(item_logprobs) + for item_logprobs in target_logprobs + ], + [ + None + if item_top_k is None + else TopK( + logprobs=anchor_tensor(item_top_k.logprobs), + tokens=item_top_k.tokens, + ) + for item_top_k in top_k + ], + ) + + +def _try_triton_local_topk_stats( + local_logits: torch.Tensor, + *, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None: + if k <= 0 or k > int( + os.environ.get("ART_TRAINER_RANK_TRITON_FUSED_TOPK_MAX", "10") + ): + return None + return cast( + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None, + _try_triton_stats( + "local_topk_stats", + local_logits, + k=min(k, int(local_logits.shape[1])), + ), + ) + + +def _try_triton_local_logsumexp_stats( + local_logits: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor] | None: + return cast( + tuple[torch.Tensor, torch.Tensor] | None, + _try_triton_stats("local_logsumexp_stats", local_logits), + ) + + +def _try_triton_stats( + name: str, + local_logits: torch.Tensor, + **kwargs: object, +) -> object | None: + if not local_logits.is_cuda: + return None + if os.environ.get("ART_TRAINER_RANK_TRITON_TOPK", "1").lower() in { + "0", + "false", + } or int(local_logits.shape[0]) < int( + os.environ.get("ART_TRAINER_RANK_TRITON_MIN_ROWS", "64") + ): + return None + try: + from art.trainer_rank import topk + + return getattr(topk, name)(local_logits, **kwargs) + except Exception: + if os.environ.get("ART_TRAINER_RANK_TRITON_TOPK", "1").lower() == "strict": + raise + return None + + +def _vocab_parallel_topk_from_local( + local_values: torch.Tensor, + local_tokens: torch.Tensor, + *, + k: int, + log_z: torch.Tensor, + vocab_start: int, +) -> TopK: + local_k = min(k, int(local_values.shape[1])) + local_values = local_values[:, :local_k] - log_z.unsqueeze(1) + local_tokens = local_tokens[:, :local_k] + vocab_start + + from megatron.core import parallel_state as ps + + tp_size = int(ps.get_tensor_model_parallel_world_size()) + if tp_size <= 1: + return TopK(logprobs=local_values, tokens=local_tokens) + + from torch.distributed.nn.functional import all_gather + + group = ps.get_tensor_model_parallel_group(check_initialized=False) + gathered_values = cast(tuple[torch.Tensor, ...], all_gather(local_values, group)) + gathered_tokens = [torch.empty_like(local_tokens) for _ in range(tp_size)] + dist.all_gather(gathered_tokens, local_tokens, group=group) + values = torch.cat(gathered_values, dim=1) + tokens = torch.cat(gathered_tokens, dim=1) + top_values, top_offsets = torch.topk(values, k=k, dim=-1) + return TopK(logprobs=top_values, tokens=tokens.gather(1, top_offsets)) + + +def _vocab_parallel_log_z(local_logits: torch.Tensor) -> torch.Tensor: + local_logits = local_logits.float() + local_max = local_logits.max(dim=-1).values.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + local_sum = _call_compiled(_local_vocab_exp_sum, local_logits, global_max) + global_sum = _all_reduce_tensor_parallel_sum(local_sum) + return global_max + torch.log(global_sum) + + +def _local_vocab_exp_sum( + local_logits: torch.Tensor, + global_max: torch.Tensor, +) -> torch.Tensor: + return torch.exp(local_logits.float() - global_max.unsqueeze(1)).sum(dim=-1) + + +def _vocab_range(local_logits: torch.Tensor) -> tuple[int, int]: + from megatron.core import parallel_state as ps + + local_size = int(local_logits.shape[1]) + rank = int(ps.get_tensor_model_parallel_rank()) + start = rank * local_size + return start, start + local_size + + +def _all_reduce_tensor_parallel_sum(tensor: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return tensor + from torch.distributed.nn.functional import all_reduce + + return cast( + torch.Tensor, + all_reduce( + tensor, + op=dist.ReduceOp.SUM, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ), + ) + + +def _all_reduce_tensor_parallel_max(tensor: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return tensor + output = tensor.clone() + dist.all_reduce( + output, + op=dist.ReduceOp.MAX, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ) + return output + + +def _call_compiled(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + if os.environ.get("ART_TRAINER_RANK_COMPILE", "0").lower() in {"0", "false"}: + return fn(*args, **kwargs) + compiled = _COMPILED_FUNCTIONS.get(fn) + if compiled is None: + compiled = cast(Callable[..., object], torch.compile(fn, dynamic=True)) + _COMPILED_FUNCTIONS[fn] = compiled + try: + return cast(Callable[P, R], compiled)(*args, **kwargs) + except Exception: + return fn(*args, **kwargs) + + +def _matching_offsets( + positions: torch.Tensor, + chunk_rows: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + if int(positions.numel()) == 0 or int(chunk_rows.numel()) == 0: + empty = torch.empty(0, dtype=torch.long, device=positions.device) + return empty, empty + sorted_rows, order = chunk_rows.sort() + indices = torch.searchsorted(sorted_rows, positions) + in_bounds = indices < int(sorted_rows.numel()) + source_offsets = torch.arange( + int(positions.numel()), + device=positions.device, + dtype=torch.long, + )[in_bounds] + found = indices[in_bounds] + keep = sorted_rows.index_select(0, found) == positions.index_select( + 0, + source_offsets, + ) + return source_offsets[keep], order.index_select(0, found[keep]) + + +def _row_match(positions: torch.Tensor, rows: torch.Tensor) -> _RowMatch: + source_offsets, row_offsets = _matching_offsets(positions, rows) + if int(row_offsets.numel()) > 1: + order = row_offsets.argsort() + source_offsets = source_offsets.index_select(0, order) + row_offsets = row_offsets.index_select(0, order) + return _RowMatch(source_offsets=source_offsets, row_offsets=row_offsets) + + +def _match_chunk_offsets( + match: _RowMatch, + *, + start: int, + end: int, +) -> tuple[torch.Tensor, torch.Tensor]: + keep = (match.row_offsets >= start) & (match.row_offsets < end) + source_offsets = match.source_offsets[keep] + return source_offsets, match.row_offsets[keep] - start + + +def _select_positions(values: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + if int(positions.numel()) == 0: + return values[:0] + return values.index_select(0, positions.to(device=values.device)) + + +def _batch_seq_logits(logits: torch.Tensor, *, seq_len: int) -> torch.Tensor: + if int(logits.ndim) != 3: + raise RuntimeError( + f"expected logits with shape [B, S, V] or [S, B, V], got {tuple(logits.shape)}" + ) + if int(logits.shape[0]) == 1 and int(logits.shape[1]) == seq_len: + return logits + if int(logits.shape[0]) == seq_len and int(logits.shape[1]) == 1: + return logits.transpose(0, 1).contiguous() + raise RuntimeError( + f"logits do not match sequence length {seq_len}: {tuple(logits.shape)}" + ) + + +def _materialize(inputs: ForwardInputs) -> ForwardInputs: + if isinstance(inputs, ForwardInput): + return inputs + return [_materialize(item) for item in _nested_forward_children(inputs)] + + +def _flatten(inputs: ForwardInputs) -> Iterator[AnyForwardInput]: + if isinstance(inputs, ForwardInput): + yield inputs + return + for item in _nested_forward_children(inputs): + yield from _flatten(item) + + +def _unflatten( + template: ForwardInputs, outputs: Iterator[AnyForwardOutput] +) -> ForwardOutputs: + if isinstance(template, ForwardInput): + return next(outputs) + return [_unflatten(item, outputs) for item in _nested_forward_children(template)] + + +def _nested_forward_children(inputs: ForwardInputs) -> Iterator[ForwardInputs]: + if isinstance(inputs, Mapping): + raise TypeError( + "dict was passed directly to TrainerRank; gather or materialize the " + "values into a list/tuple so nested forward output ordering is explicit" + ) + if isinstance(inputs, str | bytes): + raise TypeError( + "TrainerRank forward inputs must be ForwardInput objects or nested " + "iterables of ForwardInput objects, not strings" + ) + try: + return iter(cast(Iterable[ForwardInputs], inputs)) + except TypeError as exc: + raise TypeError( + "TrainerRank forward inputs must be ForwardInput objects or nested " + "iterables of ForwardInput objects" + ) from exc + + +__all__ = [ + "AdamParams", + "ForwardInput", + "ForwardOutput", + "MicroBatch", + "MicroBatchStats", + "TopK", + "TrainerRank", + "TrainerRankMemoryError", +] diff --git a/src/art/trainer_rank/topk.py b/src/art/trainer_rank/topk.py new file mode 100644 index 000000000..e0a84722f --- /dev/null +++ b/src/art/trainer_rank/topk.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from typing import Any + +import torch +import triton +import triton.language as tl + +type LocalTopKStats = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +type LocalLogSumExpStats = tuple[torch.Tensor, torch.Tensor] + + +@triton.jit +def _stats_stage1_kernel( + logits_ptr, + partial_max_ptr, + partial_sum_ptr, + partial_values_ptr, + partial_tokens_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + n_blocks: tl.constexpr, + k: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + values = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + + block_max = tl.max(values, axis=0) + block_sum = tl.sum(tl.exp(values - block_max), axis=0) + partial_offset = row * n_blocks + block + tl.store(partial_max_ptr + partial_offset, block_max) + tl.store(partial_sum_ptr + partial_offset, block_sum) + + work = values + arange = tl.arange(0, block_v) + for slot in tl.static_range(0, k): + top_value, top_index = tl.max( + work, + axis=0, + return_indices=True, + return_indices_tie_break_left=True, + ) + output_offset = (partial_offset * k) + slot + tl.store(partial_values_ptr + output_offset, top_value) + tl.store( + partial_tokens_ptr + output_offset, + (block * block_v + top_index).to(tl.int64), + ) + work = tl.where(arange == top_index, -float("inf"), work) + + +@triton.jit +def _stats_stage2_kernel( + partial_max_ptr, + partial_sum_ptr, + partial_values_ptr, + partial_tokens_ptr, + local_max_ptr, + local_sum_ptr, + values_ptr, + tokens_ptr, + n_blocks: tl.constexpr, + k: tl.constexpr, + block_b: tl.constexpr, + block_candidates: tl.constexpr, +): + row = tl.program_id(0) + + block_offsets = tl.arange(0, block_b) + block_mask = block_offsets < n_blocks + partial_base = row * n_blocks + block_max = tl.load( + partial_max_ptr + partial_base + block_offsets, + mask=block_mask, + other=-float("inf"), + ) + row_max = tl.max(block_max, axis=0) + block_sum = tl.load( + partial_sum_ptr + partial_base + block_offsets, + mask=block_mask, + other=0.0, + ) + row_sum = tl.sum(block_sum * tl.exp(block_max - row_max), axis=0) + tl.store(local_max_ptr + row, row_max) + tl.store(local_sum_ptr + row, row_sum) + + if k > 0: + candidate_offsets = tl.arange(0, block_candidates) + candidate_mask = candidate_offsets < n_blocks * k + candidate_base = row * n_blocks * k + candidates = tl.load( + partial_values_ptr + candidate_base + candidate_offsets, + mask=candidate_mask, + other=-float("inf"), + ) + work = candidates + for slot in tl.static_range(0, k): + top_value, top_index = tl.max( + work, + axis=0, + return_indices=True, + return_indices_tie_break_left=True, + ) + output_offset = row * k + slot + tl.store(values_ptr + output_offset, top_value) + tl.store( + tokens_ptr + output_offset, + tl.load(partial_tokens_ptr + candidate_base + top_index), + ) + work = tl.where(candidate_offsets == top_index, -float("inf"), work) + + +@triton.jit +def _stats_backward_kernel( + logits_ptr, + local_max_ptr, + tokens_ptr, + grad_sum_ptr, + grad_values_ptr, + grad_logits_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + k: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + + logits = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + local_max = tl.load(local_max_ptr + row) + grad = tl.load(grad_sum_ptr + row).to(tl.float32) * tl.exp(logits - local_max) + + for slot in tl.static_range(0, k): + token = tl.load(tokens_ptr + row * k + slot) + value_grad = tl.load(grad_values_ptr + row * k + slot).to(tl.float32) + grad += tl.where(offsets == token, value_grad, 0.0) + + tl.store(grad_logits_ptr + row * stride_row + offsets, grad, mask=mask) + + +class _LocalStatsFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, local_logits: torch.Tensor, k: int): + local_max, local_sum, values, tokens = _local_stats_forward(local_logits, k=k) + ctx.save_for_backward(local_logits, local_max, tokens) + ctx.k = k + return local_max, local_sum, values, tokens + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_local_max, grad_local_sum, grad_values, grad_tokens = grad_outputs + del grad_local_max, grad_tokens + logits, local_max, tokens = ctx.saved_tensors + k = int(ctx.k) + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = int(triton.cdiv(vocab_size, block_v)) + + if grad_local_sum is None: + grad_local_sum = torch.zeros_like(local_max) + if grad_values is None: + grad_values = torch.zeros( + (rows, k), + device=logits.device, + dtype=torch.float32, + ) + + grad_logits = torch.empty_like(logits) + _stats_backward_kernel[(rows, n_blocks)]( + logits, + local_max, + tokens, + grad_local_sum.contiguous(), + grad_values.contiguous(), + grad_logits, + logits.stride(0), + vocab_size=vocab_size, # ty: ignore[invalid-argument-type] + k=k, # ty: ignore[invalid-argument-type] + block_v=block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + return grad_logits, None + + +def _check_local_logits(local_logits: torch.Tensor) -> torch.Tensor: + if local_logits.ndim != 2: + raise ValueError( + f"expected [rows, vocab] logits, got {tuple(local_logits.shape)}" + ) + if not local_logits.is_cuda: + raise ValueError("local top-k helpers require CUDA logits") + return local_logits.contiguous() + + +def _local_stats_forward(local_logits: torch.Tensor, *, k: int) -> LocalTopKStats: + logits = _check_local_logits(local_logits) + if k < 0 or k > int(local_logits.shape[1]): + raise ValueError( + f"k={k} is outside local vocab size {int(local_logits.shape[1])}" + ) + + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = int(triton.cdiv(vocab_size, block_v)) + block_b = int(triton.next_power_of_2(n_blocks)) + block_candidates = int(triton.next_power_of_2(n_blocks * k)) if k else 1 + + partial_shape = (rows, n_blocks) + partial_max = torch.empty(partial_shape, device=logits.device, dtype=torch.float32) + partial_sum = torch.empty_like(partial_max) + partial_topk_shape = (rows, n_blocks, k) if k else (1,) + partial_values = torch.empty( + partial_topk_shape, device=logits.device, dtype=torch.float32 + ) + partial_tokens = torch.empty( + partial_topk_shape, device=logits.device, dtype=torch.long + ) + local_max = torch.empty((rows,), device=logits.device, dtype=torch.float32) + local_sum = torch.empty_like(local_max) + values = torch.empty((rows, k), device=logits.device, dtype=torch.float32) + tokens = torch.empty((rows, k), device=logits.device, dtype=torch.long) + + _stats_stage1_kernel[(rows, n_blocks)]( + logits, + partial_max, + partial_sum, + partial_values, + partial_tokens, + stride_row=logits.stride(0), # ty: ignore[invalid-argument-type] + vocab_size=vocab_size, # ty: ignore[invalid-argument-type] + n_blocks=n_blocks, # ty: ignore[invalid-argument-type] + k=k, # ty: ignore[invalid-argument-type] + block_v=block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + _stats_stage2_kernel[(rows,)]( + partial_max, + partial_sum, + partial_values, + partial_tokens, + local_max, + local_sum, + values, + tokens, + n_blocks=n_blocks, # ty: ignore[invalid-argument-type] + k=k, # ty: ignore[invalid-argument-type] + block_b=block_b, # ty: ignore[invalid-argument-type] + block_candidates=block_candidates, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + return local_max, local_sum, values, tokens + + +def local_topk_stats(local_logits: torch.Tensor, *, k: int) -> LocalTopKStats: + logits = local_logits.contiguous() + if not logits.requires_grad: + return _local_stats_forward(logits, k=k) + return _LocalStatsFunction.apply(logits, k) + + +def local_logsumexp_stats(local_logits: torch.Tensor) -> LocalLogSumExpStats: + logits = local_logits.contiguous() + if not logits.requires_grad: + local_max, local_sum, _, _ = _local_stats_forward(logits, k=0) + return local_max, local_sum + local_max, local_sum, _, _ = _LocalStatsFunction.apply(logits, 0) + return local_max, local_sum diff --git a/tests/integration/megatron/lora/test_dynamic_lora_slots.py b/tests/integration/megatron/lora/test_dynamic_lora_slots.py new file mode 100644 index 000000000..253be55f7 --- /dev/null +++ b/tests/integration/megatron/lora/test_dynamic_lora_slots.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from contextlib import contextmanager +import os +import socket +from types import SimpleNamespace + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("megatron.core") + +from megatron.core import parallel_state as ps # noqa: E402 +from torch.distributed import destroy_process_group, init_process_group # noqa: E402 + +from art.megatron.lora import LoRA, LoRASlotRef, use_lora_slot # noqa: E402 +from art.trainer_rank import AdamParams, TrainerRank # noqa: E402 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +def test_dynamic_lora_slots_capture_recompute_context_and_step_independently() -> None: + with _single_rank_model_parallel(): + device = torch.device("cuda") + lora = LoRA( + "dense", + in_features=4, + out_features=5, + rank=2, + alpha=32, + dtype=torch.float32, + device=device, + ) + ref_a = LoRASlotRef("checkpoint", "A") + ref_b = LoRASlotRef("checkpoint", "B") + lora.load_lora_slot( + ref_a, _adapter("dense", rank=1, seed=1), requires_grad=True + ) + lora.load_lora_slot( + ref_b, _adapter("dense", rank=4, seed=2), requires_grad=True + ) + + x = torch.randn(7, 4, device=device) + with use_lora_slot(LoRASlotRef("checkpoint", None)): + assert torch.equal(lora(x), torch.zeros(7, 5, device=device)) + with use_lora_slot(LoRASlotRef("lora", "missing")): + assert torch.equal(lora(x), torch.zeros(7, 5, device=device)) + + slot_a = lora._slot(ref_a) + assert slot_a is not None + with use_lora_slot(ref_a): + actual = lora(x) + expected = (x @ slot_a.A_T) @ slot_a.B_T * slot_a.scale + assert torch.allclose(actual, expected, atol=0, rtol=0) + assert slot_a.rank == 1 + assert slot_a.scale == 32.0 + assert lora._slot(ref_b).scale == 8.0 # type: ignore[union-attr] + + trainer = _trainer_for(lora, device) + with trainer.push_checkpoint("A"): + assert trainer._slot_stack[-1] == ref_a + with trainer.push_lora(None): + assert trainer._slot_stack[-1].name is None + assert trainer._slot_stack[-1] == ref_a + assert trainer._slot_stack == [] + + from megatron.core.tensor_parallel.random import ( + checkpoint as megatron_checkpoint, + ) + from torch.utils.checkpoint import checkpoint as torch_checkpoint + + _assert_checkpoint_recomputes_with(ref_a, ref_b, lora, torch_checkpoint) + _assert_checkpoint_recomputes_with( + ref_a, ref_b, lora, megatron_checkpoint, False + ) + _assert_step_updates_only(ref_a, ref_b, lora, trainer) + _assert_reload_replaces_slot_optimizer(ref_a, lora, trainer) + + +def _adapter(prefix: str, *, rank: int, seed: int) -> dict[str, torch.Tensor]: + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(seed) + return { + f"{prefix}.lora_A.weight": torch.randn( + rank, 4, generator=generator, device=device + ), + f"{prefix}.lora_B.weight": torch.randn( + 5, rank, generator=generator, device=device + ), + } + + +def _assert_checkpoint_recomputes_with( + expected_ref: LoRASlotRef, + ambient_ref: LoRASlotRef, + lora: LoRA, + checkpoint, + *checkpoint_args, +) -> None: + for param in lora.parameters(): + param.grad = None + x = torch.randn(3, 4, device="cuda", requires_grad=True) + with use_lora_slot(expected_ref): + y = checkpoint(lambda t: lora(t), *checkpoint_args, x) + with use_lora_slot(ambient_ref): + y.sum().backward() + assert lora._slot(expected_ref).A_T.grad is not None # type: ignore[union-attr] + assert lora._slot(ambient_ref).A_T.grad is None # type: ignore[union-attr] + + +def _assert_step_updates_only( + stepped_ref: LoRASlotRef, + frozen_ref: LoRASlotRef, + lora: LoRA, + trainer: TrainerRank, +) -> None: + for param in lora.parameters(): + param.grad = None + with use_lora_slot(stepped_ref): + lora(torch.randn(5, 4, device="cuda")).sum().backward() + before_stepped = [p.detach().clone() for p in lora.lora_slot_params(stepped_ref)] + before_frozen = [p.detach().clone() for p in lora.lora_slot_params(frozen_ref)] + trainer.optim_step( + params=AdamParams(learning_rate=1e-3, weight_decay=0.0, grad_clip_norm=1.0), + checkpoints=[stepped_ref.name or ""], + ) + assert any( + not torch.equal(before, after) + for before, after in zip( + before_stepped, lora.lora_slot_params(stepped_ref), strict=True + ) + ) + assert all( + torch.equal(before, after) + for before, after in zip( + before_frozen, lora.lora_slot_params(frozen_ref), strict=True + ) + ) + + +def _assert_reload_replaces_slot_optimizer( + ref: LoRASlotRef, + lora: LoRA, + trainer: TrainerRank, +) -> None: + assert ref.name is not None + old_params = trainer._checkpoint_slot_params_by_name[ref.name] + assert ref.name in trainer._dynamic_optimizers + + trainer.load_checkpoint_slot(ref.name, _adapter("dense", rank=3, seed=9)) + + new_params = trainer._checkpoint_slot_params_by_name[ref.name] + assert ref.name not in trainer._dynamic_optimizers + assert [tuple(param.shape) for param in new_params] == [(4, 3), (3, 5)] + assert all(old is not new for old, new in zip(old_params, new_params, strict=True)) + assert lora._slot(ref).rank == 3 # type: ignore[union-attr] + + +def _trainer_for(lora: LoRA, device: torch.device) -> TrainerRank: + trainer = TrainerRank.__new__(TrainerRank) + trainer.runtime = SimpleNamespace(model=[lora], optimizer=None) + trainer.device = device + trainer._slot_stack = [] + trainer._default_slot_ref = None + trainer._dynamic_optimizers = {} + trainer._checkpoint_slot_params_by_name = { + "A": tuple(lora.lora_slot_params(LoRASlotRef("checkpoint", "A"))), + "B": tuple(lora.lora_slot_params(LoRASlotRef("checkpoint", "B"))), + } + return trainer + + +@contextmanager +def _single_rank_model_parallel(): + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = str(_free_port()) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + torch.cuda.set_device(0) + init_process_group("nccl", rank=0, world_size=1) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py new file mode 100644 index 000000000..a4bcae03e --- /dev/null +++ b/tests/unit/test_trainer_rank_validation.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from art.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + TrainerRankMemoryError, + Unset, + _anchor_disconnected_outputs, + _MemoryCheck, + _MemoryProfile, + _validate_top_k, +) + + +class _Model: + vocab_size = 8 + + +def _runtime(model: torch.nn.Module | None = None) -> SimpleNamespace: + return SimpleNamespace( + model=[model or torch.nn.Linear(1, 1)], + optimizer=None, + provider=SimpleNamespace(hidden_size=4, num_layers=1), + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + + +def _target_request(token: int) -> ForwardInput[torch.Tensor, None, None, None]: + tokens = torch.tensor([token, token + 1], dtype=torch.long) + return ForwardInput(input_tokens=tokens, target_tokens=tokens) + + +def test_forward_input_rejects_non_positive_top_k() -> None: + with pytest.raises(ValueError, match="top_k must be >= 1"): + ForwardInput(input_tokens=torch.tensor([1]), top_k=0) + + +def test_forward_input_adapter_selection_defaults_to_unset() -> None: + request = ForwardInput(input_tokens=torch.tensor([1])) + + assert request.checkpoint is Unset + assert request.lora is Unset + + +def test_forward_input_accepts_explicit_base_checkpoint() -> None: + request = ForwardInput(input_tokens=torch.tensor([1]), checkpoint=None) + + assert request.checkpoint is None + assert request.lora is Unset + + +def test_forward_input_rejects_checkpoint_and_lora_together() -> None: + with pytest.raises(ValueError, match="cannot set both checkpoint and lora"): + ForwardInput(input_tokens=torch.tensor([1]), checkpoint="a", lora="b") + + +def test_validate_top_k_rejects_values_above_vocab_size() -> None: + with pytest.raises(ValueError, match="top_k=9 exceeds vocabulary size 8"): + _validate_top_k(9, _Model()) # type: ignore[arg-type] + + +def test_trainer_rank_accepts_nested_shared_prefix_for_gdn_runtime() -> None: + trainer = TrainerRank(_runtime(), shared_prefix_max_depth=2) # type: ignore[arg-type] + + assert trainer.shared_prefix_max_depth == 2 + + +def test_trainer_rank_accepts_zero_depth_shared_prefix_for_gdn_runtime() -> None: + trainer = TrainerRank(_runtime(), shared_prefix_max_depth=0) # type: ignore[arg-type] + + assert trainer.shared_prefix_max_depth == 0 + + +def test_trainer_rank_pop_rejects_empty_adapter_stack() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + + with pytest.raises(RuntimeError, match="No pushed LoRA or checkpoint"): + trainer.pop_pushed_lora_or_checkpoint() + + +def test_dp_rank_forward_preserves_nested_shape_for_inactive_requests() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + request_a = ForwardInput(input_tokens=torch.tensor([1])) + request_b = ForwardInput(input_tokens=torch.tensor([2])) + + outputs = trainer.dp_rank_forward([[request_a], [request_b]]) + + assert len(outputs) == 2 + assert len(outputs[0]) == 1 + assert outputs[0][0].target_logprobs is None + assert outputs[1][0].target_logprobs is None + assert not hasattr(trainer, "forward") + assert not hasattr(trainer, "micro_batches") + + +def test_forward_micro_batches_uses_deterministic_dp_windows( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (1, 2)) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batches = list( + trainer.forward_micro_batches([_target_request(i) for i in range(5)]) + ) + + assert [batch.indices for batch in batches] == [(1,), (3,), ()] + assert [len(batch.outputs) for batch in batches] == [1, 1, 0] + + +def test_forward_micro_batches_outputs_match_top_level_nested_inputs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + nested = [[_target_request(1), _target_request(3)]] + batch = next(iter(trainer.forward_micro_batches(nested))) + + assert batch.inputs == nested + assert len(batch.outputs) == 1 + assert len(batch.outputs[0]) == 2 + + +def test_forward_micro_batches_ramps_after_first_success( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + + def run(plan, **_kwargs): + trainer._memory_profiles[plan.signature] = _MemoryProfile( + bytes_per_token=0.0, + packed_tokens=plan.packed_tokens, + ) + return [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ] + + monkeypatch.setattr(trainer, "_run_flat_plan_with_memory_tracking", run) + + batches = list( + trainer.forward_micro_batches([_target_request(i) for i in range(8)]) + ) + + assert batches[0].stats.global_count == 1 + assert batches[0].stats.cold_start + assert batches[1].stats.global_count > 1 + assert not batches[1].stats.cold_start + + +def test_forward_micro_batches_does_not_overtrust_tiny_memory_profile( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + inputs = [_target_request(i) for i in range(64)] + tiny_plan = trainer._plan_flat_forward([inputs[0]]) + trainer._memory_profiles[tiny_plan.signature] = _MemoryProfile( + bytes_per_token=0.0, + packed_tokens=tiny_plan.packed_tokens, + ) + + candidate = trainer._select_next_micro_batch(inputs, 0) + + assert candidate.stats_global_count == 8 + assert candidate.plan.packed_tokens == 16 + + +def test_forward_micro_batches_shrinks_to_largest_fitting_window( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._last_global_micro_batch_size = 4 + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + + def required_memory(**kwargs): + return kwargs["packed_tokens"] + + def memory_check(required): + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=6, + fits=required <= 6, + ) + + monkeypatch.setattr( + trainer, "_estimate_required_memory_bytes_from_values", required_memory + ) + monkeypatch.setattr(trainer, "_memory_check_required", memory_check) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batch = next( + iter(trainer.forward_micro_batches([_target_request(i) for i in range(8)])) + ) + + assert batch.stats.global_count == 3 + assert batch.stats.rejected_candidates >= 1 + + +def test_forward_micro_batches_tail_does_not_reset_stable_window( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._last_global_micro_batch_size = 64 + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + trainer, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=128, + fits=required <= 128, + ), + ) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batches = list( + trainer.forward_micro_batches([_target_request(i) for i in range(130)]) + ) + + assert [batch.stats.global_count for batch in batches] == [64, 64, 2] + assert trainer._last_global_micro_batch_size == 64 + + +def test_forward_micro_batches_grows_small_stable_window_when_work_remains( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._last_global_micro_batch_size = 64 + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + trainer, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=512, + fits=required <= 512, + ), + ) + + candidate = trainer._select_next_micro_batch( + [_target_request(i) for i in range(512)], + 0, + ) + + assert candidate.stats_global_count == 256 + + +def test_forward_micro_batches_reuses_cached_candidate_plans( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + original_plan = trainer._plan_flat_forward + plan_calls = 0 + memory_checks = 0 + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + def memory_check(plan): + nonlocal memory_checks + memory_checks += 1 + return _MemoryCheck( + estimated_required_bytes=plan.packed_tokens, + available_bytes=10, + fits=True, + ) + + monkeypatch.setattr(trainer, "_plan_flat_forward", plan) + monkeypatch.setattr(trainer, "_memory_check", memory_check) + inputs = [_target_request(i) for i in range(8)] + + list(trainer.forward_micro_batches(inputs)) + first_plan_calls = plan_calls + first_memory_checks = memory_checks + list(trainer.forward_micro_batches(inputs)) + + assert first_plan_calls > 0 + assert first_plan_calls == 1 + assert plan_calls == first_plan_calls + assert memory_checks == first_memory_checks == 0 + + +def test_forward_micro_batches_raises_when_smallest_batch_will_not_fit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **_kwargs: 4, + ) + monkeypatch.setattr( + trainer, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=3, + fits=False, + ), + ) + with pytest.raises(TrainerRankMemoryError, match="smallest DP microbatch"): + next(iter(trainer.forward_micro_batches([_target_request(1)]))) + + +def test_forward_micro_batches_rejects_mismatched_replicated_counts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + import art.trainer_rank as trainer_rank + + monkeypatch.setattr(trainer_rank.dist, "is_available", lambda: True) + monkeypatch.setattr(trainer_rank.dist, "is_initialized", lambda: True) + monkeypatch.setattr(trainer_rank.dist, "get_world_size", lambda: 2) + + def gather(output, value): + output[:] = [value, value + 1] + + monkeypatch.setattr(trainer_rank.dist, "all_gather_object", gather) + + with pytest.raises(ValueError, match="same top-level input count"): + list(trainer.forward_micro_batches([_target_request(1)])) + + +def test_forward_plan_estimates_output_memory_for_request_combo() -> None: + class FakeGPT(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros(())) + self.config = SimpleNamespace( + hidden_size=4, + num_layers=1, + padded_vocab_size=10, + ) + self.decoder = object() + + def _preprocess(self, *args: object, **kwargs: object) -> None: + return None + + trainer = TrainerRank(_runtime(FakeGPT())) # type: ignore[arg-type] + tokens = torch.tensor([1, 2, 3], dtype=torch.long) + labels = torch.stack((tokens, tokens + 1), dim=1) + + plan = trainer._plan_flat_forward( + [ + ForwardInput( + input_tokens=tokens, + target_tokens=labels, + top_k=5, + logits=True, + hidden_states=True, + ) + ] + ) + + target_bytes = 3 * 2 * 4 + topk_bytes = 3 * 5 * (4 + 8) + logits_bytes = 3 * 10 * 4 + hidden_bytes = 3 * 4 * 4 + assert plan.output_bytes == target_bytes + topk_bytes + logits_bytes + hidden_bytes + + +def test_disconnected_outputs_keep_zero_graph_anchor() -> None: + hidden = torch.randn(2, 3, requires_grad=True) + disconnected = torch.zeros(4) + top_k = TopK(logprobs=torch.zeros(4, 2), tokens=torch.ones(4, 2, dtype=torch.long)) + + (anchored,), (anchored_top_k,) = _anchor_disconnected_outputs( + [disconnected], + [top_k], + hidden, + ) + + assert anchored is not None + assert anchored.requires_grad + assert anchored_top_k is not None + assert anchored_top_k.logprobs.requires_grad + torch.testing.assert_close(anchored, disconnected) + torch.testing.assert_close(anchored_top_k.logprobs, top_k.logprobs) + (anchored.sum() + anchored_top_k.logprobs.sum()).backward() + assert hidden.grad is not None + torch.testing.assert_close(hidden.grad, torch.zeros_like(hidden)) diff --git a/tests/unit/test_trainer_rank_weird_shapes.py b/tests/unit/test_trainer_rank_weird_shapes.py new file mode 100644 index 000000000..02831c1c0 --- /dev/null +++ b/tests/unit/test_trainer_rank_weird_shapes.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from collections.abc import Iterable +from types import SimpleNamespace + +import pytest +import torch + +from art.megatron.shared_prefix_packing import ( + estimate_shared_prefix_packed_tokens, + pack_shared_prefixes, +) +from art.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + TrainerRankMemoryError, + Unset, + _flatten, + _MemoryCheck, +) + + +class _FakeGPT(torch.nn.Module): + def __init__(self, *, hidden_size: int = 8, vocab_size: int = 32) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros((), dtype=torch.float16)) + self.config = SimpleNamespace( + hidden_size=hidden_size, + num_layers=4, + padded_vocab_size=vocab_size, + ) + self.decoder = object() + + def _preprocess(self, *args: object, **kwargs: object) -> None: + return None + + +def _runtime() -> SimpleNamespace: + return SimpleNamespace( + model=[_FakeGPT()], + optimizer=None, + provider=SimpleNamespace(hidden_size=8, num_layers=4), + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + + +def _tokens(*values: int) -> torch.Tensor: + return torch.tensor(values, dtype=torch.long) + + +def _target_request( + tokens: torch.Tensor, + *, + target_count: int = 1, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: object = Unset, + lora: object = Unset, +) -> ForwardInput: + labels = ( + tokens + if target_count == 1 + else torch.stack( + tuple(tokens + offset for offset in range(target_count)), + dim=-1, + ) + ) + return ForwardInput( + input_tokens=tokens, + target_tokens=labels, + top_k=top_k, + logits=logits, + hidden_states=hidden_states, + checkpoint=checkpoint, # type: ignore[arg-type] + lora=lora, # type: ignore[arg-type] + ) + + +def _ternary_tree_sequences() -> tuple[torch.Tensor, ...]: + # Shape: shared root, two continuation branches, and terminal nodes at + # several depths. This mirrors prompt -> continuation A/B -> terminal data. + root = [10, 11, 12] + left = root + [20, 21] + right = root + [30, 31, 32] + return ( + _tokens(*(root + [1])), + _tokens(*(left + [2])), + _tokens(*(left + [3, 4])), + _tokens(*(right + [5])), + _tokens(*(right + [6, 7])), + _tokens(80, 81), + ) + + +def _vineppo_like_inputs() -> list[list[ForwardInput]]: + groups: list[list[ForwardInput]] = [] + for prompt_index in range(4): + prompt = [100 + prompt_index, 200 + prompt_index, 201 + prompt_index] + trajectories = [] + for branch_index, completion_len in enumerate((1, 2, 4)): + completion = [300 + branch_index] * completion_len + tokens = _tokens(*(prompt + completion)) + trajectories.append( + _target_request( + tokens, + target_count=2 if branch_index == 2 else 1, + top_k=5 if branch_index == 1 else None, + hidden_states=branch_index == 0, + ) + ) + groups.append(trajectories) + return groups + + +def _random_tree_sequences(seed: int, *, max_depth: int) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + out: list[torch.Tensor] = [] + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def segment(depth: int) -> list[int]: + return [depth * 100 + randint(1, 40) for _ in range(randint(1, 4))] + + def walk(prefix: list[int], depth: int) -> None: + if depth >= max_depth or randint(0, 2) == 0: + out.append(_tokens(*(prefix + segment(depth)))) + return + shared = prefix + segment(depth) + out.append(_tokens(*shared)) + walk(shared + [10 + depth], depth + 1) + walk(shared + [20 + depth], depth + 1) + + walk([], 0) + return tuple(out) + + +@pytest.mark.parametrize("max_depth", (0, 1, 2, 4)) +def test_pack_estimator_matches_ternary_and_random_trees(max_depth: int) -> None: + cases = [ + _ternary_tree_sequences(), + _random_tree_sequences(3, max_depth=4), + _random_tree_sequences(99, max_depth=5), + ] + + for sequences in cases: + pack = pack_shared_prefixes(sequences, max_depth=max_depth) + + assert estimate_shared_prefix_packed_tokens( + sequences, max_depth=max_depth + ) == int(pack.tokens.numel()) + for sequence, positions in zip( + sequences, pack.positions_by_sequence, strict=True + ): + torch.testing.assert_close(pack.tokens.reshape(-1)[positions], sequence) + + +def test_planner_handles_vineppo_nested_shape_and_request_mix() -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=3) # type: ignore[arg-type] + inputs = _vineppo_like_inputs() + flat = list(_flatten(inputs)) + + plan = rank._plan_flat_forward(flat) + estimate = rank._estimate_flat_forward(flat) + + assert estimate is not None + packed_tokens, output_bytes, signature = estimate + assert packed_tokens == plan.packed_tokens + assert output_bytes == plan.output_bytes + assert signature == plan.signature + assert plan.request_count == 12 + assert plan.signature.request_mix == ( + "target:(2,)", + "target:single+hidden", + "target:single+topk:5", + ) + + +def test_forward_micro_batches_preserves_nested_vineppo_groups( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=2) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + monkeypatch.setattr( + rank, + "_memory_check", + lambda plan: _MemoryCheck(plan.packed_tokens, 10_000, True), + ) + monkeypatch.setattr( + rank, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + groups = _vineppo_like_inputs() + + micro_batches = list(rank.forward_micro_batches(groups)) + + assert [batch.indices for batch in micro_batches] == [(0, 1, 2, 3)] + assert micro_batches[0].select(groups) == groups + assert len(micro_batches[0].outputs) == 4 + assert all( + isinstance(group_outputs, list) and len(group_outputs) == 3 + for group_outputs in micro_batches[0].outputs + ) + + +def test_adaptive_planner_materializes_only_final_large_candidate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=3) # type: ignore[arg-type] + rank._last_global_micro_batch_size = 32 + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + plan_calls = 0 + estimate_calls = 0 + original_plan = rank._plan_flat_forward + original_estimate = rank._estimate_flat_forward + inputs = [ + _target_request( + _tokens(1, 2, 3, index % 7, index), + target_count=2 if index % 5 == 0 else 1, + top_k=3 if index % 4 == 0 else None, + hidden_states=index % 9 == 0, + ) + for index in range(96) + ] + limit = rank._estimate_flat_forward(inputs[:40]) + assert limit is not None + limit_packed_tokens = limit[0] + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + def estimate(requests): + nonlocal estimate_calls + estimate_calls += 1 + return original_estimate(requests) + + def required_memory(**kwargs): + return kwargs["packed_tokens"] + + def check(required): + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=limit_packed_tokens, + fits=required <= limit_packed_tokens, + ) + + monkeypatch.setattr(rank, "_plan_flat_forward", plan) + monkeypatch.setattr(rank, "_estimate_flat_forward", estimate) + monkeypatch.setattr( + rank, "_estimate_required_memory_bytes_from_values", required_memory + ) + monkeypatch.setattr(rank, "_memory_check_required", check) + + candidate = rank._select_next_micro_batch(inputs, 0) + + assert candidate.stats_global_count == 40 + assert plan_calls == 1 + assert estimate_calls <= 10 + assert candidate.rejected_candidates <= 8 + + +def test_adaptive_planner_reuses_large_stable_window( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=1) # type: ignore[arg-type] + rank._last_global_micro_batch_size = 512 + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + monkeypatch.setattr( + rank, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + rank, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=700, + fits=required <= 700, + ), + ) + + candidate = rank._select_next_micro_batch( + [_target_request(_tokens(index)) for index in range(900)], + 0, + ) + + assert candidate.stats_global_count == 512 + assert candidate.rejected_candidates == 0 + + +def test_forward_micro_batches_shrinks_when_memory_budget_drops( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=2) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + inputs = [_target_request(_tokens(1, 2, 3, index)) for index in range(14)] + first_limit = rank._estimate_flat_forward(inputs[:8]) + tail_limit = rank._estimate_flat_forward(inputs[8:11]) + assert first_limit is not None + assert tail_limit is not None + first_limit_packed_tokens = first_limit[0] + tail_limit_packed_tokens = tail_limit[0] + available = {"packed_tokens": first_limit_packed_tokens} + plan_calls = 0 + original_plan = rank._plan_flat_forward + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + def required_memory(**kwargs): + return kwargs["packed_tokens"] + + def check(required): + limit = available["packed_tokens"] + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=limit, + fits=required <= limit, + ) + + def run(plan, **_kwargs): + if available["packed_tokens"] == first_limit_packed_tokens: + available["packed_tokens"] = tail_limit_packed_tokens + return [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ] + + monkeypatch.setattr(rank, "_plan_flat_forward", plan) + monkeypatch.setattr( + rank, "_estimate_required_memory_bytes_from_values", required_memory + ) + monkeypatch.setattr(rank, "_memory_check_required", check) + monkeypatch.setattr(rank, "_run_flat_plan_with_memory_tracking", run) + + batches = list(rank.forward_micro_batches(inputs)) + + assert [batch.stats.global_count for batch in batches] == [8, 3, 3] + assert [batch.stats.available_bytes for batch in batches] == [ + first_limit_packed_tokens, + tail_limit_packed_tokens, + tail_limit_packed_tokens, + ] + assert [batch.indices for batch in batches] == [ + tuple(range(8)), + (8, 9, 10), + (11, 12, 13), + ] + assert plan_calls == len(batches) + + +def test_heterogeneous_slots_split_packing_without_losing_output_estimates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=4) # type: ignore[arg-type] + monkeypatch.setattr( + TrainerRank, + "_slot_ref", + staticmethod(lambda kind, name: (kind, name)), + ) + rank.set_checkpoint("student") + requests = [ + _target_request(_tokens(1, 2, 3), top_k=3), + _target_request(_tokens(1, 2, 4), checkpoint=None, logits=True), + _target_request(_tokens(1, 2, 5), lora="teacher", hidden_states=True), + _target_request(_tokens(1, 2, 6), checkpoint="critic", target_count=4), + ] + + plan = rank._plan_flat_forward(requests) + estimate = rank._estimate_flat_forward(requests) + + assert estimate is not None + packed_tokens, output_bytes, signature = estimate + assert packed_tokens == plan.packed_tokens + assert output_bytes == plan.output_bytes + assert signature == plan.signature + assert plan.signature.slot_group_count == 4 + assert {group.slot_ref for group in plan.groups} == { + ("checkpoint", "student"), + ("checkpoint", None), + ("lora", "teacher"), + ("checkpoint", "critic"), + } + + +def test_dp_uneven_tail_yields_empty_rank_batch( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (3, 4)) + monkeypatch.setattr( + rank, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batches = list( + rank.forward_micro_batches( + [_target_request(_tokens(i, i + 1)) for i in range(5)] + ) + ) + + assert [batch.indices for batch in batches] == [(3,), ()] + assert [batch.stats.local_count for batch in batches] == [1, 0] + + +def test_dp_rank_forward_raises_before_expected_oom( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr( + rank, + "_memory_check", + lambda plan: _MemoryCheck( + estimated_required_bytes=plan.output_bytes + 1, + available_bytes=plan.output_bytes, + fits=False, + ), + ) + + with pytest.raises(TrainerRankMemoryError, match="dp_rank_forward"): + rank.dp_rank_forward( + [_target_request(_tokens(1, 2, 3), logits=True, hidden_states=True)] + ) + + +def test_memory_error_includes_actionable_shape_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + rank, + "_estimate_required_memory_bytes_from_values", + lambda **_kwargs: 99, + ) + monkeypatch.setattr( + rank, + "_memory_check_required", + lambda required: _MemoryCheck(required, 1, False), + ) + + with pytest.raises(TrainerRankMemoryError) as exc_info: + next( + iter( + rank.forward_micro_batches( + [_target_request(_tokens(1, 2, 3), logits=True)] + ) + ) + ) + + message = str(exc_info.value) + assert "packed_tokens=" in message + assert "logical_tokens=" in message + assert "output_gb=" in message + assert "Use smaller top-level items" in message + + +def test_topk_output_memory_scales_with_requested_k() -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + tokens = _tokens(1, 2, 3, 4) + + small = rank._plan_flat_forward([_target_request(tokens, top_k=1)]) + large = rank._plan_flat_forward([_target_request(tokens, top_k=7)]) + + assert large.output_bytes - small.output_bytes == 4 * 6 * (4 + 8) + + +def test_flatten_rejects_dicts_to_avoid_silent_top_level_shape_changes() -> None: + with pytest.raises(TypeError, match="dict was passed directly"): + list(_flatten({"bad": _target_request(_tokens(1, 2))})) # type: ignore[arg-type] + + +def test_no_output_requests_do_not_pack_or_consume_compute_memory() -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + requests: Iterable[ForwardInput] = [ + ForwardInput(input_tokens=_tokens(1, 2, 3)), + ForwardInput(input_tokens=_tokens(1, 2, 4)), + ] + + plan = rank._plan_flat_forward(list(requests)) + + assert plan.groups == () + assert plan.packed_tokens == 0 + assert rank._memory_check(plan).estimated_required_bytes == 0 From dc7116ec1a6d60ff968b569a8cbd0d710646f8db Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 20:08:46 -0600 Subject: [PATCH 02/11] fix trainer rank dp microbatch memory sync --- src/art/trainer_rank/__init__.py | 19 ++++++--- tests/unit/test_trainer_rank_validation.py | 45 +++++++++++++++++--- tests/unit/test_trainer_rank_weird_shapes.py | 14 +++--- 3 files changed, 62 insertions(+), 16 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index a3fe29bad..8cc93e621 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -745,7 +745,7 @@ def candidate( inputs=local_inputs, indices=indices, plan=plan, - check=estimated_check or self._memory_check(plan), + check=estimated_check or self._memory_check(plan, sync_across_dp=True), stats_global_count=width, rejected_candidates=rejected, cold_start=not self._all_ranks_have_memory_profile( @@ -873,7 +873,8 @@ def _cached_adaptive_estimate( packed_tokens=packed_tokens, output_bytes=output_bytes, signature=signature, - ) + ), + sync_across_dp=True, ), self._all_ranks_have_memory_profile( packed_tokens=packed_tokens, @@ -1093,19 +1094,27 @@ def _topology_key(self) -> tuple[int, int, int, int]: def _memory_check( self, forward: _FlatForwardPlan, + *, + sync_across_dp: bool = False, ) -> _MemoryCheck: return self._memory_check_required( self._estimate_required_memory_bytes_from_values( packed_tokens=forward.packed_tokens, output_bytes=forward.output_bytes, signature=forward.signature, - ) + ), + sync_across_dp=sync_across_dp, ) - def _memory_check_required(self, required: int) -> _MemoryCheck: + def _memory_check_required( + self, + required: int, + *, + sync_across_dp: bool = False, + ) -> _MemoryCheck: available = self._available_memory_bytes() if dist.is_available() and dist.is_initialized(): - group = self._forward_memory_group() + group = None if sync_across_dp else self._forward_memory_group() values = torch.tensor( [float(required), float(available)], device=self.device if self.device.type == "cuda" else "cpu", diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index a4bcae03e..2f9dc6641 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -121,6 +121,39 @@ def test_forward_micro_batches_uses_deterministic_dp_windows( assert [len(batch.outputs) for batch in batches] == [1, 1, 0] +def test_forward_micro_batches_syncs_fit_decision_across_dp( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (1, 2)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + sync_flags: list[bool] = [] + + def memory_check(required: int, *, sync_across_dp: bool = False) -> _MemoryCheck: + sync_flags.append(sync_across_dp) + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=1 << 30, + fits=True, + ) + + monkeypatch.setattr(trainer, "_memory_check_required", memory_check) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + next(iter(trainer.forward_micro_batches([_target_request(i) for i in range(6)]))) + + assert sync_flags + assert all(sync_flags) + + def test_forward_micro_batches_outputs_match_top_level_nested_inputs( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -200,7 +233,8 @@ def test_forward_micro_batches_shrinks_to_largest_fitting_window( def required_memory(**kwargs): return kwargs["packed_tokens"] - def memory_check(required): + def memory_check(required, *, sync_across_dp=False): + assert sync_across_dp return _MemoryCheck( estimated_required_bytes=required, available_bytes=6, @@ -244,7 +278,7 @@ def test_forward_micro_batches_tail_does_not_reset_stable_window( monkeypatch.setattr( trainer, "_memory_check_required", - lambda required: _MemoryCheck( + lambda required, *, sync_across_dp=False: _MemoryCheck( estimated_required_bytes=required, available_bytes=128, fits=required <= 128, @@ -283,7 +317,7 @@ def test_forward_micro_batches_grows_small_stable_window_when_work_remains( monkeypatch.setattr( trainer, "_memory_check_required", - lambda required: _MemoryCheck( + lambda required, *, sync_across_dp=False: _MemoryCheck( estimated_required_bytes=required, available_bytes=512, fits=required <= 512, @@ -322,7 +356,8 @@ def plan(requests): plan_calls += 1 return original_plan(requests) - def memory_check(plan): + def memory_check(plan, *, sync_across_dp=False): + assert sync_across_dp nonlocal memory_checks memory_checks += 1 return _MemoryCheck( @@ -359,7 +394,7 @@ def test_forward_micro_batches_raises_when_smallest_batch_will_not_fit( monkeypatch.setattr( trainer, "_memory_check_required", - lambda required: _MemoryCheck( + lambda required, *, sync_across_dp=False: _MemoryCheck( estimated_required_bytes=required, available_bytes=3, fits=False, diff --git a/tests/unit/test_trainer_rank_weird_shapes.py b/tests/unit/test_trainer_rank_weird_shapes.py index 02831c1c0..541d2de9a 100644 --- a/tests/unit/test_trainer_rank_weird_shapes.py +++ b/tests/unit/test_trainer_rank_weird_shapes.py @@ -188,7 +188,9 @@ def test_forward_micro_batches_preserves_nested_vineppo_groups( monkeypatch.setattr( rank, "_memory_check", - lambda plan: _MemoryCheck(plan.packed_tokens, 10_000, True), + lambda plan, *, sync_across_dp=False: _MemoryCheck( + plan.packed_tokens, 10_000, True + ), ) monkeypatch.setattr( rank, @@ -247,7 +249,7 @@ def estimate(requests): def required_memory(**kwargs): return kwargs["packed_tokens"] - def check(required): + def check(required, *, sync_across_dp=False): return _MemoryCheck( estimated_required_bytes=required, available_bytes=limit_packed_tokens, @@ -284,7 +286,7 @@ def test_adaptive_planner_reuses_large_stable_window( monkeypatch.setattr( rank, "_memory_check_required", - lambda required: _MemoryCheck( + lambda required, *, sync_across_dp=False: _MemoryCheck( estimated_required_bytes=required, available_bytes=700, fits=required <= 700, @@ -325,7 +327,7 @@ def plan(requests): def required_memory(**kwargs): return kwargs["packed_tokens"] - def check(required): + def check(required, *, sync_across_dp=False): limit = available["packed_tokens"] return _MemoryCheck( estimated_required_bytes=required, @@ -427,7 +429,7 @@ def test_dp_rank_forward_raises_before_expected_oom( monkeypatch.setattr( rank, "_memory_check", - lambda plan: _MemoryCheck( + lambda plan, *, sync_across_dp=False: _MemoryCheck( estimated_required_bytes=plan.output_bytes + 1, available_bytes=plan.output_bytes, fits=False, @@ -453,7 +455,7 @@ def test_memory_error_includes_actionable_shape_context( monkeypatch.setattr( rank, "_memory_check_required", - lambda required: _MemoryCheck(required, 1, False), + lambda required, *, sync_across_dp=False: _MemoryCheck(required, 1, False), ) with pytest.raises(TrainerRankMemoryError) as exc_info: From f96f77df477278afa107d9e98d735023085307f2 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 20:16:08 -0600 Subject: [PATCH 03/11] fix trainer rank adaptive memory estimate cache --- src/art/trainer_rank/__init__.py | 37 +++++++++---------- tests/unit/test_trainer_rank_validation.py | 42 ++++++++++++++++++++++ 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 8cc93e621..57ecc83e1 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -283,7 +283,7 @@ def __init__( self._adaptive_plan_cache: dict[_AdaptivePlanCacheKey, _FlatForwardPlan] = {} self._adaptive_plan_cache_top_level_ids: tuple[int, ...] = () self._adaptive_estimate_cache: dict[ - _AdaptivePlanCacheKey, tuple[_MemoryCheck, bool] | None + _AdaptivePlanCacheKey, tuple[int, int, _MemorySignature] | None ] = {} self._last_global_micro_batch_size: int | None = None self.zero_grad() @@ -863,26 +863,27 @@ def _cached_adaptive_estimate( ) -> tuple[_MemoryCheck, bool] | None: key = self._adaptive_cache_key(indices) if key in self._adaptive_estimate_cache: - return self._adaptive_estimate_cache[key] - estimate = self._estimate_flat_forward(list(_flatten(local_inputs))) - if estimate is not None: - packed_tokens, output_bytes, signature = estimate - estimate = ( - self._memory_check_required( - self._estimate_required_memory_bytes_from_values( - packed_tokens=packed_tokens, - output_bytes=output_bytes, - signature=signature, - ), - sync_across_dp=True, - ), - self._all_ranks_have_memory_profile( + estimate = self._adaptive_estimate_cache[key] + else: + estimate = self._estimate_flat_forward(list(_flatten(local_inputs))) + self._adaptive_estimate_cache[key] = estimate + if estimate is None: + return None + packed_tokens, output_bytes, signature = estimate + return ( + self._memory_check_required( + self._estimate_required_memory_bytes_from_values( packed_tokens=packed_tokens, + output_bytes=output_bytes, signature=signature, ), - ) - self._adaptive_estimate_cache[key] = estimate - return estimate + sync_across_dp=True, + ), + self._all_ranks_have_memory_profile( + packed_tokens=packed_tokens, + signature=signature, + ), + ) def _adaptive_cache_key( self, diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index 2f9dc6641..252ac85f6 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -381,6 +381,48 @@ def memory_check(plan, *, sync_across_dp=False): assert memory_checks == first_memory_checks == 0 +def test_cached_adaptive_estimate_rechecks_current_memory( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + original_estimate = trainer._estimate_flat_forward + estimate_calls = 0 + available = [1 << 30, 1] + + def estimate(requests): + nonlocal estimate_calls + estimate_calls += 1 + return original_estimate(requests) + + def memory_check(required: int, *, sync_across_dp: bool = False) -> _MemoryCheck: + assert sync_across_dp + current = available.pop(0) + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=current, + fits=required <= current, + ) + + monkeypatch.setattr(trainer, "_estimate_flat_forward", estimate) + monkeypatch.setattr(trainer, "_memory_check_required", memory_check) + inputs = [_target_request(1), _target_request(2)] + + first = trainer._cached_adaptive_estimate((0, 1), inputs) + second = trainer._cached_adaptive_estimate((0, 1), inputs) + + assert first is not None and first[0].fits + assert second is not None and not second[0].fits + assert estimate_calls == 1 + + def test_forward_micro_batches_raises_when_smallest_batch_will_not_fit( monkeypatch: pytest.MonkeyPatch, ) -> None: From 91c9dc7e62d84c002cbd2d7bd97d73eaee4dfa36 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 20:21:58 -0600 Subject: [PATCH 04/11] simplify trainer rank adaptive plan cache scope --- src/art/trainer_rank/__init__.py | 8 +--- tests/unit/test_trainer_rank_validation.py | 43 ++++++++++++++++++---- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 57ecc83e1..4ffb176b2 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -281,7 +281,6 @@ def __init__( ] = {} self._memory_profiles: dict[_MemorySignature, _MemoryProfile] = {} self._adaptive_plan_cache: dict[_AdaptivePlanCacheKey, _FlatForwardPlan] = {} - self._adaptive_plan_cache_top_level_ids: tuple[int, ...] = () self._adaptive_estimate_cache: dict[ _AdaptivePlanCacheKey, tuple[int, int, _MemorySignature] | None ] = {} @@ -454,6 +453,8 @@ def forward_micro_batches( inputs: Iterable[ForwardInputsT], ) -> Iterator[MicroBatch[ForwardInputsT]]: items = list(inputs) + self._adaptive_plan_cache.clear() + self._adaptive_estimate_cache.clear() self._validate_replicated_top_level_count(len(items)) start = 0 while start < len(items): @@ -704,11 +705,6 @@ def _select_next_micro_batch( min_width = min(dp_size, remaining) if min_width <= 0: raise RuntimeError("cannot select an empty microbatch window") - top_level_ids = tuple(id(item) for item in items) - if top_level_ids != self._adaptive_plan_cache_top_level_ids: - self._adaptive_plan_cache.clear() - self._adaptive_estimate_cache.clear() - self._adaptive_plan_cache_top_level_ids = top_level_ids def clamp_width(width: int) -> int: return max(min_width, min(width, remaining)) diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index 252ac85f6..23fce9b0a 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -332,7 +332,7 @@ def test_forward_micro_batches_grows_small_stable_window_when_work_remains( assert candidate.stats_global_count == 256 -def test_forward_micro_batches_reuses_cached_candidate_plans( +def test_forward_micro_batches_avoids_packing_rejected_candidates( monkeypatch: pytest.MonkeyPatch, ) -> None: trainer = TrainerRank(_runtime()) # type: ignore[arg-type] @@ -370,15 +370,44 @@ def memory_check(plan, *, sync_across_dp=False): monkeypatch.setattr(trainer, "_memory_check", memory_check) inputs = [_target_request(i) for i in range(8)] + batches = list(trainer.forward_micro_batches(inputs)) + + assert [batch.stats.global_count for batch in batches] == [8] + assert plan_calls == 1 + assert memory_checks == 0 + + +def test_forward_micro_batches_replans_reused_input_list( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + original_plan = trainer._plan_flat_forward + plan_calls = 0 + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + monkeypatch.setattr(trainer, "_plan_flat_forward", plan) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + inputs = [_target_request(1)] + list(trainer.forward_micro_batches(inputs)) - first_plan_calls = plan_calls - first_memory_checks = memory_checks + inputs[0] = _target_request(10) list(trainer.forward_micro_batches(inputs)) - assert first_plan_calls > 0 - assert first_plan_calls == 1 - assert plan_calls == first_plan_calls - assert memory_checks == first_memory_checks == 0 + assert plan_calls == 2 def test_cached_adaptive_estimate_rechecks_current_memory( From 1da2e8c9ff7948a747f64daabe0dbd985e824793 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 22:05:58 -0600 Subject: [PATCH 05/11] fix: harden trainer rank edge cases --- src/art/trainer_rank/__init__.py | 17 +++++++++++------ tests/unit/test_trainer_rank_validation.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 4ffb176b2..e131d6c13 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -364,6 +364,8 @@ def _load_slot( trainable: bool, alpha: float | None, ) -> int: + if self._slot_stack: + raise RuntimeError("Cannot load a LoRA/checkpoint while a slot is pushed") from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model return load_lora_slot_into_model( @@ -1352,12 +1354,15 @@ def _project_head( dtype=torch.float32, ) if item.request.top_k is None and not item.request.logits: - valid = labels != -100 - if labels.ndim > 1: - valid = valid.reshape(int(labels.shape[0]), -1).any(dim=1) - valid_offsets = torch.nonzero(valid, as_tuple=False).reshape(-1) - if int(valid_offsets.numel()): - projected_rows.append(positions.index_select(0, valid_offsets)) + if int(labels.shape[0]): + valid = labels != -100 + if labels.ndim > 1: + valid = valid.reshape(int(labels.shape[0]), -1).any(dim=1) + valid_offsets = torch.nonzero(valid, as_tuple=False).reshape(-1) + if int(valid_offsets.numel()): + projected_rows.append( + positions.index_select(0, valid_offsets) + ) if item.request.logits: logits[index] = torch.empty( (int(positions.numel()), _padded_vocab_size(model)), diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index 23fce9b0a..db08b3a28 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -85,6 +85,16 @@ def test_trainer_rank_pop_rejects_empty_adapter_stack() -> None: trainer.pop_pushed_lora_or_checkpoint() +def test_trainer_rank_load_rejects_active_adapter_stack() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._slot_stack.append(object()) # type: ignore[arg-type] + + with pytest.raises(RuntimeError, match="Cannot load a LoRA/checkpoint"): + trainer.load_checkpoint_slot("teacher", {}) + with pytest.raises(RuntimeError, match="Cannot load a LoRA/checkpoint"): + trainer.load_lora_slot("teacher", {}) + + def test_dp_rank_forward_preserves_nested_shape_for_inactive_requests() -> None: trainer = TrainerRank(_runtime()) # type: ignore[arg-type] request_a = ForwardInput(input_tokens=torch.tensor([1])) From 62d3e9a2457b2055fbb0258a395f945d5e96ea27 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 10:54:09 -0600 Subject: [PATCH 06/11] fix: guard trainer rank live slot graphs --- src/art/trainer_rank/__init__.py | 94 +++++++++++++++++++++- tests/unit/test_trainer_rank_validation.py | 60 ++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index e131d6c13..71089b1c4 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -166,6 +166,23 @@ class TrainerRankMemoryError(RuntimeError): pass +class TrainerRankSlotStateError(RuntimeError): + pass + + +@dataclass +class _SlotGraphLease: + trainer: "TrainerRank" + ref: "LoRASlotRef" + active: bool = True + + def release(self) -> None: + if not self.active: + return + self.active = False + self.trainer._release_slot_graph(self.ref) + + @dataclass(frozen=True) class _PushedSlot: trainer: "TrainerRank" @@ -279,6 +296,7 @@ def __init__( self._checkpoint_slot_params_by_name: dict[ str, tuple[torch.nn.Parameter, ...] ] = {} + self._pending_slot_graphs: dict[LoRASlotRef, int] = {} self._memory_profiles: dict[_MemorySignature, _MemoryProfile] = {} self._adaptive_plan_cache: dict[_AdaptivePlanCacheKey, _FlatForwardPlan] = {} self._adaptive_estimate_cache: dict[ @@ -298,6 +316,7 @@ def zero_grad(self) -> None: for params in self._checkpoint_slot_params_by_name.values(): for param in params: param.grad = None + self._pending_slot_graphs.clear() def _optimizer(self) -> "MegatronOptimizer": optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) @@ -368,9 +387,11 @@ def _load_slot( raise RuntimeError("Cannot load a LoRA/checkpoint while a slot is pushed") from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model + ref = self._slot_ref(kind, name) + self._guard_slot_can_load(ref) return load_lora_slot_into_model( self.runtime.model, - self._slot_ref(kind, name), + ref, adapter_model, alpha=LORA_ALPHA if alpha is None else alpha, requires_grad=trainable, @@ -616,6 +637,7 @@ def _dynamic_optim_step( ) -> dict[str, float]: all_params: list[torch.nn.Parameter] = [] for name in checkpoint_names: + self._guard_checkpoint_can_step(name) slot_params = self._checkpoint_slot_params_by_name[name] for param in slot_params: if param.grad is None: @@ -633,6 +655,7 @@ def _dynamic_optim_step( optimizer = self._dynamic_optimizer(name, params) optimizer.step() optimizer.zero_grad(set_to_none=True) + self._pending_slot_graphs.pop(self._slot_ref("checkpoint", name), None) return { "learning_rate": float(params.learning_rate), "grad_norm": float(grad_norm), @@ -1036,10 +1059,67 @@ def _execute_flat_plan(self, plan: _FlatForwardPlan) -> list[AnyForwardOutput]: with use_lora_slot(group.slot_ref): prepared = self._prepare_packed_forward(group.packed) item_outputs = self._forward_packed(group.items, prepared) + self._track_slot_graph_outputs(group.slot_ref, item_outputs) for index, output in zip(group.request_indices, item_outputs, strict=True): outputs[index] = output return outputs + def _track_slot_graph_outputs( + self, + ref: "LoRASlotRef | None", + outputs: Sequence[AnyForwardOutput], + ) -> None: + if ref is None or ref.name is None: + return + tensors = [ + tensor + for output in outputs + for tensor in _forward_output_grad_tensors(output) + ] + if not tensors: + return + + self._pending_slot_graphs[ref] = self._pending_slot_graphs.get(ref, 0) + 1 + lease = _SlotGraphLease(self, ref) + + def release(grad: torch.Tensor) -> torch.Tensor: + lease.release() + return grad + + for tensor in tensors: + tensor.register_hook(release) + + def _release_slot_graph(self, ref: "LoRASlotRef") -> None: + count = self._pending_slot_graphs.get(ref, 0) + if count <= 1: + self._pending_slot_graphs.pop(ref, None) + else: + self._pending_slot_graphs[ref] = count - 1 + + def _guard_slot_can_load(self, ref: "LoRASlotRef") -> None: + if self._pending_slot_graphs.get(ref, 0) <= 0: + return + raise TrainerRankSlotStateError( + f"Cannot load {ref.kind} slot {ref.name!r} while outputs from an " + "earlier forward using that slot still have a live backward graph. " + "Activation checkpoint recompute resolves slots by name, so replacing " + "the slot before backward can compute gradients with different LoRA " + "weights than the original forward. Call loss.backward() first, or " + "call zero_grad() if the forward was abandoned, or load the new " + "weights under a different slot name." + ) + + def _guard_checkpoint_can_step(self, name: str) -> None: + ref = self._slot_ref("checkpoint", name) + if self._pending_slot_graphs.get(ref, 0) <= 0: + return + raise TrainerRankSlotStateError( + f"Cannot optim_step checkpoint slot {name!r} while outputs from an " + "earlier forward using that slot have not been backpropagated. Call " + "loss.backward() before optim_step(), or call zero_grad() if that " + "forward was abandoned." + ) + def _estimate_group_request_output_bytes( self, requests: Sequence[AnyForwardInput], @@ -2156,6 +2236,17 @@ def _batch_seq_logits(logits: torch.Tensor, *, seq_len: int) -> torch.Tensor: ) +def _forward_output_grad_tensors(output: AnyForwardOutput) -> Iterator[torch.Tensor]: + if output.target_logprobs is not None and output.target_logprobs.requires_grad: + yield output.target_logprobs + if output.top_k is not None and output.top_k.logprobs.requires_grad: + yield output.top_k.logprobs + if output.logits is not None and output.logits.requires_grad: + yield output.logits + if output.hidden_states is not None and output.hidden_states.requires_grad: + yield output.hidden_states + + def _materialize(inputs: ForwardInputs) -> ForwardInputs: if isinstance(inputs, ForwardInput): return inputs @@ -2207,4 +2298,5 @@ def _nested_forward_children(inputs: ForwardInputs) -> Iterator[ForwardInputs]: "TopK", "TrainerRank", "TrainerRankMemoryError", + "TrainerRankSlotStateError", ] diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index db08b3a28..658ca8f39 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from types import SimpleNamespace import pytest @@ -11,6 +12,7 @@ TopK, TrainerRank, TrainerRankMemoryError, + TrainerRankSlotStateError, Unset, _anchor_disconnected_outputs, _MemoryCheck, @@ -23,6 +25,12 @@ class _Model: vocab_size = 8 +@dataclass(frozen=True) +class _SlotRef: + kind: str + name: str | None + + def _runtime(model: torch.nn.Module | None = None) -> SimpleNamespace: return SimpleNamespace( model=[model or torch.nn.Linear(1, 1)], @@ -95,6 +103,58 @@ def test_trainer_rank_load_rejects_active_adapter_stack() -> None: trainer.load_lora_slot("teacher", {}) +def test_trainer_rank_load_rejects_pending_checkpoint_graph() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + ref = _SlotRef("checkpoint", "teacher") + output = ForwardOutput(torch.ones(1, requires_grad=True) * 2, None, None, None) + + trainer._track_slot_graph_outputs(ref, [output]) # type: ignore[arg-type] + + with pytest.raises(TrainerRankSlotStateError, match="Cannot load checkpoint slot"): + trainer._guard_slot_can_load(ref) # type: ignore[arg-type] + + output.target_logprobs.sum().backward() + + trainer._guard_slot_can_load(ref) # type: ignore[arg-type] + + +def test_trainer_rank_step_rejects_pending_checkpoint_graph( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_slot_ref", lambda kind, name: _SlotRef(kind, name)) + ref = _SlotRef("checkpoint", "student") + output = ForwardOutput(torch.ones(1, requires_grad=True) * 2, None, None, None) + + trainer._track_slot_graph_outputs(ref, [output]) # type: ignore[arg-type] + + with pytest.raises(TrainerRankSlotStateError, match="Cannot optim_step"): + trainer._guard_checkpoint_can_step("student") + + output.target_logprobs.sum().backward() + + trainer._guard_checkpoint_can_step("student") + + +def test_trainer_rank_zero_grad_clears_abandoned_slot_graphs() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + ref = _SlotRef("lora", "teacher") + output = ForwardOutput( + None, + TopK( + torch.ones(1, requires_grad=True) * 2, + torch.ones(1, dtype=torch.long), + ), + None, + None, + ) + + trainer._track_slot_graph_outputs(ref, [output]) # type: ignore[arg-type] + trainer.zero_grad() + + trainer._guard_slot_can_load(ref) # type: ignore[arg-type] + + def test_dp_rank_forward_preserves_nested_shape_for_inactive_requests() -> None: trainer = TrainerRank(_runtime()) # type: ignore[arg-type] request_a = ForwardInput(input_tokens=torch.tensor([1])) From 228885f60d511769845775cfbeb212823d32af69 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 10:55:53 -0600 Subject: [PATCH 07/11] fix: lazily initialize trainer rank slot graph guard --- src/art/trainer_rank/__init__.py | 23 +++++++++++++++------- tests/unit/test_trainer_rank_validation.py | 9 +++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 71089b1c4..557521ada 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -655,7 +655,7 @@ def _dynamic_optim_step( optimizer = self._dynamic_optimizer(name, params) optimizer.step() optimizer.zero_grad(set_to_none=True) - self._pending_slot_graphs.pop(self._slot_ref("checkpoint", name), None) + self._slot_graphs().pop(self._slot_ref("checkpoint", name), None) return { "learning_rate": float(params.learning_rate), "grad_norm": float(grad_norm), @@ -1079,7 +1079,8 @@ def _track_slot_graph_outputs( if not tensors: return - self._pending_slot_graphs[ref] = self._pending_slot_graphs.get(ref, 0) + 1 + graphs = self._slot_graphs() + graphs[ref] = graphs.get(ref, 0) + 1 lease = _SlotGraphLease(self, ref) def release(grad: torch.Tensor) -> torch.Tensor: @@ -1090,14 +1091,22 @@ def release(grad: torch.Tensor) -> torch.Tensor: tensor.register_hook(release) def _release_slot_graph(self, ref: "LoRASlotRef") -> None: - count = self._pending_slot_graphs.get(ref, 0) + graphs = self._slot_graphs() + count = graphs.get(ref, 0) if count <= 1: - self._pending_slot_graphs.pop(ref, None) + graphs.pop(ref, None) else: - self._pending_slot_graphs[ref] = count - 1 + graphs[ref] = count - 1 + + def _slot_graphs(self) -> dict["LoRASlotRef", int]: + graphs = getattr(self, "_pending_slot_graphs", None) + if graphs is None: + graphs = {} + self._pending_slot_graphs = graphs + return graphs def _guard_slot_can_load(self, ref: "LoRASlotRef") -> None: - if self._pending_slot_graphs.get(ref, 0) <= 0: + if self._slot_graphs().get(ref, 0) <= 0: return raise TrainerRankSlotStateError( f"Cannot load {ref.kind} slot {ref.name!r} while outputs from an " @@ -1111,7 +1120,7 @@ def _guard_slot_can_load(self, ref: "LoRASlotRef") -> None: def _guard_checkpoint_can_step(self, name: str) -> None: ref = self._slot_ref("checkpoint", name) - if self._pending_slot_graphs.get(ref, 0) <= 0: + if self._slot_graphs().get(ref, 0) <= 0: return raise TrainerRankSlotStateError( f"Cannot optim_step checkpoint slot {name!r} while outputs from an " diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index 658ca8f39..0a79895d8 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -136,6 +136,15 @@ def test_trainer_rank_step_rejects_pending_checkpoint_graph( trainer._guard_checkpoint_can_step("student") +def test_trainer_rank_step_allows_missing_slot_graph_bookkeeping( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank.__new__(TrainerRank) + monkeypatch.setattr(trainer, "_slot_ref", lambda kind, name: _SlotRef(kind, name)) + + trainer._guard_checkpoint_can_step("student") + + def test_trainer_rank_zero_grad_clears_abandoned_slot_graphs() -> None: trainer = TrainerRank(_runtime()) # type: ignore[arg-type] ref = _SlotRef("lora", "teacher") From 702e00efa76d3ccaafb42ee3779fd58fa50f2778 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 11:07:34 -0600 Subject: [PATCH 08/11] refactor: surface trainer rank public api --- src/art/trainer_rank/__init__.py | 222 +++++++++++++++---------------- 1 file changed, 111 insertions(+), 111 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 557521ada..b828c2951 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -138,6 +138,14 @@ class MicroBatchStats: cold_start: bool +class TrainerRankMemoryError(RuntimeError): + pass + + +class TrainerRankSlotStateError(RuntimeError): + pass + + @dataclass(frozen=True) class _MemoryCheck: estimated_required_bytes: int @@ -162,14 +170,6 @@ class _CandidateMicroBatch(Generic[ForwardInputsT]): cold_start: bool -class TrainerRankMemoryError(RuntimeError): - pass - - -class TrainerRankSlotStateError(RuntimeError): - pass - - @dataclass class _SlotGraphLease: trainer: "TrainerRank" @@ -318,12 +318,6 @@ def zero_grad(self) -> None: param.grad = None self._pending_slot_graphs.clear() - def _optimizer(self) -> "MegatronOptimizer": - optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) - if optimizer is None: - raise RuntimeError("TrainerRank requires a runtime with an optimizer") - return optimizer - def set_checkpoint(self, name: str | None) -> None: self._set_default_slot(self._slot_ref("checkpoint", name)) @@ -374,103 +368,6 @@ def load_lora_slot( self._validate_dynamic_slot_consistency("lora", name, loaded) return loaded - def _load_slot( - self, - kind: Literal["checkpoint", "lora"], - name: str, - adapter_model: dict[str, torch.Tensor], - *, - trainable: bool, - alpha: float | None, - ) -> int: - if self._slot_stack: - raise RuntimeError("Cannot load a LoRA/checkpoint while a slot is pushed") - from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model - - ref = self._slot_ref(kind, name) - self._guard_slot_can_load(ref) - return load_lora_slot_into_model( - self.runtime.model, - ref, - adapter_model, - alpha=LORA_ALPHA if alpha is None else alpha, - requires_grad=trainable, - ) - - def _set_default_slot(self, ref: "LoRASlotRef") -> None: - if self._slot_stack: - raise RuntimeError("Cannot set a LoRA/checkpoint while a slot is pushed") - self._default_slot_ref = ref - - @staticmethod - def _slot_ref( - kind: Literal["checkpoint", "lora"], name: str | None - ) -> "LoRASlotRef": - from art.megatron.lora import LoRASlotRef - - return LoRASlotRef(kind=kind, name=name) - - def _validate_dynamic_slot_consistency( - self, - kind: Literal["checkpoint", "lora"], - name: str, - loaded_sites: int, - ) -> tuple[torch.nn.Parameter, ...]: - from art.megatron.lora import iter_lora_slot_parameters - - ref = self._slot_ref(kind, name) - params = tuple(iter_lora_slot_parameters(self.runtime.model, ref)) - if not (dist.is_available() and dist.is_initialized()): - return params - - local = { - "rank": dist.get_rank(), - "loaded_sites": int(loaded_sites), - "param_count": len(params), - "numel": sum(int(param.numel()) for param in params), - "signature": [ - ( - tuple(int(dim) for dim in param.shape), - str(param.dtype), - bool(getattr(param, "allreduce", True)), - str(getattr(param, "grad_sync_domain", "tp_default")), - str(getattr(param, "grad_sync_op", "none")), - ) - for param in params - ], - } - gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() - dist.all_gather_object(gathered, local) - ranks = [rank for rank in gathered if rank is not None] - reference = ranks[0] - if all( - rank["loaded_sites"] == reference["loaded_sites"] - and rank["signature"] == reference["signature"] - for rank in ranks - ): - return params - - summary = [ - {key: rank[key] for key in ("rank", "loaded_sites", "param_count", "numel")} - for rank in ranks - ] - raise RuntimeError( - f"Dynamic LoRA slot {kind}:{name} is not loaded consistently across " - "distributed ranks. This usually means a sharded/exported LoRA state " - "dict was passed directly to TrainerRank; gather or materialize the " - "full adapter state before loading a dynamic slot. " - f"Rank summary: {summary}." - ) - - def _resolve_slot_ref(self, request: AnyForwardInput) -> "LoRASlotRef | None": - if request.checkpoint is not Unset: - return self._slot_ref("checkpoint", cast(str | None, request.checkpoint)) - if request.lora is not Unset: - return self._slot_ref("lora", cast(str | None, request.lora)) - if self._slot_stack: - return self._slot_stack[-1] - return self._default_slot_ref - def forward_micro_batches( self, inputs: Iterable[ForwardInputsT], @@ -602,6 +499,109 @@ def optim_step( "num_zeros_in_grad": float(num_zeros or 0), } + def _optimizer(self) -> "MegatronOptimizer": + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is None: + raise RuntimeError("TrainerRank requires a runtime with an optimizer") + return optimizer + + def _load_slot( + self, + kind: Literal["checkpoint", "lora"], + name: str, + adapter_model: dict[str, torch.Tensor], + *, + trainable: bool, + alpha: float | None, + ) -> int: + if self._slot_stack: + raise RuntimeError("Cannot load a LoRA/checkpoint while a slot is pushed") + from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model + + ref = self._slot_ref(kind, name) + self._guard_slot_can_load(ref) + return load_lora_slot_into_model( + self.runtime.model, + ref, + adapter_model, + alpha=LORA_ALPHA if alpha is None else alpha, + requires_grad=trainable, + ) + + def _set_default_slot(self, ref: "LoRASlotRef") -> None: + if self._slot_stack: + raise RuntimeError("Cannot set a LoRA/checkpoint while a slot is pushed") + self._default_slot_ref = ref + + @staticmethod + def _slot_ref( + kind: Literal["checkpoint", "lora"], name: str | None + ) -> "LoRASlotRef": + from art.megatron.lora import LoRASlotRef + + return LoRASlotRef(kind=kind, name=name) + + def _validate_dynamic_slot_consistency( + self, + kind: Literal["checkpoint", "lora"], + name: str, + loaded_sites: int, + ) -> tuple[torch.nn.Parameter, ...]: + from art.megatron.lora import iter_lora_slot_parameters + + ref = self._slot_ref(kind, name) + params = tuple(iter_lora_slot_parameters(self.runtime.model, ref)) + if not (dist.is_available() and dist.is_initialized()): + return params + + local = { + "rank": dist.get_rank(), + "loaded_sites": int(loaded_sites), + "param_count": len(params), + "numel": sum(int(param.numel()) for param in params), + "signature": [ + ( + tuple(int(dim) for dim in param.shape), + str(param.dtype), + bool(getattr(param, "allreduce", True)), + str(getattr(param, "grad_sync_domain", "tp_default")), + str(getattr(param, "grad_sync_op", "none")), + ) + for param in params + ], + } + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + ranks = [rank for rank in gathered if rank is not None] + reference = ranks[0] + if all( + rank["loaded_sites"] == reference["loaded_sites"] + and rank["signature"] == reference["signature"] + for rank in ranks + ): + return params + + summary = [ + {key: rank[key] for key in ("rank", "loaded_sites", "param_count", "numel")} + for rank in ranks + ] + raise RuntimeError( + f"Dynamic LoRA slot {kind}:{name} is not loaded consistently across " + "distributed ranks. This usually means a sharded/exported LoRA state " + "dict was passed directly to TrainerRank; gather or materialize the " + "full adapter state before loading a dynamic slot. " + f"Rank summary: {summary}." + ) + + def _resolve_slot_ref(self, request: AnyForwardInput) -> "LoRASlotRef | None": + if request.checkpoint is not Unset: + return self._slot_ref("checkpoint", cast(str | None, request.checkpoint)) + if request.lora is not Unset: + return self._slot_ref("lora", cast(str | None, request.lora)) + if self._slot_stack: + return self._slot_stack[-1] + return self._default_slot_ref + def _selected_dynamic_checkpoints( self, checkpoints: Sequence[str] | None, From e4d4886594f156dd48de3b7fa3724516857f88d7 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 11:37:19 -0600 Subject: [PATCH 09/11] fix: type trainer rank nested forwards --- src/art/trainer_rank/__init__.py | 112 +++++++++++++++++++-- tests/unit/test_trainer_rank_validation.py | 74 ++++++++++++++ 2 files changed, 176 insertions(+), 10 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index b828c2951..78a58a509 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -108,15 +108,17 @@ def __post_init__(self) -> None: torch.Tensor | None, torch.Tensor | None, ] +type AnyMicroBatch = MicroBatch[AnyForwardInput, AnyForwardOutput] type ForwardInputs = AnyForwardInput | Iterable["ForwardInputs"] type ForwardOutputs = AnyForwardOutput | Sequence["ForwardOutputs"] ForwardInputsT = TypeVar("ForwardInputsT", bound=ForwardInputs) +ForwardOutputsT = TypeVar("ForwardOutputsT", bound=ForwardOutputs) @dataclass(frozen=True) -class MicroBatch(Generic[ForwardInputsT]): +class MicroBatch(Generic[ForwardInputsT, ForwardOutputsT]): inputs: Sequence[ForwardInputsT] - outputs: Sequence[ForwardOutputs] + outputs: Sequence[ForwardOutputsT] indices: Sequence[int] stats: "MicroBatchStats" @@ -368,11 +370,73 @@ def load_lora_slot( self._validate_dynamic_slot_consistency("lora", name, loaded) return loaded + @overload + def forward_micro_batches( + self, + inputs: Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]], + ) -> Iterator[ + MicroBatch[ + ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT], + ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT], + ] + ]: ... + + @overload + def forward_micro_batches( + self, + inputs: Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ], + ) -> Iterator[ + MicroBatch[ + Sequence[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]], + Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]], + ] + ]: ... + + @overload + def forward_micro_batches( + self, + inputs: Iterable[ + Iterable[Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ], + ) -> Iterator[ + MicroBatch[ + Sequence[Sequence[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]], + Sequence[Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]], + ] + ]: ... + + @overload def forward_micro_batches( self, - inputs: Iterable[ForwardInputsT], - ) -> Iterator[MicroBatch[ForwardInputsT]]: - items = list(inputs) + inputs: Iterable[ + Iterable[ + Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ] + ] + ], + ) -> Iterator[ + MicroBatch[ + Sequence[ + Sequence[ + Sequence[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ] + ], + Sequence[ + Sequence[ + Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ] + ], + ] + ]: ... + + def forward_micro_batches( + self, + inputs: Iterable[ForwardInputs], + ) -> Iterator[MicroBatch[ForwardInputs, ForwardOutputs]]: + items: list[ForwardInputs] = list(inputs) self._adaptive_plan_cache.clear() self._adaptive_estimate_cache.clear() self._validate_replicated_top_level_count(len(items)) @@ -427,6 +491,32 @@ def dp_rank_forward( Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] ]: ... + @overload + def dp_rank_forward( + self, + inputs: Iterable[ + Iterable[Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ], + ) -> Sequence[ + Sequence[Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ]: ... + + @overload + def dp_rank_forward( + self, + inputs: Iterable[ + Iterable[ + Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ] + ] + ], + ) -> Sequence[ + Sequence[ + Sequence[Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]] + ] + ]: ... + def dp_rank_forward(self, inputs: ForwardInputs) -> ForwardOutputs: materialized = _materialize(inputs) plan = self._plan_flat_forward(list(_flatten(materialized))) @@ -2020,11 +2110,13 @@ def anchor_tensor(tensor: torch.Tensor) -> torch.Tensor: for item_logprobs in target_logprobs ], [ - None - if item_top_k is None - else TopK( - logprobs=anchor_tensor(item_top_k.logprobs), - tokens=item_top_k.tokens, + ( + None + if item_top_k is None + else TopK( + logprobs=anchor_tensor(item_top_k.logprobs), + tokens=item_top_k.tokens, + ) ) for item_top_k in top_k ], diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py index 0a79895d8..7d76b9256 100644 --- a/tests/unit/test_trainer_rank_validation.py +++ b/tests/unit/test_trainer_rank_validation.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from types import SimpleNamespace +from typing import Any, cast import pytest import torch @@ -45,6 +46,30 @@ def _target_request(token: int) -> ForwardInput[torch.Tensor, None, None, None]: return ForwardInput(input_tokens=tokens, target_tokens=tokens) +def _indexed_outputs(plan: object, **_kwargs: object) -> list[ForwardOutput]: + return [ + ForwardOutput(torch.tensor([index], dtype=torch.float32), None, None, None) + for index in range(int(getattr(plan, "request_count"))) + ] + + +def _output_values(outputs: object) -> list[int]: + if isinstance(outputs, ForwardOutput): + target_logprobs = outputs.target_logprobs + assert isinstance(target_logprobs, torch.Tensor) + return [int(target_logprobs.item())] + values: list[int] = [] + for item in outputs: # type: ignore[union-attr] + values.extend(_output_values(item)) + return values + + +def _output_shape(outputs: object) -> object: + if isinstance(outputs, ForwardOutput): + return "output" + return [_output_shape(item) for item in outputs] # type: ignore[union-attr] + + def test_forward_input_rejects_non_positive_top_k() -> None: with pytest.raises(ValueError, match="top_k must be >= 1"): ForwardInput(input_tokens=torch.tensor([1]), top_k=0) @@ -179,6 +204,27 @@ def test_dp_rank_forward_preserves_nested_shape_for_inactive_requests() -> None: assert not hasattr(trainer, "micro_batches") +def test_dp_rank_forward_supports_arbitrary_nested_depth( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr( + trainer, "_run_flat_plan_with_memory_tracking", _indexed_outputs + ) + nested = [ + [[[[[_target_request(1)]]]]], + [[[[[_target_request(3), _target_request(5)]]]]], + ] + + outputs = cast(Any, trainer).dp_rank_forward(nested) + + assert _output_shape(outputs) == [ + [[[[["output"]]]]], + [[[[["output", "output"]]]]], + ] + assert _output_values(outputs) == [0, 1, 2] + + def test_forward_micro_batches_uses_deterministic_dp_windows( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -254,6 +300,34 @@ def test_forward_micro_batches_outputs_match_top_level_nested_inputs( assert len(batch.outputs[0]) == 2 +def test_forward_micro_batches_supports_arbitrary_nested_depth( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, "_run_flat_plan_with_memory_tracking", _indexed_outputs + ) + nested = [ + [[[[[_target_request(1)]]]]], + [[[[[_target_request(3), _target_request(5)]]]]], + ] + + batches = list(cast(Any, trainer).forward_micro_batches(nested)) + + assert len(batches) == 1 + assert batches[0].inputs == nested + assert batches[0].select(nested) == nested + assert _output_shape(batches[0].outputs) == [ + [[[[["output"]]]]], + [[[[["output", "output"]]]]], + ] + assert _output_values(batches[0].outputs) == [0, 1, 2] + + def test_forward_micro_batches_ramps_after_first_success( monkeypatch: pytest.MonkeyPatch, ) -> None: From 02423d8c9eeaf22fa3c1ec5823430c2bbff1282a Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 11:48:27 -0600 Subject: [PATCH 10/11] fix: make trainer microbatch typing covariant --- src/art/trainer_rank/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 78a58a509..3397f07c5 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -113,12 +113,14 @@ def __post_init__(self) -> None: type ForwardOutputs = AnyForwardOutput | Sequence["ForwardOutputs"] ForwardInputsT = TypeVar("ForwardInputsT", bound=ForwardInputs) ForwardOutputsT = TypeVar("ForwardOutputsT", bound=ForwardOutputs) +MicroBatchInputsT = TypeVar("MicroBatchInputsT", bound=ForwardInputs, covariant=True) +MicroBatchOutputsT = TypeVar("MicroBatchOutputsT", bound=ForwardOutputs, covariant=True) @dataclass(frozen=True) -class MicroBatch(Generic[ForwardInputsT, ForwardOutputsT]): - inputs: Sequence[ForwardInputsT] - outputs: Sequence[ForwardOutputsT] +class MicroBatch(Generic[MicroBatchInputsT, MicroBatchOutputsT]): + inputs: Sequence[MicroBatchInputsT] + outputs: Sequence[MicroBatchOutputsT] indices: Sequence[int] stats: "MicroBatchStats" From b3bbecf3a98d4702562666e48664b373ba6afb09 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Sat, 27 Jun 2026 15:24:25 -0600 Subject: [PATCH 11/11] fix: type trainer rank request outputs --- pyproject.toml | 5 +- src/art/trainer_rank/__init__.py | 267 ++++++++++++++++++- tests/unit/test_trainer_rank_weird_shapes.py | 9 +- uv.lock | 201 +++++--------- 4 files changed, 339 insertions(+), 143 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e6261e24..cdc21c50f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ backend = [ "awscli>=1.38.1", "setuptools>=78.1.0", "wandb==0.25.0", - "transformers==5.2.0", + "transformers==5.6.2", "duckdb>=1.0.0", "pyarrow>=15.0.0", "trl==0.20.0", @@ -81,7 +81,7 @@ tinker = [ "tinker-cookbook>=0.4.1,<0.5", "tinker>=0.21.0,<0.22", "torch==2.11.0", - "transformers>=5.2.0,<=5.5.3", + "transformers==5.6.2", "uvicorn>=0.35.0", "datrie>=0.8.3", ] @@ -169,6 +169,7 @@ override-dependencies = [ "quack-kernels==0.3.7", "transformer-engine==2.11.0", "torch==2.11.0", + "transformers==5.6.2", ] exclude-dependencies = ["pynvml", "emerging-optimizers", "causal-conv1d", "mamba-ssm"] no-build-isolation-package = ["apex", "transformer-engine", "transformer-engine-cu12", "transformer-engine-torch", "megatron-bridge", "deep-ep", "nv-grouped-gemm"] diff --git a/src/art/trainer_rank/__init__.py b/src/art/trainer_rank/__init__.py index 3397f07c5..3a71ad7d2 100644 --- a/src/art/trainer_rank/__init__.py +++ b/src/art/trainer_rank/__init__.py @@ -10,7 +10,16 @@ ) from dataclasses import dataclass import os -from typing import TYPE_CHECKING, Generic, Literal, ParamSpec, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + ParamSpec, + TypeVar, + cast, + overload, +) import torch import torch.distributed as dist @@ -79,7 +88,7 @@ class ForwardOutput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): hidden_states: HiddenStatesT -@dataclass(slots=True) +@dataclass(slots=True, init=False) class ForwardInput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): input_tokens: torch.Tensor target_tokens: torch.Tensor | None = None @@ -89,6 +98,260 @@ class ForwardInput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): checkpoint: AdapterSelection = Unset lora: AdapterSelection = Unset + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, None, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, None, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, TopK, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, None, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, None, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, TopK, None, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, None, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[False] = False, + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, None, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[True], + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, TopK, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, TopK, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, None, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[True], + hidden_states: Literal[False] = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, TopK, torch.Tensor, None]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[False] = False, + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, TopK, None, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: None = None, + logits: Literal[True], + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, None, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: None = None, + top_k: int, + logits: Literal[True], + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[None, TopK, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor, + top_k: int, + logits: Literal[True], + hidden_states: Literal[True], + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor, TopK, torch.Tensor, torch.Tensor]": ... + + @overload + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> "ForwardInput[torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None]": ... + + def __new__( + cls, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> Any: + return object.__new__(cls) + + def __init__( + self, + *, + input_tokens: torch.Tensor, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, + ) -> None: + self.input_tokens = input_tokens + self.target_tokens = target_tokens + self.top_k = top_k + self.logits = logits + self.hidden_states = hidden_states + self.checkpoint = checkpoint + self.lora = lora + self.__post_init__() + def __post_init__(self) -> None: if self.top_k is not None and self.top_k < 1: raise ValueError("top_k must be >= 1") diff --git a/tests/unit/test_trainer_rank_weird_shapes.py b/tests/unit/test_trainer_rank_weird_shapes.py index 541d2de9a..05bd79076 100644 --- a/tests/unit/test_trainer_rank_weird_shapes.py +++ b/tests/unit/test_trainer_rank_weird_shapes.py @@ -11,6 +11,7 @@ pack_shared_prefixes, ) from art.trainer_rank import ( + AdapterSelection, ForwardInput, ForwardOutput, TopK, @@ -57,8 +58,8 @@ def _target_request( top_k: int | None = None, logits: bool = False, hidden_states: bool = False, - checkpoint: object = Unset, - lora: object = Unset, + checkpoint: AdapterSelection = Unset, + lora: AdapterSelection = Unset, ) -> ForwardInput: labels = ( tokens @@ -74,8 +75,8 @@ def _target_request( top_k=top_k, logits=logits, hidden_states=hidden_states, - checkpoint=checkpoint, # type: ignore[arg-type] - lora=lora, # type: ignore[arg-type] + checkpoint=checkpoint, + lora=lora, ) diff --git a/uv.lock b/uv.lock index 36ad513d8..47786e876 100644 --- a/uv.lock +++ b/uv.lock @@ -54,6 +54,7 @@ overrides = [ { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'win32'", specifier = "==2.11.0" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = "==2.11.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformer-engine", specifier = "==2.11.0" }, + { name = "transformers", specifier = "==5.6.2" }, ] excludes = [ "causal-conv1d", @@ -1220,7 +1221,7 @@ name = "cuda-bindings" version = "12.9.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder", marker = "sys_platform == 'linux' or sys_platform == 'win32' or extra == 'extra-12-openpipe-art-megatron'" }, + { name = "cuda-pathfinder", marker = "sys_platform == 'linux' or (sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/32/45/557d4ed1fa54f0c7db8aee083229f624990d69f7d00f55477eed5c7e169a/cuda_bindings-12.9.7-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0666d3c082ef8f4b2d670950589373550e9f3bf564d635dd883f24a0b40402ff", size = 7071026, upload-time = "2026-05-27T18:44:13.356Z" }, @@ -1291,37 +1292,37 @@ wheels = [ [package.optional-dependencies] cublas = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cudart = [ - { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cufft = [ - { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cufft-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cufile = [ - { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cufile-cu12", marker = "sys_platform == 'linux'" }, ] cupti = [ - { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cuda-cupti-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] curand = [ - { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-curand-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cusolver = [ - { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cusolver-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] cusparse = [ - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] nvjitlink = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] nvrtc = [ - { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] nvtx = [ - { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvtx-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] [[package]] @@ -1938,7 +1939,7 @@ version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fla-core" }, - { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/79/5c/1db76cc829c951117a3112f306d50333bd71399d2e35807fe7c99ffc2007/flash_linear_attention-0.5.0.tar.gz", hash = "sha256:22b789a47f07738b4382ecdf775d7bb40e0d803c467c34f8e2ecd6a1dc780938", size = 160419, upload-time = "2026-04-21T20:25:42.344Z" } wheels = [ @@ -2502,7 +2503,7 @@ name = "gunicorn" version = "25.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "packaging" }, + { name = "packaging", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/c4/f4/e78fa054248fab913e2eab0332c6c2cb07421fca1ce56d8fe43b6aef57a4/gunicorn-25.3.0.tar.gz", hash = "sha256:f74e1b2f9f76f6cd1ca01198968bd2dd65830edc24b6e8e4d78de8320e2fe889", size = 634883, upload-time = "2026-03-27T00:00:26.092Z" } wheels = [ @@ -3844,7 +3845,7 @@ dependencies = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, - { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, { name = "typing-extensions" }, { name = "wandb" }, ] @@ -4394,7 +4395,7 @@ name = "nvidia-cudnn-cu12" version = "9.19.0.56" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/09/b8/277c51962ee46fa3e5b203ac5f76107c650f781d6891e681e28e6f3e9fe6/nvidia_cudnn_cu12-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:08caaf27fe556aca82a3ee3b5aa49a77e7de0cfcb7ff4e5c29da426387a8267e", size = 656910700, upload-time = "2026-02-03T20:40:25.508Z" }, @@ -4423,7 +4424,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" }, @@ -4455,9 +4456,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" }, @@ -4470,7 +4471,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" }, @@ -4768,7 +4769,7 @@ backend = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchao" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, { name = "trl" }, { name = "unsloth" }, { name = "unsloth-zoo" }, @@ -4799,7 +4800,7 @@ megatron = [ { name = "transformer-engine" }, { name = "transformer-engine-cu12" }, { name = "transformer-engine-torch" }, - { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, ] plotting = [ { name = "matplotlib" }, @@ -4818,7 +4819,7 @@ tinker = [ { name = "tinker-cookbook" }, { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, { name = "uvicorn" }, ] @@ -4904,9 +4905,9 @@ requires-dist = [ { name = "transformer-engine", marker = "extra == 'megatron'", specifier = "==2.11.0" }, { name = "transformer-engine-cu12", marker = "extra == 'megatron'", specifier = "==2.11.0" }, { name = "transformer-engine-torch", marker = "extra == 'megatron'", git = "https://github.com/NVIDIA/TransformerEngine.git?subdirectory=transformer_engine%2Fpytorch&rev=v2.11" }, - { name = "transformers", marker = "extra == 'backend'", specifier = "==5.2.0" }, + { name = "transformers", marker = "extra == 'backend'", specifier = "==5.6.2" }, { name = "transformers", marker = "extra == 'megatron'", specifier = "==5.6.2" }, - { name = "transformers", marker = "extra == 'tinker'", specifier = ">=5.2.0,<=5.5.3" }, + { name = "transformers", marker = "extra == 'tinker'", specifier = "==5.6.2" }, { name = "trl", marker = "extra == 'backend'", specifier = "==0.20.0" }, { name = "typer", specifier = ">=0.15.2" }, { name = "unsloth", marker = "extra == 'backend'", specifier = "==2026.3.3" }, @@ -5235,8 +5236,7 @@ dependencies = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-12-openpipe-art-backend' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "transformers", version = "5.6.2", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-12-openpipe-art-megatron'" }, + { name = "transformers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/86/cf/037f1e3d5186496c05513a6754639e2dab3038a05f384284d49a9bd06a2d/peft-0.19.1.tar.gz", hash = "sha256:0d97542fe96dcdaa20d3b81c06f26f988618f416a73544ab23c3618ccb674a40", size = 763738, upload-time = "2026-04-16T15:46:45.105Z" } wheels = [ @@ -7696,16 +7696,16 @@ name = "tilelang" version = "0.1.10" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "apache-tvm-ffi", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "cloudpickle", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "ml-dtypes", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "psutil", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "torch-c-dlpack-ext", marker = "(python_full_version < '3.14' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform == 'win32')" }, - { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, - { name = "z3-solver", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "apache-tvm-ffi", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "cloudpickle", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "ml-dtypes", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "psutil", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "torch-c-dlpack-ext", marker = "python_full_version < '3.14' and platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "tqdm", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, + { name = "z3-solver", marker = "platform_machine != 's390x' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/77/5c/07146b4527656102e48d21c2599aa80477e83ea3f149ac0df3b15a247bd4/tilelang-0.1.10.tar.gz", hash = "sha256:d8813e668fcf75843bc2d68c633c352b419c1e292895a6038a4aadd943e56c2b", size = 93184128, upload-time = "2026-05-25T03:58:57.006Z" } wheels = [ @@ -7748,7 +7748,7 @@ dependencies = [ { name = "pydantic" }, { name = "rich" }, { name = "sniffio" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/bc/7a/a72cb2b487a7581cc192f73fd64250d14434702c4e83b2da3d5924d5ecbc/tinker-0.21.0.tar.gz", hash = "sha256:8d72709fb639f74bf90f1d1fd57beec53bfc147a768a8f42e5d6b4404eeccce9", size = 251660, upload-time = "2026-05-19T00:24:02.569Z" } @@ -7779,7 +7779,7 @@ dependencies = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra != 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'linux' and extra != 'extra-12-openpipe-art-megatron') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra != 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "tqdm" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e5/9c/37af9804cb3f1d88f5e67512aa1aeafeb49ef9012532d056d92c96194320/tinker_cookbook-0.4.1.tar.gz", hash = "sha256:1f9ad977317529bbf796f40ef13de59b2c93a0a257469bd80a7ffcfed5beb8b2", size = 4517724, upload-time = "2026-05-12T03:49:19.6Z" } wheels = [ @@ -7953,20 +7953,20 @@ resolution-markers = [ "(python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", ] dependencies = [ - { name = "cuda-bindings", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "cuda-bindings", marker = "sys_platform == 'linux'" }, { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "triton", marker = "sys_platform == 'linux' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32' or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, + { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "jinja2", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "networkx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "nvidia-cudnn-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "sys_platform == 'linux'" }, + { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "triton", marker = "sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ { url = "https://download-r2.pytorch.org/whl/cu128/torch-2.11.0%2Bcu128-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9c8f38efee365cb9d334de8a83ce52fc7e5fc9e5a7b0853285efa1b69e00b0f2", upload-time = "2026-04-27T17:41:30Z" }, @@ -8139,77 +8139,20 @@ dependencies = [ { name = "transformer-engine-cu12" }, ] -[[package]] -name = "transformers" -version = "5.2.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version == '3.13.*' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version < '3.13' and sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version >= '3.14' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version == '3.13.*' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version < '3.13' and sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "(python_full_version >= '3.14' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version >= '3.14' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", - "(python_full_version == '3.13.*' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version == '3.13.*' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", - "(python_full_version < '3.13' and sys_platform == 'linux' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron') or (python_full_version < '3.13' and sys_platform == 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron')", - "python_full_version >= '3.14' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version == '3.13.*' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", - "python_full_version < '3.13' and sys_platform != 'linux' and sys_platform != 'win32' and extra != 'extra-12-openpipe-art-backend' and extra != 'extra-12-openpipe-art-megatron'", -] -dependencies = [ - { name = "huggingface-hub", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "numpy", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "packaging", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "pyyaml", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "regex", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "safetensors", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "tokenizers", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "tqdm", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, - { name = "typer-slim", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/bd/7e/8a0c57d562015e5b16c97c1f0b8e0e92ead2c7c20513225dc12c2043ba9f/transformers-5.2.0.tar.gz", hash = "sha256:0088b8b46ccc9eff1a1dca72b5d618a5ee3b1befc3e418c9512b35dea9f9a650", size = 8618176, upload-time = "2026-02-16T18:54:02.867Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/93/79754b0ca486e556c2b95d4f5afc66aaf4b260694f3d6e1b51da2d036691/transformers-5.2.0-py3-none-any.whl", hash = "sha256:9ecaf243dc45bee11a7d93f8caf03746accc0cb069181bbf4ad8566c53e854b4", size = 10403304, upload-time = "2026-02-16T18:53:59.699Z" }, -] - [[package]] name = "transformers" version = "5.6.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'linux' and sys_platform != 'win32'", -] -dependencies = [ - { name = "huggingface-hub", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "numpy", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "packaging", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "pyyaml", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "regex", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "safetensors", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "tokenizers", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "tqdm", marker = "extra == 'extra-12-openpipe-art-megatron'" }, - { name = "typer", marker = "extra == 'extra-12-openpipe-art-megatron'" }, +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, + { name = "typer" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a4/e9/c6c80a07690142a7d05444271f47b9f3c8aac7dea01d52e1137ee480ad78/transformers-5.6.2.tar.gz", hash = "sha256:e657134c3e5a6bc00a3c35f4e2674bb51adfcd89898495b788a18552bac2b91a", size = 8311867, upload-time = "2026-04-23T18:33:29.332Z" } wheels = [ @@ -8250,7 +8193,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "accelerate" }, { name = "datasets" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, ] sdist = { url = "https://files.pythonhosted.org/packages/60/11/95cf1210df9f241b7b1084abe1032e322374f667c4587c09af8d14a1d76f/trl-0.20.0.tar.gz", hash = "sha256:3f949b009b79dc609cd8f5469d67209ab8f71c5cb4d8d979f7b568ef054922fa", size = 461791, upload-time = "2025-07-29T04:10:06.305Z" } wheels = [ @@ -8317,18 +8260,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/cc/c6c5dea061e2740355bfeef22ac6a41751bd2f3903e83921295569bdcec4/typer-0.26.3-py3-none-any.whl", hash = "sha256:e70549ec5a403ca8a0bf0802ddd9f3c6ff7a14ccbb859b01b697baa943636f33", size = 122338, upload-time = "2026-05-28T20:30:49.816Z" }, ] -[[package]] -name = "typer-slim" -version = "0.24.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typer", marker = "extra == 'extra-12-openpipe-art-backend' or extra != 'extra-12-openpipe-art-megatron' or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a7/a7/e6aecc4b4eb59598829a3b5076a93aff291b4fdaa2ded25efc4e1f4d219c/typer_slim-0.24.0.tar.gz", hash = "sha256:f0ed36127183f52ae6ced2ecb2521789995992c521a46083bfcdbb652d22ad34", size = 4776, upload-time = "2026-02-16T22:08:51.2Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/24/5480c20380dfd18cf33d14784096dca45a24eae6102e91d49a718d3b6855/typer_slim-0.24.0-py3-none-any.whl", hash = "sha256:d5d7ee1ee2834d5020c7c616ed5e0d0f29b9a4b1dd283bdebae198ec09778d0e", size = 3394, upload-time = "2026-02-16T22:08:49.92Z" }, -] - [[package]] name = "types-paramiko" version = "4.0.0.20260518" @@ -8406,7 +8337,7 @@ dependencies = [ { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchvision" }, { name = "tqdm" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, { name = "triton", marker = "'linux' in sys_platform" }, { name = "triton-windows", marker = "(platform_machine == 'AMD64' and sys_platform == 'win32') or (platform_machine == 'x86_64' and sys_platform == 'win32')" }, { name = "trl" }, @@ -8444,7 +8375,7 @@ dependencies = [ { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(sys_platform == 'linux' and extra == 'extra-12-openpipe-art-backend') or (sys_platform == 'win32' and extra == 'extra-12-openpipe-art-backend') or (extra == 'extra-12-openpipe-art-backend' and extra == 'extra-12-openpipe-art-megatron') or (extra == 'extra-12-openpipe-art-megatron' and extra == 'extra-12-openpipe-art-tinker')" }, { name = "torchao" }, { name = "tqdm" }, - { name = "transformers", version = "5.2.0", source = { registry = "https://pypi.org/simple" } }, + { name = "transformers" }, { name = "triton", marker = "'linux' in sys_platform" }, { name = "trl" }, { name = "typing-extensions" },