diff --git a/flashmask/benchmarks/paddle_ops/__init__.py b/flashmask/benchmarks/paddle_ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/flashmask/benchmarks/paddle_ops/registry.py b/flashmask/benchmarks/paddle_ops/registry.py new file mode 100644 index 00000000000..a68dfd69f37 --- /dev/null +++ b/flashmask/benchmarks/paddle_ops/registry.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import paddle +import paddle.nn.functional as F + +logger = logging.getLogger(__name__) + + +def shape_BTHD(B, T, H, D, **kw): + return (B, T, H, D) + + +def shape_BTH(B, T, H, D, **kw): + return (B, T, H) + + +logsigmoid = F.log_sigmoid + + +def sigmoid_transform(t): + return t.sigmoid() + + +@dataclass +class TensorSpec: + shape_fn: Callable + requires_grad: bool = True + dtype: Any = 'default' + transform: Callable | None = None + + +@dataclass +class OpConfig: + name: str + import_path: str + inputs: dict[str, TensorSpec] + func_name: str | None = None + extra_kwargs: dict[str, Any] = field(default_factory=dict) + output_is_tuple: bool = True + skip_backward: bool = False + category: str = '' + + +_REGISTRY: dict[str, OpConfig] = {} + + +def register_op(config: OpConfig) -> None: + _REGISTRY[config.name] = config + + +SHAPE_CONFIGS = { + 'B1_T8192_H96_D128': {'B': 1, 'T': 8192, 'H': 96, 'D': 128}, + 'B2_T16384_H16_D128': {'B': 2, 'T': 16384, 'H': 16, 'D': 128}, + 'B4_T2048_H16_D128': {'B': 4, 'T': 2048, 'H': 16, 'D': 128}, + 'B4_T4096_H64_D128': {'B': 4, 'T': 4096, 'H': 64, 'D': 128}, + 'B8_T2048_H32_D256': {'B': 8, 'T': 2048, 'H': 32, 'D': 256}, + 'B8_T1024_H8_D64': {'B': 8, 'T': 1024, 'H': 8, 'D': 64}, +} + + +def get_op(name: str) -> OpConfig: + if name not in _REGISTRY: + raise KeyError(f"Op '{name}' not registered. Available: {sorted(_REGISTRY)}") + return _REGISTRY[name] + + +def list_ops() -> list[str]: + return sorted(_REGISTRY.keys()) + + +def _resolve_dtype(dtype): + if dtype == 'default': + return paddle.bfloat16 + if dtype == 'float32': + return paddle.float32 + if dtype == 'int64': + return paddle.int64 + return dtype + + +def _set_device(device: str | None): + if device is None: + return + current = paddle.get_device() + if current != device: + paddle.device.set_device(device) + + +def generate_inputs( + config: OpConfig, + B: int, + T: int, + H: int, + D: int, + dtype=paddle.bfloat16, + device: str | None = None, +) -> dict[str, paddle.Tensor]: + _set_device(device) + inputs: dict[str, paddle.Tensor] = {} + for param_name, spec in config.inputs.items(): + shape = spec.shape_fn(B, T, H, D) + tensor_dtype = dtype if spec.dtype == 'default' else _resolve_dtype(spec.dtype) + if tensor_dtype == paddle.int64: + tensor = paddle.randint(0, 10, shape=shape, dtype=tensor_dtype) + else: + tensor = paddle.randn(shape, dtype=tensor_dtype) + if spec.transform is not None: + tensor = spec.transform(tensor) + if spec.requires_grad and paddle.is_floating_point(tensor): + tensor.stop_gradient = False + inputs[param_name] = tensor + return inputs + + +_simple_qkv = { + 'q': TensorSpec(shape_BTHD), + 'k': TensorSpec(shape_BTHD), + 'v': TensorSpec(shape_BTHD), +} + +register_op(OpConfig( + name='chunk_gdn', + import_path='linear_attn.ops.gated_delta_rule', + func_name='chunk_gated_delta_rule', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTH, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True}, + category='gate_beta', +)) + +register_op(OpConfig( + name='chunk_kda', + import_path='linear_attn.ops.kda', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTHD, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True, 'safe_gate': True, 'lower_bound': -5}, + category='gate_beta', +)) + +register_op(OpConfig( + name='recurrent_gdn', + import_path='linear_attn.ops.gated_delta_rule', + func_name='fused_recurrent_gated_delta_rule', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTH, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True}, + skip_backward=True, + category='gate_beta', +)) + +register_op(OpConfig( + name='recurrent_kda', + import_path='linear_attn.ops.kda', + func_name='fused_recurrent_kda', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTHD, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True, 'safe_gate': True, 'lower_bound': -5}, + skip_backward=True, + category='gate_beta', +)) diff --git a/flashmask/benchmarks/paddle_ops/run.py b/flashmask/benchmarks/paddle_ops/run.py new file mode 100644 index 00000000000..a39bcd1c3fa --- /dev/null +++ b/flashmask/benchmarks/paddle_ops/run.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import argparse +import importlib +import json +import logging +import os +import platform +import socket +import sys +from contextlib import contextmanager + +import paddle + +# Must call paddle.enable_compat(scope={"triton"}) BEFORE import triton. +# In a pure-Paddle env this registers the Paddle triton driver so that +# triton can discover it during initialization. In a mixed torch+paddle +# env this is also safe — both drivers are registered and the +# swap_driver_guard / activate_paddle_driver mechanism handles switching. +paddle.enable_compat(scope={"triton"}) + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from registry import SHAPE_CONFIGS, OpConfig, generate_inputs, get_op, list_ops # noqa: E402 + +logger = logging.getLogger(__name__) + + +@contextmanager +def _activate_paddle_driver(): + try: + from linear_attn.triton_utils import paddle_driver + from triton.runtime.driver import driver + except Exception: + yield + return + + if paddle_driver is None: + yield + return + + driver.set_active(paddle_driver) + try: + yield + finally: + driver.reset_active() + + +def _import_op(config: OpConfig): + mod = importlib.import_module(config.import_path) + attr = config.func_name or config.name + fn = getattr(mod, attr, None) + if fn is None: + raise ImportError( + f"Cannot find '{attr}' in module '{config.import_path}'. " + f"Available: {[x for x in dir(mod) if not x.startswith('_')]}" + ) + return fn + + +def _get_machine_info() -> dict: + info = { + 'hostname': socket.gethostname(), + 'platform': platform.platform(), + 'paddle_version': paddle.__version__, + } + try: + import triton + info['triton_version'] = triton.__version__ + except Exception: + info['triton_version'] = 'N/A' + + if paddle.device.is_compiled_with_cuda(): + info['gpu_name'] = paddle.device.cuda.get_device_name() + info['gpu_count'] = paddle.device.cuda.device_count() + else: + info['gpu_name'] = 'N/A' + info['gpu_count'] = 0 + return info + + +def _warmup_iters() -> int: + return max(1, int(os.environ.get('FLA_BENCH_OP_WARMUP_ITERS', '5'))) + + +def _do_bench_kw(): + warmup_ms = int(os.environ.get('FLA_BENCH_WARMUP_MS', '100')) + rep_ms = int(os.environ.get('FLA_BENCH_REP_MS', '500')) + return {'warmup': max(1, warmup_ms), 'rep': max(1, rep_ms)} + + +def _synchronize(): + if paddle.device.is_compiled_with_cuda(): + paddle.device.synchronize() + + +def _clear_gradients(inputs: dict[str, paddle.Tensor]): + for tensor in inputs.values(): + if isinstance(tensor, paddle.Tensor) and not tensor.stop_gradient: + tensor.clear_gradient(False) + + +def _backward(tensor: paddle.Tensor, grad: paddle.Tensor): + paddle.autograd.backward([tensor], [grad]) + + +def _warmup_autotune(fn, n: int | None = None): + if n is None: + n = _warmup_iters() + with _activate_paddle_driver(): + for _ in range(n): + fn() + _synchronize() + + +def benchmark_op( + op_name: str, + shapes: dict[str, dict[str, int]], + modes: list[str] | None = None, +) -> list[dict]: + import triton + + if modes is None: + modes = ['fwd', 'fwdbwd'] + + config = get_op(op_name) + op_fn = _import_op(config) + if config.skip_backward and 'fwdbwd' in modes: + modes = [mode for mode in modes if mode != 'fwdbwd'] + + dtype = paddle.bfloat16 + device = 'gpu' + + print(f"\n [{op_name}] Warming up {len(shapes)} shape(s)...") + failed_shapes = set() + for shape_name, shape_dict in shapes.items(): + B, T, H, D = shape_dict['B'], shape_dict['T'], shape_dict['H'], shape_dict['D'] + try: + inputs = generate_inputs(config, B, T, H, D, dtype=dtype, device=device) + with _activate_paddle_driver(): + out = op_fn(**inputs, **config.extra_kwargs) + out_tensor = out[0] if config.output_is_tuple else out + do = paddle.randn(out_tensor.shape, dtype=out_tensor.dtype) + + def _fwd_fn(inputs=inputs): + return op_fn(**inputs, **config.extra_kwargs) + + def _fwdbwd_fn(inputs=inputs, do=do): + _clear_gradients(inputs) + result = op_fn(**inputs, **config.extra_kwargs) + tensor = result[0] if config.output_is_tuple else result + _backward(tensor, do) + + warmup_fn = _fwdbwd_fn if 'fwdbwd' in modes else _fwd_fn + _warmup_autotune(warmup_fn) + except Exception as error: + logger.warning(f"Warmup failed for {op_name} @ {shape_name}: {error}") + failed_shapes.add(shape_name) + + valid_shapes = {name: cfg for name, cfg in shapes.items() if name not in failed_shapes} + print(f" [{op_name}] Warmup done.") + + results = [] + for shape_name, shape_dict in valid_shapes.items(): + B, T, H, D = shape_dict['B'], shape_dict['T'], shape_dict['H'], shape_dict['D'] + try: + inputs = generate_inputs(config, B, T, H, D, dtype=dtype, device=device) + with _activate_paddle_driver(): + out = op_fn(**inputs, **config.extra_kwargs) + out_tensor = out[0] if config.output_is_tuple else out + do = paddle.randn(out_tensor.shape, dtype=out_tensor.dtype) + except Exception as error: + logger.warning(f"Input generation failed for {op_name} @ {shape_name}: {error}") + continue + + for mode in modes: + if mode == 'fwd': + def fn(inputs=inputs): + return op_fn(**inputs, **config.extra_kwargs) + else: + def fn(inputs=inputs, do=do): + _clear_gradients(inputs) + result = op_fn(**inputs, **config.extra_kwargs) + tensor = result[0] if config.output_is_tuple else result + _backward(tensor, do) + + try: + with _activate_paddle_driver(): + ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8], **_do_bench_kw()) + except Exception as error: + logger.warning(f"Bench failed for {op_name} {mode} @ {shape_name}: {error}") + continue + + results.append({ + 'op': op_name, + 'mode': mode, + 'B': B, + 'T': T, + 'H': H, + 'D': D, + 'median_ms': ms[0], + 'p20_ms': ms[1], + 'p80_ms': ms[2], + }) + + return results + + +def print_results_table(results: list[dict], machine_info: dict | None = None): + if not results: + print("\n No results to display.") + return + + width = 92 + print(f"\n{'=' * width}") + if machine_info: + gpu = machine_info.get('gpu_name', 'N/A') + paddle_version = machine_info.get('paddle_version', 'N/A') + print(f" Machine: {gpu} | Paddle {paddle_version}") + print(f"{'=' * width}") + print(f" {'op':<18s} {'mode':<7s} {'B':>4s} {'T':>6s} {'H':>4s} {'D':>4s} {'median(ms)':>12s} {'p20(ms)':>12s} {'p80(ms)':>12s}") + print(f" {'-' * 18} {'-' * 7} {'-' * 4} {'-' * 6} {'-' * 4} {'-' * 4} {'-' * 12} {'-' * 12} {'-' * 12}") + for result in results: + print( + f" {result['op']:<18s} {result['mode']:<7s} {result['B']:>4d} {result['T']:>6d} " + f"{result['H']:>4d} {result['D']:>4d} {result['median_ms']:>12.3f} " + f"{result['p20_ms']:>12.3f} {result['p80_ms']:>12.3f}" + ) + print(f"{'=' * width}") + + +def main(argv: list[str] | None = None): + parser = argparse.ArgumentParser(description='Paddle benchmark runner for flash-linear-attention ops') + parser.add_argument('--op', nargs='+', default=None, help='Op name(s) to benchmark, or "all"') + parser.add_argument( + '--custom-shapes', + default=None, + help='JSON string to override default shapes, e.g. \'{"my": {"B":1,"T":2048,"H":16,"D":128}}\'', + ) + parser.add_argument( + '--modes', + nargs='+', + default=['fwd', 'fwdbwd'], + choices=['fwd', 'fwdbwd'], + help='Benchmark modes (default: fwd fwdbwd)', + ) + parser.add_argument('--json', dest='json_file', default=None, help='Output file path for JSON results') + parser.add_argument('--list', action='store_true', help='List all registered ops and exit') + args = parser.parse_args(argv) + + if args.list: + ops = list_ops() + print(f"Registered ops ({len(ops)}):") + for name in ops: + cfg = get_op(name) + print(f" {name:18s} [{cfg.category}] {cfg.import_path}") + return [] + + if args.op is None: + parser.error('--op is required unless --list') + + op_names = list_ops() if args.op == ['all'] else args.op + shape_configs = json.loads(args.custom_shapes) if args.custom_shapes else SHAPE_CONFIGS + + machine_info = _get_machine_info() + print(f"Machine: {machine_info.get('gpu_name', 'N/A')} | Paddle {machine_info.get('paddle_version', 'N/A')}") + print(f"Shapes: {len(shape_configs)} configs") + print(f"Ops: {op_names}") + + all_results = [] + for op_name in op_names: + try: + all_results.extend(benchmark_op(op_name, shape_configs, modes=args.modes)) + except Exception as error: + logger.error(f"Failed to benchmark {op_name}: {error}") + + mode_order = {'fwd': 0, 'fwdbwd': 1} + all_results.sort(key=lambda result: (mode_order.get(result['mode'], 9), result['B'], result['T'], result['H'], result['D'], result['op'])) + print_results_table(all_results, machine_info) + + if args.json_file: + output = {'machine_info': machine_info, 'results': all_results} + with open(args.json_file, 'w') as handle: + json.dump(output, handle, indent=2) + print(f"\nResults saved to {args.json_file}") + + return all_results + + +if __name__ == '__main__': + main() diff --git a/flashmask/flash_mask/__init__.py b/flashmask/flash_mask/__init__.py index 5a807427987..ef5ff1e30d1 100644 --- a/flashmask/flash_mask/__init__.py +++ b/flashmask/flash_mask/__init__.py @@ -72,3 +72,28 @@ if not _fa3_available and not _fa4_available: print("[WARNING] flash_mask: neither FA3 nor FA4 is available. " "Check your installation.") + +# ============================================================ +# Linear Attention: GDN and KDA operators +# ============================================================ +_linear_attn_available = False +try: + from .linear_attn import ( + chunk_gated_delta_rule, + chunk_gdn, + fused_recurrent_gated_delta_rule, + fused_recurrent_gdn, + chunk_kda, + fused_recurrent_kda, + ) + __all__ += [ + "chunk_gated_delta_rule", + "chunk_gdn", + "fused_recurrent_gated_delta_rule", + "fused_recurrent_gdn", + "chunk_kda", + "fused_recurrent_kda", + ] + _linear_attn_available = True +except ImportError: + pass # linear_attn dependencies not installed diff --git a/flashmask/flash_mask/linear_attn/README.md b/flashmask/flash_mask/linear_attn/README.md new file mode 100644 index 00000000000..e2c16b67ab3 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/README.md @@ -0,0 +1,129 @@ +# linear_attn + +Triton-based GDN (Gated Delta Networks) and KDA (Kimi Delta Attention) operators with chunk-wise and fused-recurrent execution modes. + +## Dependencies + +- PaddlePaddle-GPU (GPU required) +- triton +- pytest (for tests) +- einops (for GDN tests) + +## Environment Setup + +```bash +# Set PYTHONPATH so linear_attn is importable as a top-level package +export PYTHONPATH=/path/to/flash-attention/flashmask:$PYTHONPATH +``` + +## Running Tests + +Test files are located at `flashmask/tests/linear_attn/`: + +| File | Description | +|------|-------------| +| `test_gated_delta.py` | GDN operator correctness tests | +| `test_kda.py` | KDA operator correctness tests | + +Each test compares the Triton-optimized implementation against a naive Python reference, checking both forward output and backward gradients. + +```bash +cd /path/to/flash-attention/flashmask/tests/linear_attn + +# Run all tests +pytest test_gated_delta.py test_kda.py -v + +# Run GDN tests only +pytest test_gated_delta.py -v + +# Run KDA tests only +pytest test_kda.py -v + +# Run a single test function +pytest test_gated_delta.py::test_fused_recurrent -v +pytest test_kda.py::test_chunk -v + +# Filter by parametrized id +pytest test_gated_delta.py -k "test_fused_recurrent and B1-T63" -v +``` + +### Optional Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `SKIP_TEST_CHUNK_VARLEN=1` | unset | Skip varlen (variable-length sequence) tests | +| `FLA_BENCHMARK=1` | `0` | Disable driver probing overhead | + +```bash +SKIP_TEST_CHUNK_VARLEN=1 pytest test_gated_delta.py test_kda.py -v +``` + +## Running Benchmarks + +The benchmark framework is located at `flashmask/benchmarks/paddle_ops/` and supports 4 operators: + +| Operator | Description | Modes | +|----------|-------------|-------| +| `chunk_gdn` | GDN chunk-level | fwd / fwdbwd | +| `chunk_kda` | KDA chunk-level | fwd / fwdbwd | +| `recurrent_gdn` | GDN fused recurrent | fwd only | +| `recurrent_kda` | KDA fused recurrent | fwd only | + +```bash +cd /path/to/flash-attention/flashmask + +# List registered operators +python -m benchmarks.paddle_ops.run --list + +# Run all benchmarks +python -m benchmarks.paddle_ops.run --op all + +# Run specific operators +python -m benchmarks.paddle_ops.run --op chunk_gdn +python -m benchmarks.paddle_ops.run --op chunk_kda recurrent_kda + +# Forward only +python -m benchmarks.paddle_ops.run --op chunk_gdn --modes fwd + +# Custom shapes +python -m benchmarks.paddle_ops.run --op chunk_gdn \ + --custom-shapes '{"smoke":{"B":1,"T":64,"H":2,"D":32}}' + +# Save results as JSON +python -m benchmarks.paddle_ops.run --op all --json results.json +``` + +### Default Shape Configs + +| Config Name | B | T | H | D | +|-------------|---|---|---|---| +| B1_T8192_H96_D128 | 1 | 8192 | 96 | 128 | +| B2_T16384_H16_D128 | 2 | 16384 | 16 | 128 | +| B4_T2048_H16_D128 | 4 | 2048 | 16 | 128 | +| B4_T4096_H64_D128 | 4 | 4096 | 64 | 128 | +| B8_T2048_H32_D256 | 8 | 2048 | 32 | 256 | +| B8_T1024_H8_D64 | 8 | 1024 | 8 | 64 | + +### Benchmark Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `FLA_BENCH_OP_WARMUP_ITERS` | `5` | Number of warmup iterations | +| `FLA_BENCH_WARMUP_MS` | `100` | `do_bench` warmup time (ms) | +| `FLA_BENCH_REP_MS` | `500` | `do_bench` repeat measurement time (ms) | + +## Known Limitations + +- Context Parallel (CP) is NOT supported. +- `fused_recurrent_gdn` / `fused_recurrent_kda` are forward-only. Use `chunk_gdn` / `chunk_kda` for training workloads that require gradients. + +## Known Issue: GDN Backward Precision on Hopper GPUs with Triton >= 3.4.0 + +The upstream fla-org/flash-linear-attention project has identified a backward precision issue in the gated `chunk_bwd_dqkwg` kernel when running on Hopper-class GPUs (H20, H100, GB200, etc.) with Triton >= 3.4.0 ([upstream PR #827](https://github.com/fla-org/flash-linear-attention/pull/827)). The upstream fix introduces a TileLang-based kernel as an alternative backend. + +**Current status in this fork:** +- On NVIDIA H800 (Hopper) with the current Triton version, this issue has **not** been observed in practice. +- If you plan to deploy on other Hopper GPUs (H20, GB200, etc.), or upgrade Triton to >= 3.4.0, you may encounter this backward precision regression. +- The TileLang backend has **not** been integrated into this Paddle port yet. + +**Action needed:** When targeting Hopper GPUs other than H800 or upgrading Triton, consider integrating the TileLang backend from the upstream fix (`pip install tilelang` + dispatch logic in `fla/ops/common/chunk_o.py`). diff --git a/flashmask/flash_mask/linear_attn/__init__.py b/flashmask/flash_mask/linear_attn/__init__.py new file mode 100644 index 00000000000..5e64ec1cafc --- /dev/null +++ b/flashmask/flash_mask/linear_attn/__init__.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# flash-linear-attention Paddle migration entry point +# +# Migrated from fla-org/flash-linear-attention (MIT License) for PaddlePaddle. +# Provides Triton-based GDN (Gated Delta Networks) and KDA (Kimi Delta +# Attention) operators with chunk-wise and fused-recurrent execution modes. +# +# Known limitations (Phase 1): +# - Context Parallel (CP) is NOT supported. The cp_context parameter is +# accepted for API compatibility but has no effect; passing a non-None +# value may raise NotImplementedError in backward paths. +# - fused_recurrent_gdn / fused_recurrent_kda are FORWARD-ONLY. Calling +# backward through them will raise NotImplementedError. Use the +# chunk-based variants (chunk_gdn / chunk_kda) for training workloads +# that require gradient computation. + +import paddle +from flash_mask.linear_attn.triton_utils import _is_package_installed + +# No torch environment: enable triton scope compat globally (zero runtime overhead) +if not _is_package_installed("torch"): + paddle.enable_compat(scope={"triton"}) + +from flash_mask.linear_attn.ops.gated_delta_rule import ( + chunk_gated_delta_rule, + chunk_gdn, + fused_recurrent_gated_delta_rule, + fused_recurrent_gdn, +) +from flash_mask.linear_attn.ops.kda import ( + chunk_kda, + fused_recurrent_kda, +) + +__all__ = [ + 'chunk_gated_delta_rule', + 'chunk_gdn', + 'fused_recurrent_gated_delta_rule', + 'fused_recurrent_gdn', + 'chunk_kda', + 'fused_recurrent_kda', +] diff --git a/flashmask/flash_mask/linear_attn/modules/__init__.py b/flashmask/flash_mask/linear_attn/modules/__init__.py new file mode 100644 index 00000000000..b664bd46500 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/modules/__init__.py @@ -0,0 +1 @@ +from .l2norm import L2Norm, l2norm, l2norm_fwd, l2norm_bwd, l2_norm, L2NormFunction diff --git a/flashmask/flash_mask/linear_attn/modules/l2norm.py b/flashmask/flash_mask/linear_attn/modules/l2norm.py new file mode 100644 index 00000000000..d17549d35f7 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/modules/l2norm.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import paddle.nn as nn +import triton +import triton.language as tl + +from flash_mask.linear_attn.utils import IS_AMD, autotune_cache_kwargs, input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +BT_LIST = [8, 16, 32, 64, 128] +NUM_WARPS_AUTOTUNE = [1, 2, 4, 8, 16] if IS_AMD else [1, 2, 4, 8, 16, 32] + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE], + key=["D"], + **autotune_cache_kwargs, +) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + rstd, + eps, + D, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_rstd = 1 / tl.sqrt(tl.sum(b_x * b_x) + eps) + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + tl.store(rstd + i_t, b_rstd) + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE], + key=["D"], + **autotune_cache_kwargs, +) +@triton.jit +def l2norm_bwd_kernel1( + y, + rstd, + dy, + dx, + eps, + D, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + y += i_t * D + dx += i_t * D + dy += i_t * D + + cols = tl.arange(0, BD) + mask = cols < D + b_y = tl.load(y + cols, mask=mask, other=0.0).to(tl.float32) + b_rstd = tl.load(rstd + i_t).to(tl.float32) + b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32) + b_dx = b_dy * b_rstd - tl.sum(b_dy * b_y) * b_y * b_rstd + tl.store(dx + cols, b_dx, mask=mask) + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D", "NB"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def l2norm_fwd_kernel( + x, + y, + rstd, + eps, + T, + D: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + BT: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_rstd = 1 / tl.sqrt(tl.sum(b_x * b_x, 1) + eps) + b_y = b_x * b_rstd[:, None] + + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D", "NB"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def l2norm_bwd_kernel( + y, + rstd, + dy, + dx, + eps, + T, + D: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + BT: tl.constexpr, +): + i_t = tl.program_id(0) + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + + b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32) + b_rstd = tl.load(p_rstd, boundary_check=(0,)).to(tl.float32) + b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) + b_dx = b_dy * b_rstd[:, None] - tl.sum(b_dy * b_y, 1)[:, None] * b_y * b_rstd[:, None] + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + +def l2norm_fwd( + x: paddle.Tensor, + eps: float = 1e-6, + output_dtype=None, +): + x_shape_og = x.shape + x = x.reshape([-1, x.shape[-1]]) + # allocate output + if output_dtype is None: + y = paddle.empty_like(x) + else: + y = paddle.empty_like(x).cast(output_dtype) + assert y.strides[-1] == 1 + T, D = x.shape[0], x.shape[-1] + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + rstd = paddle.empty([T], dtype=paddle.float32) + if D <= 512: + # NOTE(tylerr): Avoid excessive recompilation and autotuning by tolerating a larger range + # of T before recompiling the kernel. + NB = triton.cdiv(T, 2048 * 32) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x=x, + y=y, + rstd=rstd, + eps=eps, + T=T, + D=D, + BD=BD, + NB=NB, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x=x, + y=y, + rstd=rstd, + eps=eps, + D=D, + BD=BD, + ) + return y.reshape(x_shape_og), rstd.reshape(x_shape_og[:-1]) + + +def l2norm_bwd( + y: paddle.Tensor, + rstd: paddle.Tensor, + dy: paddle.Tensor, + eps: float = 1e-6, +): + y_shape_og = y.shape + y = y.reshape([-1, dy.shape[-1]]) + dy = dy.reshape([-1, dy.shape[-1]]) + assert dy.shape == y.shape + # allocate output + dx = paddle.empty_like(y) + T, D = y.shape[0], y.shape[-1] + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // y.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if D <= 512: + NB = triton.cdiv(T, 2048 * 32) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_bwd_kernel[grid]( + y=y, + rstd=rstd, + dy=dy, + dx=dx, + eps=eps, + T=T, + D=D, + BD=BD, + NB=NB, + ) + else: + l2norm_bwd_kernel1[(T,)]( + y=y, + rstd=rstd, + dy=dy, + dx=dx, + eps=eps, + D=D, + BD=BD, + ) + + return dx.reshape(y_shape_og) + + +class L2NormFunction(paddle.autograd.PyLayer): + @staticmethod + @input_guard + def forward( + ctx, + x, + eps=1e-6, + output_dtype=None, + ): + y, rstd = l2norm_fwd(x, eps, output_dtype) + ctx.eps = eps + ctx.x_dtype = x.dtype + ctx.save_for_backward(y, rstd) + return y + + @staticmethod + @input_guard + def backward(ctx, dy): + y, rstd = ctx.saved_tensor() + dx = l2norm_bwd(y, rstd, dy, ctx.eps) + return dx + + +def l2norm( + x: paddle.Tensor, + eps: float = 1e-6, + output_dtype=None, +) -> paddle.Tensor: + return L2NormFunction.apply(x, eps, output_dtype) + + +l2_norm = l2norm + + +class L2Norm(nn.Layer): + def __init__( + self, + eps: float = 1e-6, + output_dtype=None, + ): + super().__init__() + self.eps = eps + self.output_dtype = output_dtype + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return l2norm(x, self.eps, self.output_dtype) diff --git a/flashmask/flash_mask/linear_attn/ops/__init__.py b/flashmask/flash_mask/linear_attn/ops/__init__.py new file mode 100644 index 00000000000..40a96afc6ff --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/flashmask/flash_mask/linear_attn/ops/common/__init__.py b/flashmask/flash_mask/linear_attn/ops/common/__init__.py new file mode 100644 index 00000000000..3d7af8b0fdf --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/common/__init__.py @@ -0,0 +1 @@ +# linear_attn/ops/common diff --git a/flashmask/flash_mask/linear_attn/ops/common/chunk_delta_h.py b/flashmask/flash_mask/linear_attn/ops/common/chunk_delta_h.py new file mode 100644 index 00000000000..b94f414c3f8 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/common/chunk_delta_h.py @@ -0,0 +1,785 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from flash_mask.linear_attn.ops.utils.op import exp, exp2 +from flash_mask.linear_attn.utils import IS_NVIDIA_HOPPER, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8, 16] + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_GK': lambda args: args['gk'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in ([2, 3, 4] if check_shared_mem('ampere') else [2, 1]) + for BV in ([32, 64] if check_shared_mem('ada') else [32]) + ], + key=['H', 'HV', 'K', 'V', 'BT', 'USE_EXP2', 'TRANSPOSE_STATE'], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + USE_EXP2: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // HV, i_nh % HV + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + if TRANSPOSE_STATE: + b_h1 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([BV, 64], dtype=tl.float32) + else: + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * HV + i_h).to(tl.int64) * K*V + v += (bos * HV + i_h).to(tl.int64) * V + k += (bos * H + i_h // (HV // H)).to(tl.int64) * K + w += (bos * HV + i_h).to(tl.int64) * K + if SAVE_NEW_VALUE: + v_new += (bos * HV + i_h).to(tl.int64) * V + + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K*V + if STORE_FINAL_STATE: + ht = ht + i_nh * K*V + + # load initial state + if USE_INITIAL_STATE: + if TRANSPOSE_STATE: + p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + else: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + if TRANSPOSE_STATE: + p_h0_2 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + else: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + if TRANSPOSE_STATE: + p_h0_3 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + else: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + if TRANSPOSE_STATE: + p_h0_4 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + else: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + i_t_int64 = i_t.to(tl.int64) + if TRANSPOSE_STATE: + p_h1 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + else: + p_h1 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + if TRANSPOSE_STATE: + p_h2 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + else: + p_h2 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + if TRANSPOSE_STATE: + p_h3 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + else: + p_h3 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + if TRANSPOSE_STATE: + p_h4 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + else: + p_h4 = tl.make_block_ptr(h + i_t_int64 * HV*K*V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype)) + else: + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype)) + else: + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype)) + else: + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (HV*K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype)) + else: + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + (bos * HV + last_idx * HV + i_h).to(tl.int64)).to(tl.float32) + p_g = tl.make_block_ptr(g + (bos * HV + i_h).to(tl.int64), (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + if USE_EXP2: + b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] + b_g_last = exp2(b_g_last) + else: + b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 *= b_g_last + if K > 64: + b_h2 *= b_g_last + if K > 128: + b_h3 *= b_g_last + if K > 192: + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + if USE_EXP2: + b_h1 *= exp2(b_gk_last1)[None, :] + else: + b_h1 *= exp(b_gk_last1)[None, :] + else: + if USE_EXP2: + b_h1 *= exp2(b_gk_last1)[:, None] + else: + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + if USE_EXP2: + b_h2 *= exp2(b_gk_last2)[None, :] + else: + b_h2 *= exp(b_gk_last2)[None, :] + else: + if USE_EXP2: + b_h2 *= exp2(b_gk_last2)[:, None] + else: + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + if USE_EXP2: + b_h3 *= exp2(b_gk_last3)[None, :] + else: + b_h3 *= exp(b_gk_last3)[None, :] + else: + if USE_EXP2: + b_h3 *= exp2(b_gk_last3)[:, None] + else: + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + (bos + last_idx) * HV*K + i_h * K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + if USE_EXP2: + b_h4 *= exp2(b_gk_last4)[None, :] + else: + b_h4 *= exp(b_gk_last4)[None, :] + else: + if USE_EXP2: + b_h4 *= exp2(b_gk_last4)[:, None] + else: + b_h4 *= exp(b_gk_last4)[:, None] + + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_h1 += tl.trans(tl.dot(b_k, b_v)) + else: + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_h2 += tl.trans(tl.dot(b_k, b_v)) + else: + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_h3 += tl.trans(tl.dot(b_k, b_v)) + else: + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if TRANSPOSE_STATE: + b_h4 += tl.trans(tl.dot(b_k, b_v)) + else: + b_h4 += tl.dot(b_k, b_v) + + if STORE_FINAL_STATE: + if TRANSPOSE_STATE: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + else: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + if TRANSPOSE_STATE: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + else: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + if TRANSPOSE_STATE: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + else: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + if TRANSPOSE_STATE: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + else: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_GK': lambda args: args['gk'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in ([2, 3, 4] if check_shared_mem('ampere') else [1]) + for BV in ([32, 64] if check_shared_mem('ada') else [32]) + ], + key=['H', 'HV', 'K', 'V', 'BT', 'BV', 'USE_G', 'USE_EXP2', 'TRANSPOSE_STATE'], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( + q, + k, + w, + g, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_EXP2: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // HV, i_nh % HV + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + if TRANSPOSE_STATE: + b_dh1 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([BV, 64], dtype=tl.float32) + else: + b_dh1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + q += (bos * H + i_h // (HV // H)).to(tl.int64) * K + k += (bos * H + i_h // (HV // H)).to(tl.int64) * K + w += (bos * HV + i_h).to(tl.int64) * K + do += (bos * HV + i_h).to(tl.int64) * V + dv += (bos * HV + i_h).to(tl.int64) * V + dv2 += (bos * HV + i_h).to(tl.int64) * V + dh += (boh * HV + i_h).to(tl.int64) * K*V + if USE_GK: + gk += (bos * HV + i_h).to(tl.int64) * K + + if USE_INITIAL_STATE: + dh0 += i_nh * K*V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K*V + + if USE_FINAL_STATE_GRADIENT: + if TRANSPOSE_STATE: + p_dht1 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + else: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if K > 64: + if TRANSPOSE_STATE: + p_dht2 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + else: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if K > 128: + if TRANSPOSE_STATE: + p_dht3 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + else: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if K > 192: + if TRANSPOSE_STATE: + p_dht4 = tl.make_block_ptr(dht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + else: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + i_t_int64 = i_t.to(tl.int64) + if TRANSPOSE_STATE: + p_dh1 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + else: + p_dh1 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + if TRANSPOSE_STATE: + p_dh2 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + else: + p_dh2 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + if TRANSPOSE_STATE: + p_dh3 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + else: + p_dh3 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + if TRANSPOSE_STATE: + p_dh4 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + else: + p_dh4 = tl.make_block_ptr(dh + i_t_int64*HV*K*V, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + bg_last = tl.load(g + (bos + last_idx) * HV + i_h).to(tl.float32) + p_g = tl.make_block_ptr(g + bos * HV + i_h, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + if USE_EXP2: + bg_last_exp = exp2(bg_last) + b_g_exp = exp2(b_g) + else: + bg_last_exp = exp(bg_last) + b_g_exp = exp(b_g) + + p_dv = tl.make_block_ptr(dv, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # Update dv + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + last_idx * HV*K + o_k1, mask=(o_k1 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + b_dv = tl.dot(b_k, tl.trans(b_dh1).to(b_k.dtype)) + else: + b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) + + if K > 64: + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + last_idx * HV*K + o_k2, mask=(o_k2 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + b_dv += tl.dot(b_k, tl.trans(b_dh2).to(b_k.dtype)) + else: + b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) + + if K > 128: + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + last_idx * HV*K + o_k3, mask=(o_k3 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + b_dv += tl.dot(b_k, tl.trans(b_dh3).to(b_k.dtype)) + else: + b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) + + if K > 192: + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + last_idx * HV*K + o_k4, mask=(o_k4 < K), other=0.).to(tl.float32) + if TRANSPOSE_STATE: + b_dv += tl.dot(b_k, tl.trans(b_dh4).to(b_k.dtype)) + else: + b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + if USE_EXP2: + b_dv *= tl.where(m_t, exp2(bg_last - b_g), 0)[:, None] + else: + b_dv *= tl.where(m_t, exp(bg_last - b_g), 0)[:, None] + b_dv += tl.load(p_dv, boundary_check=(0, 1)) + + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # Update dh + p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (0, i_t * BT), (64, BT), (0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + if USE_G: + b_dh1 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if TRANSPOSE_STATE: + if USE_EXP2: + b_dh1 *= exp2(b_gk_last1)[None, :] + else: + b_dh1 *= exp(b_gk_last1)[None, :] + else: + if USE_EXP2: + b_dh1 *= exp2(b_gk_last1[:, None]) + else: + b_dh1 *= exp(b_gk_last1[:, None]) + if TRANSPOSE_STATE: + b_dh1 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) + else: + b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 64: + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (64, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (64, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh2 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if TRANSPOSE_STATE: + if USE_EXP2: + b_dh2 *= exp2(b_gk_last2)[None, :] + else: + b_dh2 *= exp(b_gk_last2)[None, :] + else: + if USE_EXP2: + b_dh2 *= exp2(b_gk_last2[:, None]) + else: + b_dh2 *= exp(b_gk_last2[:, None]) + if TRANSPOSE_STATE: + b_dh2 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) + else: + b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 128: + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (128, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (128, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh3 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if TRANSPOSE_STATE: + if USE_EXP2: + b_dh3 *= exp2(b_gk_last3)[None, :] + else: + b_dh3 *= exp(b_gk_last3)[None, :] + else: + if USE_EXP2: + b_dh3 *= exp2(b_gk_last3[:, None]) + else: + b_dh3 *= exp(b_gk_last3[:, None]) + if TRANSPOSE_STATE: + b_dh3 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) + else: + b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 192: + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (192, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, HV*K), (192, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh4 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if TRANSPOSE_STATE: + if USE_EXP2: + b_dh4 *= exp2(b_gk_last4)[None, :] + else: + b_dh4 *= exp(b_gk_last4)[None, :] + else: + if USE_EXP2: + b_dh4 *= exp2(b_gk_last4[:, None]) + else: + b_dh4 *= exp(b_gk_last4[:, None]) + if TRANSPOSE_STATE: + b_dh4 += tl.trans(tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype))) + else: + b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + + if USE_INITIAL_STATE: + if TRANSPOSE_STATE: + p_dh0 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + else: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + if TRANSPOSE_STATE: + p_dh1 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + else: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + if TRANSPOSE_STATE: + p_dh2 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + else: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + if TRANSPOSE_STATE: + p_dh3 = tl.make_block_ptr(dh0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + else: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: paddle.Tensor, + w: paddle.Tensor, + u: paddle.Tensor, + g: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: paddle.Tensor | None = None, + cu_seqlens_cpu: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor | None]: + B, T, H, K, V, HV = *k.shape, u.shape[-1], u.shape[2] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + if transpose_state_layout: + h = paddle.empty(shape=[B, NT, HV, V, K], dtype=k.dtype) + final_state = paddle.zeros(shape=[N, HV, V, K], dtype=paddle.float32) if output_final_state else None + else: + h = paddle.empty(shape=[B, NT, HV, K, V], dtype=k.dtype) + final_state = paddle.zeros(shape=[N, HV, K, V], dtype=paddle.float32) if output_final_state else None + + v_new = paddle.empty_like(u) if save_new_value else None + def grid(meta): return (triton.cdiv(V, meta['BV']), N*HV) + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + TRANSPOSE_STATE=transpose_state_layout, + ) + return h, v_new, final_state + + +def chunk_gated_delta_rule_bwd_dhu( + q: paddle.Tensor, + k: paddle.Tensor, + w: paddle.Tensor, + do: paddle.Tensor, + dv: paddle.Tensor, + g: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + h0: paddle.Tensor | None = None, + dht: paddle.Tensor | None = None, + scale: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + B, T, H, K, V, HV = *q.shape, do.shape[-1], do.shape[2] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + if transpose_state_layout: + dh = paddle.empty(shape=[B, NT, HV, V, K], dtype=q.dtype) + else: + dh = paddle.empty(shape=[B, NT, HV, K, V], dtype=q.dtype) + dh0 = paddle.empty_like(h0, dtype=paddle.float32) if h0 is not None else None + dv2 = paddle.empty_like(dv) + + def grid(meta): return (triton.cdiv(V, meta['BV']), N*HV) + chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( + q=q, + k=k, + w=w, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + TRANSPOSE_STATE=transpose_state_layout, + ) + return dh, dh0, dv2 diff --git a/flashmask/flash_mask/linear_attn/ops/common/chunk_o.py b/flashmask/flash_mask/linear_attn/ops/common/chunk_o.py new file mode 100644 index 00000000000..b6b16b6f8fb --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/common/chunk_o.py @@ -0,0 +1,794 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +from functools import lru_cache + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp, exp2 +from flash_mask.linear_attn.utils import IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +BKV_LIST = [64, 128] if check_shared_mem() else ([32, 64] if check_shared_mem('ada') else [32]) +NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8] + + +@lru_cache(maxsize=None) +def _const_tiling(device_idx: int) -> int: + if check_shared_mem('hopper', device_idx): + return 128 + if check_shared_mem('ada', device_idx): + return 64 + return 32 + + +@lru_cache(maxsize=None) +def _chunk_o_launch_meta(device_idx: int, T: int, K: int, V: int, BT: int) -> tuple[int, int, int, int]: + const_tiling = _const_tiling(device_idx) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NT = triton.cdiv(T, BT) + NK = triton.cdiv(K, BK) + return BK, BV, NT, NK + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_G_GAMMA': lambda args: args['g_gamma'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': 128, 'BV': 128}, num_warps=8, num_stages=3), + triton.Config({'BK': 64, 'BV': 64}, num_warps=4, num_stages=3), + triton.Config({'BK': 32, 'BV': 32}, num_warps=2, num_stages=3), + ], + key=['H', 'HV', 'K', 'V', 'BT', 'TRANSPOSE_STATE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + g_gamma, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_EXP2: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // HV, i_bh % HV + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h // (HV // H)) * K + k += (bos * H + i_h // (HV // H)) * K + v += (bos * HV + i_h) * V + o += (bos * HV + i_h) * V + h += (i_tg * HV + i_h).to(tl.int64) * K*V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + if TRANSPOSE_STATE: + p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) + else: + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + if TRANSPOSE_STATE: + b_o += tl.dot(b_q, tl.trans(b_h)) + else: + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * HV + i_h + p_g = tl.make_block_ptr(g, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + b_o = b_o * exp2(b_g)[:, None] + b_A = b_A * exp2(b_g[:, None] - b_g[None, :]) + else: + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + if USE_EXP2: + b_o = b_o * exp2(b_g)[:, None] + b_A = b_A * exp2(b_g[:, None] - b_g[None, :]) + else: + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_G_GAMMA': lambda args: args['g_gamma'] is not None, + 'USE_DW': lambda args: args['dw'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'HV', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_DW', 'TRANSPOSE_STATE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dqkwg( + q, + k, + v, + g, + g_gamma, + h, + do, + dh, + dq, + dk, + dw, + dv, + dg, + cu_seqlens, + chunk_indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_EXP2: tl.constexpr, + USE_DW: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // HV, i_bh % HV + + all = B * T + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += (bos * HV + i_h) * V + do += (bos * HV + i_h) * V + h += (i_tg * HV + i_h).to(tl.int64) * K*V + dh += (i_tg * HV + i_h).to(tl.int64) * K*V + q += (bos * H + i_h // (HV // H)) * K + k += (bos * H + i_h // (HV // H)) * K + dq += (bos * HV + i_h) * K + dk += (bos * HV + i_h) * K + + # for delta rule only + if USE_DW: + dw += (bos * HV + i_h) * K + dv += (bos * HV + i_h) * V + + if USE_G: + dg += i_k * all * HV + b_dg_last = tl.zeros([1], dtype=tl.float32) if USE_G else None + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + b_g_last = b_gamma * min(BT, T - i_t * BT) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + if TRANSPOSE_STATE: + p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) + p_dh = tl.make_block_ptr(dh, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) + else: + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + if USE_G: + b_dg_last += (tl.sum(b_h * b_dh)) + # [BT, BV] @ [BV, BT] -> [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + if USE_DW: + p_dv = tl.make_block_ptr(dv, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + if USE_DW: + p_dw = tl.make_block_ptr(dw, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_dq = tl.make_block_ptr(dq, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + if USE_G: + b_dg = tl.zeros([BT], dtype=tl.float32) + g += bos * HV + i_h + dg += bos * HV + i_h + p_g = tl.make_block_ptr(g, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * HV) + if USE_EXP2: + b_dg_last *= exp2(b_g_last) + b_dq = b_dq * exp2(b_g)[:, None] * scale + else: + b_dg_last *= exp(b_g_last) + b_dq = b_dq * exp(b_g)[:, None] * scale + b_dg += tl.sum(b_dq * b_q, axis=1) + + if USE_EXP2: + b_dk = b_dk * tl.where(m_t, exp2(-b_g + b_g_last), 0)[:, None] + else: + b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None] + b_dg -= tl.sum(b_k * b_dk, axis=1) + b_dg_last += tl.sum(b_dk * b_k) + + if USE_EXP2: + b_ds = tl.where(m_A, b_ds * exp2(b_g[:, None] - b_g[None, :]), 0) * scale + else: + b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale + b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k)) + b_dg += tl.sum(b_ds2, axis=1) + b_dg -= tl.sum(b_ds2, axis=0) + + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + p_dg = tl.make_block_ptr(dg, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue + # b_dg = tl.dot(tl.where(o_t[:, None] <= o_t[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last) + b_dg = tl.where(o_t < min(i_t * BT + BT, T) - 1, b_dg, b_dg + b_dg_last) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + elif USE_G_GAMMA: + if USE_EXP2: + b_dq = b_dq * exp2(b_g)[:, None] * scale + b_dk = b_dk * tl.where(m_t, exp2(-b_g + b_g_last), 0)[:, None] + b_ds = tl.where(m_A, b_ds * exp2(b_g[:, None] - b_g[None, :]), 0) * scale + else: + b_dq = b_dq * exp(b_g)[:, None] * scale + b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None] + b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + else: + b_ds = tl.where(m_A, b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) * scale + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_G_GAMMA': lambda args: args['g_gamma'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'HV', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_G_GAMMA'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv( + q, + k, + g, + g_gamma, + do, + dv, + dh, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // HV, i_bh % HV + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + q += (bos * H + i_h // (HV // H)) * K + k += (bos * H + i_h // (HV // H)) * K + do += (bos * HV + i_h) * V + dv += (bos * HV + i_h) * V + dh += (i_tg * HV + i_h).to(tl.int64) * K*V + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + if USE_G: + g += bos * HV + i_h + p_g = tl.make_block_ptr(g, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * HV) + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + b_g_last = b_gamma * min(BT, T - i_t * BT) + + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + if USE_G or USE_G_GAMMA: + if USE_EXP2: + b_A = tl.where(m_A, b_A * exp2(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + b_dv *= tl.where(m_t, exp2(-b_g + b_g_last), 0)[:, None] + else: + b_A = tl.where(m_A, b_A * exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + b_dv *= tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None] + else: + b_A = tl.where(m_A, b_A * scale, 0).to(do.dtype.element_ty) + p_do = tl.make_block_ptr(do, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_G_GAMMA': lambda args: args['g_gamma'] is not None, + 'USE_A': lambda args: args['A'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'HV', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + g_gamma, + A, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_EXP2: tl.constexpr, + USE_A: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // HV, i_bh % HV + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h // (HV // H)) * K + k += (bos * H + i_h // (HV // H)) * K + do += (bos * HV + i_h) * V + dv += (bos * HV + i_h) * V + + if USE_A: + p_A = tl.make_block_ptr(A + (bos * HV + i_h) * BT, (BT, T), (1, HV*BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + else: + if USE_G: + g += bos * HV + i_h + p_g = tl.make_block_ptr(g, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_G_GAMMA: + b_gamma = tl.load(g_gamma + i_h) + b_g = b_gamma * (tl.arange(0, BT) + 1) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) * scale + if USE_G or USE_G_GAMMA: + if USE_EXP2: + b_A *= exp2(b_g[None, :] - b_g[:, None]) + else: + b_A *= exp(b_g[None, :] - b_g[:, None]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + h: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + scale: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> paddle.Tensor: + B, T, H, K, V, HV = *q.shape, v.shape[-1], v.shape[2] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = paddle.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HV) + chunk_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + g_gamma=g_gamma, + o=o, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + TRANSPOSE_STATE=transpose_state_layout, + ) + return o + + +def chunk_bwd_dv( + q: paddle.Tensor, + k: paddle.Tensor, + do: paddle.Tensor, + dh: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + scale: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, +) -> paddle.Tensor: + B, T, H, K, V, HV = *k.shape, do.shape[-1], do.shape[2] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + CONST_TILING = _const_tiling(k.place.gpu_device_id()) + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NV = triton.cdiv(V, BV) + if scale is None: + scale = k.shape[-1] ** -0.5 + + dv = paddle.empty_like(do) + grid = (NV, NT, B * HV) + chunk_bwd_kernel_dv[grid]( + q=q, + k=k, + g=g, + g_gamma=g_gamma, + do=do, + dv=dv, + dh=dh, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_EXP2=use_exp2, + ) + return dv + + +def chunk_bwd_dv_local( + q: paddle.Tensor, + k: paddle.Tensor, + do: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + A: paddle.Tensor | None = None, + scale: float = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, +) -> paddle.Tensor: + B, T, H, K, V, HV = *k.shape, do.shape[-1], do.shape[2] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is None: + BK, BV, NT, _ = _chunk_o_launch_meta(k.place.gpu_device_id(), T, K, V, BT) + else: + const_tiling = _const_tiling(k.place.gpu_device_id()) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NT = len(chunk_indices) + + dv = paddle.empty_like(do) + grid = (NT, B * HV) + chunk_bwd_kernel_dv_local[grid]( + q=q, + k=k, + g=g, + g_gamma=g_gamma, + A=A, + do=do, + dv=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_EXP2=use_exp2, + ) + return dv + + +def chunk_bwd_dqkwg( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + do: paddle.Tensor, + h: paddle.Tensor, + dh: paddle.Tensor, + w: paddle.Tensor | None = None, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + dv: paddle.Tensor | None = None, + scale: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + + B, T, H, K, V, HV = *k.shape, v.shape[-1], v.shape[2] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is None: + BK, BV, NT, NK = _chunk_o_launch_meta(k.place.gpu_device_id(), T, K, V, BT) + else: + NT = len(chunk_indices) + const_tiling = _const_tiling(k.place.gpu_device_id()) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NK = triton.cdiv(K, BK) + dq = paddle.empty(shape=[B, T, HV, K], dtype=q.dtype) + dk = paddle.empty(shape=[B, T, HV, K], dtype=k.dtype) + dg = None + reduce_dg = False + if g is not None: + if NK == 1: + dg = paddle.empty(shape=list(g.shape), dtype=paddle.float32) + else: + dg = paddle.empty(shape=[NK, *g.shape], dtype=paddle.float32) + reduce_dg = True + dw = paddle.empty_like(w) if w is not None else None + + grid = (NK, NT, B * HV) + chunk_bwd_kernel_dqkwg[grid]( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + h=h, + do=do, + dh=dh, + dw=dw, + dq=dq, + dk=dk, + dv=dv, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + B=B, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_EXP2=use_exp2, + TRANSPOSE_STATE=transpose_state_layout, + ) + + if H != HV: + dq = dq.reshape([B, T, H, HV // H, K]).sum(axis=3) + dk = dk.reshape([B, T, H, HV // H, K]).sum(axis=3) + if dg is not None and reduce_dg: + dg = dg.sum(axis=0) + return dq, dk, dw, dg diff --git a/flashmask/flash_mask/linear_attn/ops/common/chunk_scaled_dot_kkt.py b/flashmask/flash_mask/linear_attn/ops/common/chunk_scaled_dot_kkt.py new file mode 100644 index 00000000000..01faba3c082 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/common/chunk_scaled_dot_kkt.py @@ -0,0 +1,135 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp +from flash_mask.linear_attn.utils import autotune_cache_kwargs +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'HV', 'K', 'BT', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // HV, i_bh % HV + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_b = tl.make_block_ptr(beta + bos*HV + i_h, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h // (HV // H)) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g + bos*HV + i_h, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A *= exp(b_g_diff) + b_A *= b_b[:, None] + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos*HV + i_h) * BT, (T, BT), (BT*HV, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: paddle.Tensor, + g: paddle.Tensor | None = None, + beta: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + output_dtype: paddle.dtype = paddle.float32, + chunk_indices: paddle.Tensor | None = None, +) -> paddle.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (paddle.Tensor): + The key tensor of shape `[B, T, H, K]` where `H` is the number of query/key heads. + beta (paddle.Tensor): + The beta tensor of shape `[B, T, HV]` where `HV` is the number of value/output heads. + g (paddle.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, HV]`. Default: `None`. + gk (paddle.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, HV, K]` applied to the key tensor. Default: `None`. + cu_seqlens (paddle.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (paddle.dtype): + The dtype of the output tensor. Default: `paddle.float32` + + Returns: + beta * K * K^T of shape `[B, T, HV, BT]` where `BT` is the chunk size. + For GVA, H < HV and HV % H == 0. For standard attention, H == HV. + """ + B, T, H, K, HV = *k.shape, beta.shape[2] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = paddle.empty(shape=[B, T, HV, BT], dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * HV)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + HV=HV, + K=K, + BT=BT, + ) + return A diff --git a/flashmask/flash_mask/linear_attn/ops/common/fused_recurrent.py b/flashmask/flash_mask/linear_attn/ops/common/fused_recurrent.py new file mode 100644 index 00000000000..fb9134e2ff6 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/common/fused_recurrent.py @@ -0,0 +1,606 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.op import exp +from flash_mask.linear_attn.utils import autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [4, 8] + ], + key=['BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_GK', 'USE_GV'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['B', 'T']) +def fused_recurrent_fwd_kernel( + q, + k, + v, + g, + g_gamma, + gk, + gv, + o, + h0, + ht, + cu_seqlens, + scale, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + + all = B * T + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + if USE_G_GAMMA: + b_g_gamma = tl.load(g_gamma + i_h) + + m_k = o_k < K + m_v = o_v < V + m_h = m_k[:, None] & m_v[None, :] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=m_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=m_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + if USE_G_GAMMA: + b_h = b_h * exp(b_g_gamma) + if USE_GK: + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[:, None]) + if USE_GV: + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[None, :]) + b_h += b_k[:, None] * b_v[None, :] + b_o = b_h * b_q[:, None] + b_o = tl.sum(b_o, axis=0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=m_v) + p_q += (-1 if REVERSE else 1) * H*K + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_o += (-1 if REVERSE else 1) * H*V + if USE_G: + p_g += (-1 if REVERSE else 1) * H + if USE_GK: + p_gk += (-1 if REVERSE else 1) * H*K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * H*V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=m_h) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [4] + ], + key=['BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_GK', 'USE_GV'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['B', 'T']) +def fused_recurrent_bwd_kernel( + q, + k, + v, + g, + g_gamma, + gk, + gv, + o, + h0, + do, + dq, + dk, + dv, + dg, + dgk, + dgv, + dht, + dh0, + cu_seqlens, + scale, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + + all = B * T + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + NV = tl.cdiv(V, BV) + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + m_k = o_k < K + m_v = o_v < V + m_h = m_k[:, None] & m_v[None, :] + + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + if USE_G_GAMMA: + b_g_gamma = tl.load(g_gamma + i_h) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=m_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=m_v, other=0).to(tl.float32) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + if USE_G_GAMMA: + b_h = b_h * exp(b_g_gamma) + if USE_GK: + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[:, None]) + if USE_GV: + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[None, :]) + b_h += b_k[:, None] * b_v[None, :] + b_dq = b_h * b_do[None, :] + b_dq = tl.sum(b_dq, axis=1) * scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=m_k) + + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_do += (-1 if REVERSE else 1) * H*V + p_dq += (-1 if REVERSE else 1) * H*K + if USE_G: + p_g += (-1 if REVERSE else 1) * H + if USE_GK: + p_gk += (-1 if REVERSE else 1) * H*K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * H*V + + # sync threads + tl.debug_barrier() + + p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + o_v + + p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_dq = dq + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + o_v + if USE_G: + p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h + p_dg = dg + ((i_k * NV + i_v) * all + bos + ((T - 1) if not REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dgk = dgk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + o_k + if USE_GV: + p_o = o + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_dgv = dgv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + o_v + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_dh += tl.load(p_dht, mask=m_h, other=0).to(tl.float32) + + if USE_G: + b_dg = tl.sum(b_h * b_dh) + if USE_GK: + b_dgk = tl.sum(b_h * b_dh, 1) + if USE_GV: + b_dgv = tl.sum(b_h * b_dh, 0) + + for _ in range(T): + b_q = tl.load(p_q, mask=m_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=m_v, other=0).to(tl.float32) + b_dh += (b_q * scale)[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_dq = tl.load(p_dq, mask=m_k, other=0).to(tl.float32) + b_dg += tl.sum(b_q * b_dq - b_k * b_dk) + b_dh *= exp(b_g) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty)) + if USE_G_GAMMA: + b_dh *= exp(b_g_gamma) + if USE_GK: + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_dq = tl.load(p_dq, mask=m_k, other=0).to(tl.float32) + b_dgk += b_q * b_dq - b_k * b_dk + b_dh *= exp(b_gk)[:, None] + tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), mask=m_k) + if USE_GV: + b_o = tl.load(p_o, mask=m_v, other=0).to(tl.float32) + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + if i_k == 0: + b_dgv += b_o * b_do + b_dgv -= b_v * b_dv + b_dh *= exp(b_gv)[None, :] + tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), mask=m_v) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=m_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=m_v) + + p_q += (1 if REVERSE else -1) * H*K + p_k += (1 if REVERSE else -1) * H*K + p_v += (1 if REVERSE else -1) * H*V + + p_do += (1 if REVERSE else -1) * H*V + p_dq += (1 if REVERSE else -1) * H*K + p_dk += (1 if REVERSE else -1) * H*K + p_dv += (1 if REVERSE else -1) * H*V + if USE_G: + p_g += (1 if REVERSE else -1) * H + p_dg += (1 if REVERSE else -1) * H + if USE_GK: + p_gk += (1 if REVERSE else -1) * H*K + p_dgk += (1 if REVERSE else -1) * H*K + if USE_GV: + p_o += (1 if REVERSE else -1) * H*V + p_gv += (1 if REVERSE else -1) * H*V + p_dgv += (1 if REVERSE else -1) * H*V + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = dh0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=m_h) + + +def fused_recurrent_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + ht = paddle.empty(shape=[N, H, K, V], dtype=paddle.float32) if output_final_state else None + o = paddle.empty(shape=[NK, *v.shape], dtype=paddle.float32) + + grid = (NV, NK, N * H) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + gk=gk, + gv=gv, + o=o, + h0=h0, + ht=ht, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_G_GAMMA=g_gamma is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + ) + o = o.sum(axis=0) + return o, ht + + +def fused_recurrent_bwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + o: paddle.Tensor | None = None, + do: paddle.Tensor | None = None, + dht: paddle.Tensor | None = None, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + dq = paddle.empty(shape=[NV, *q.shape], dtype=paddle.float32) + dk = paddle.empty(shape=[NV, *k.shape], dtype=paddle.float32) + dv = paddle.empty(shape=[NK, *v.shape], dtype=paddle.float32) + dh0 = paddle.empty_like(h0) if h0 is not None else None + + dg, dgk, dgv = None, None, None + if g is not None: + dg = paddle.empty(shape=[NK*NV, *g.shape], dtype=paddle.float32) + if gk is not None: + dgk = paddle.empty(shape=[NV, *gk.shape], dtype=paddle.float32) + if gv is not None: + dgv = paddle.empty(shape=[NK, *gv.shape], dtype=paddle.float32) + + grid = (NV, NK, N * H) + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + gk=gk, + gv=gv, + o=o, + h0=h0, + do=do, + dq=dq, + dk=dk, + dv=dv, + dg=dg, + dgk=dgk, + dgv=dgv, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_G_GAMMA=g_gamma is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + ) + dq = dq.sum(axis=0) + dk = dk.sum(axis=0) + dv = dv.sum(axis=0) + if g is not None: + dg = dg.sum(axis=0).cast(g.dtype) + if gk is not None: + dgk = dgk.sum(axis=0).cast(gk.dtype) + if gv is not None: + dgv = dgv.sum(axis=0).cast(gv.dtype) + + return dq, dk, dv, dg, dgk, dgv, dh0 + + +class FusedRecurrentFunction(paddle.autograd.PyLayer): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, + ): + o, ht = fused_recurrent_fwd( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, k, v, g, g_gamma, gk, gv, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + ctx.cu_seqlens = cu_seqlens + ctx.output_final_state = output_final_state + # Paddle PyLayer backward must return exactly as many values as tensor inputs. + _forward_args = [ + q, k, v, g, g_gamma, gk, gv, scale, initial_state, + output_final_state, reverse, cu_seqlens, + ] + ctx._tensor_mask = tuple(isinstance(a, paddle.Tensor) for a in _forward_args) + ctx._needs_grad = tuple( + isinstance(a, paddle.Tensor) and not a.stop_gradient for a in _forward_args + ) + # Paddle PyLayer forward cannot return None, use dummy tensor as placeholder + if ht is None: + ht = paddle.zeros([1], dtype=q.dtype) + return o.cast(q.dtype), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + # When output_final_state=False, forward returned a dummy tensor; + # restore dht to None so downstream bwd functions handle it correctly + if not ctx.output_final_state: + dht = None + q, k, v, g, g_gamma, gk, gv, initial_state, o = ctx.saved_tensor() + dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd( + q=q, + k=k, + v=v, + g=g, + g_gamma=g_gamma, + gk=gk, + gv=gv, + o=o, + do=do, + dht=dht, + scale=ctx.scale, + initial_state=initial_state, + reverse=ctx.reverse, + cu_seqlens=ctx.cu_seqlens, + ) + # Build all grads in forward arg order, filter to tensor inputs only. + # Order: q, k, v, g, g_gamma, gk, gv, scale, initial_state, + # output_final_state, reverse, cu_seqlens + all_grads = [ + dq.cast(q.dtype), dk.cast(k.dtype), dv.cast(v.dtype), + dg, None, dgk, dgv, None, dh0, None, None, None, + ] + return tuple( + g if needs_grad else None + for g, is_tensor, needs_grad in zip(all_grads, ctx._tensor_mask, ctx._needs_grad) + if is_tensor + ) + + +def fused_recurrent( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + g_gamma: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, +): + if scale is None: + scale = k.shape[-1] ** -0.5 + o, ht = FusedRecurrentFunction.apply( + q, + k, + v, + g, + g_gamma, + gk, + gv, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + ) + # Convert dummy tensor back to None when output_final_state=False + if not output_final_state: + ht = None + return o, ht diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/__init__.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/__init__.py new file mode 100644 index 00000000000..26962572128 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +from .chunk import chunk_gated_delta_rule, chunk_gdn +from .fused_recurrent import fused_recurrent_gated_delta_rule, fused_recurrent_gdn +from .naive import naive_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", "chunk_gdn", + "fused_recurrent_gated_delta_rule", "fused_recurrent_gdn", + "naive_recurrent_gated_delta_rule", +] diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/chunk.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/chunk.py new file mode 100644 index 00000000000..bd82ce002b8 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/chunk.py @@ -0,0 +1,517 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import warnings + +import paddle + +from flash_mask.linear_attn.modules.l2norm import l2norm_bwd, l2norm_fwd +from flash_mask.linear_attn.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from flash_mask.linear_attn.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from flash_mask.linear_attn.ops.gated_delta_rule.chunk_fwd import chunk_gated_delta_rule_fwd_intra +from flash_mask.linear_attn.ops.gated_delta_rule.gate import gdn_gate_bwd, gdn_gate_chunk_cumsum +from flash_mask.linear_attn.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from flash_mask.linear_attn.ops.utils import chunk_local_cumsum +from flash_mask.linear_attn.ops.utils.constant import RCP_LN2 +from flash_mask.linear_attn.ops.utils.index import prepare_chunk_indices +from flash_mask.linear_attn.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from flash_mask.linear_attn.triton_utils import activate_paddle_driver, compat_kernel_wrapper_fastpath + + +def _use_saved_intermediates_no_recompute( + output_final_state: bool, + cu_seqlens: paddle.Tensor | None, + cp_context, +) -> bool: + return not output_final_state and cu_seqlens is None and cp_context is None + + +def chunk_gated_delta_rule_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + output_final_state: bool, + cu_seqlens: paddle.Tensor | None = None, + cp_context=None, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = True, + transpose_state_layout: bool = False, + use_gate_in_kernel: bool = False, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + return_intermediates: bool = False, +): + g_input = g if use_gate_in_kernel else None + if use_gate_in_kernel: + g = gdn_gate_chunk_cumsum( + g=g, + A_log=A_log, + chunk_size=64, + scale=RCP_LN2 if use_exp2 else None, + dt_bias=dt_bias, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + else: + g = chunk_local_cumsum( + g, + chunk_size=64, + scale=RCP_LN2 if use_exp2 else None, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + # obtain WY representation. u is actually the new v. + # fused kkt + solve_tril + recompute_w_u + w, u, A = chunk_gated_delta_rule_fwd_intra( + k=k, + v=v, + g=g, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + ) + + # CP (Context Parallel) is skipped in Phase 1 + + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + if return_intermediates: + return g, o, A, final_state, initial_state, g_input, w, u, h, v_new + return g, o, A, final_state, initial_state, g_input + + +def chunk_gated_delta_rule_bwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + A: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + do: paddle.Tensor, + dht: paddle.Tensor, + cu_seqlens: paddle.Tensor | None = None, + cp_context=None, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = True, + transpose_state_layout: bool = False, + use_gate_in_kernel: bool = False, + g_input: paddle.Tensor | None = None, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + saved_w: paddle.Tensor | None = None, + saved_u: paddle.Tensor | None = None, + saved_h: paddle.Tensor | None = None, + saved_v_new: paddle.Tensor | None = None, +): + if all(t is not None for t in (saved_w, saved_u, saved_h, saved_v_new)): + w, u, h, v_new = saved_w, saved_u, saved_h, saved_v_new + else: + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + ) + + # CP (Context Parallel) is skipped in Phase 1 + + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + ) + + # CP (Context Parallel) is skipped in Phase 1 + + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + dk2, dv, db, dg2 = prepare_wy_repr_bwd( + k=k, + v=v, + beta=beta, + g=g, + A=A, + dw=dw, + du=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + ) + dk.add_(dk2) + dg.add_(dg2) + dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices) + dA_log, ddt_bias = None, None + if use_gate_in_kernel: + dg, dA_log, ddt_bias = gdn_gate_bwd(g=g_input, A_log=A_log, dt_bias=dt_bias, dyg=dg) + return dq, dk, dv, db, dg, dh0, dA_log, ddt_bias + + +class ChunkGatedDeltaRuleFunction(paddle.autograd.PyLayer): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + output_final_state: bool, + cu_seqlens: paddle.Tensor | None = None, + cu_seqlens_cpu: paddle.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + cp_context=None, + transpose_state_layout: bool = False, + use_gate_in_kernel: bool = False, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + ): + # Save original input refs before any reassignment, for _tensor_mask/_needs_grad below. + _orig_forward_args = [ + q, k, v, g, beta, scale, initial_state, output_final_state, + cu_seqlens, cu_seqlens_cpu, use_qk_l2norm_in_kernel, cp_context, + transpose_state_layout, use_gate_in_kernel, A_log, dt_bias, + ] + q_rstd, k_rstd = None, None + if use_qk_l2norm_in_kernel: + q, q_rstd = l2norm_fwd(q) + k, k_rstd = l2norm_fwd(k) + + chunk_indices = prepare_chunk_indices( + cu_seqlens, 64, cu_seqlens_cpu=cu_seqlens_cpu) if cu_seqlens is not None else None + use_saved_intermediates = _use_saved_intermediates_no_recompute( + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + cp_context=cp_context, + ) + with activate_paddle_driver(), compat_kernel_wrapper_fastpath(): + gdn_outputs = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + cp_context=cp_context, + chunk_indices=chunk_indices, + transpose_state_layout=transpose_state_layout, + use_gate_in_kernel=use_gate_in_kernel, + A_log=A_log, + dt_bias=dt_bias, + return_intermediates=use_saved_intermediates, + ) + if use_saved_intermediates: + g, o, A, final_state, initial_state, g_input, w, u, h, v_new = gdn_outputs + else: + g, o, A, final_state, initial_state, g_input = gdn_outputs + w = u = h = v_new = None + ctx.save_for_backward( + q, q_rstd, k, k_rstd, v, g, beta, A, + initial_state, cu_seqlens, chunk_indices, + g_input, A_log, dt_bias, + w, u, h, v_new, + ) + # Store non-tensor params as ctx attributes + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.cp_context = cp_context + ctx.transpose_state_layout = transpose_state_layout + ctx.use_gate_in_kernel = use_gate_in_kernel + ctx.output_final_state = output_final_state + ctx.use_saved_intermediates = use_saved_intermediates + # Paddle PyLayer backward must return exactly as many values as tensor inputs. + # Record which forward args are tensors so backward can filter its return. + # Also record which tensor inputs need gradients (stop_gradient=False), + # because Paddle requires backward to return None for stop_gradient=True tensors. + # IMPORTANT: Use _orig_forward_args (captured before q/k/g/initial_state were + # reassigned by l2norm_fwd / chunk_gated_delta_rule_fwd) so that stop_gradient + # reflects the *caller's* tensors, not the internal intermediates. + ctx._tensor_mask = tuple(isinstance(a, paddle.Tensor) for a in _orig_forward_args) + ctx._needs_grad = tuple( + isinstance(a, paddle.Tensor) and not a.stop_gradient for a in _orig_forward_args + ) + # Paddle PyLayer forward cannot return None, use dummy tensor as placeholder + if final_state is None: + final_state = paddle.zeros([1], dtype=q.dtype) + return o.cast(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + # When output_final_state=False, forward returned a dummy tensor; + # restore dht to None so downstream bwd functions handle it correctly + if not ctx.output_final_state: + dht = None + (q, q_rstd, k, k_rstd, v, g, beta, A, + initial_state, cu_seqlens, chunk_indices, + g_input, A_log, dt_bias, + w, u, h, v_new) = ctx.saved_tensor() + with activate_paddle_driver(), compat_kernel_wrapper_fastpath(): + dq, dk, dv, db, dg, dh0, dA_log, ddt_bias = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + cp_context=ctx.cp_context, + chunk_indices=chunk_indices, + transpose_state_layout=ctx.transpose_state_layout, + use_gate_in_kernel=ctx.use_gate_in_kernel, + g_input=g_input, + A_log=A_log, + dt_bias=dt_bias, + saved_w=w if ctx.use_saved_intermediates else None, + saved_u=u if ctx.use_saved_intermediates else None, + saved_h=h if ctx.use_saved_intermediates else None, + saved_v_new=v_new if ctx.use_saved_intermediates else None, + ) + if ctx.use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q, q_rstd, dq) + dk = l2norm_bwd(k, k_rstd, dk) + # Build all grads in forward arg order, filter to tensor inputs only. + # Order: q, k, v, g, beta, scale, initial_state, output_final_state, + # cu_seqlens, cu_seqlens_cpu, use_qk_l2norm_in_kernel, cp_context, + # transpose_state_layout, use_gate_in_kernel, A_log, dt_bias + all_grads = [ + dq.cast(q.dtype), dk.cast(k.dtype), dv.cast(v.dtype), dg.cast(g.dtype), db.cast(beta.dtype), + None, dh0, None, None, None, None, None, None, None, dA_log, ddt_bias, + ] + return tuple( + g if needs_grad else None + for g, is_tensor, needs_grad in zip(all_grads, ctx._tensor_mask, ctx._needs_grad) + if is_tensor + ) + + +def chunk_gated_delta_rule( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float = None, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: paddle.Tensor | None = None, + cu_seqlens_cpu: paddle.Tensor | None = None, + cp_context=None, + transpose_state_layout: bool = False, + **kwargs, +): + r""" + Args: + q (paddle.Tensor): + queries of shape `[B, T, H, K]`. + k (paddle.Tensor): + keys of shape `[B, T, H, K]`. + v (paddle.Tensor): + values of shape `[B, T, HV, V]`. + GVA (Grouped Value Attention) is applied if `HV > H`, where `HV` must be divisible by `H`. + g (paddle.Tensor): + (forget) gating tensor of shape `[B, T, HV]`. + When `use_gate_in_kernel=False` (default), `g` should be in log space (pre-computed decay). + When `use_gate_in_kernel=True`, `g` is the raw input before gate activation; + the kernel fuses `-exp(A_log) * softplus(g + dt_bias)` + chunk cumsum internally. + beta (paddle.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[paddle.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q/k tensor internally. Default: `False`. + cu_seqlens (paddle.Tensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + cp_context: + Context parallel context (skipped in Phase 1). Default: `None`. + transpose_state_layout (Optional[bool]): + Whether to use the transposed state layout for the hidden state. + Default: `False`. + use_gate_in_kernel (bool): + Whether to compute the log-space GDN decay internally. + When `True`, the passed `g` is the raw input, and `A_log` must be provided. + The kernel fuses gate activation + chunk cumsum in a single pass. + Default: `False`. + A_log (Optional[paddle.Tensor]): + Decay parameter of shape `[HV]`. Required when `use_gate_in_kernel=True`. + dt_bias (Optional[paddle.Tensor]): + Bias added to `g` before activation, of shape `[HV]`. + Only used when `use_gate_in_kernel=True`. + + Returns: + o (paddle.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (paddle.Tensor): + Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`. + """ + # Validate head dimensions + if q.shape[2] != k.shape[2]: + raise ValueError( + f"q and k must have the same number of heads, " + f"but got q.shape[2]={q.shape[2]} and k.shape[2]={k.shape[2]}" + ) + H, HV = q.shape[2], v.shape[2] + if HV % H != 0: + raise ValueError( + f"For GVA, num_v_heads (HV={HV}) must be evenly divisible by " + f"num_heads (H={H}), but got HV % H = {HV % H}" + ) + + if 'head_first' in kwargs: + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + ) + + # CP (Context Parallel) is skipped in Phase 1 + # cp_context is accepted but ignored + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.", + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", + ) + use_gate_in_kernel = kwargs.get('use_gate_in_kernel', False) + A_log = kwargs.get('A_log') + dt_bias = kwargs.get('dt_bias') + if use_gate_in_kernel: + assert A_log is not None, "A_log must be provided when use_gate_in_kernel=True." + + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + cu_seqlens_cpu, + use_qk_l2norm_in_kernel, + cp_context, + transpose_state_layout, + use_gate_in_kernel, + A_log, + dt_bias, + ) + # Convert dummy tensor back to None when output_final_state=False + if not output_final_state: + final_state = None + return o, final_state + + +chunk_gdn = chunk_gated_delta_rule diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/chunk_fwd.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/chunk_fwd.py new file mode 100644 index 00000000000..9cb3ea86f67 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/chunk_fwd.py @@ -0,0 +1,411 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp, exp2 +from flash_mask.linear_attn.utils import IS_TF32_SUPPORTED, autotune_cache_kwargs +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +if IS_TF32_SUPPORTED: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32') +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps) + for BK in [32, 64] + for num_warps in [1, 2, 4] + ], + key=['H', 'HV', 'K', 'BC'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_fwd_kkt_solve_kernel( + k, + g, + beta, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + USE_G: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel: compute beta * K @ K^T (lower triangular) + solve_tril (I+A)^{-1} in one pass. + + This kernel fuses chunk_scaled_dot_kkt_fwd and solve_tril into a single kernel, + avoiding the HBM round-trip for the intermediate A matrix. + + Steps: + 1. Compute all 10 lower-triangular [BC, BC] blocks of beta * K @ K^T in registers + 2. Apply gate and beta scaling + 3. Forward substitution on diagonal blocks + 4. Block merge to get full (I+A)^{-1} + 5. Write result to A (output) + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // HV, i_bh % HV + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + k += (bos * H + i_h // (HV // H)) * K + A += (bos * HV + i_h) * BT + + o_i = tl.arange(0, BC) + m_tc0 = (i_tc0 + o_i) < T + m_tc1 = (i_tc1 + o_i) < T + m_tc2 = (i_tc2 + o_i) < T + m_tc3 = (i_tc3 + o_i) < T + + # load beta for each sub-chunk + p_b0 = tl.make_block_ptr(beta + bos * HV + i_h, (T,), (HV,), (i_tc0,), (BC,), (0,)) + p_b1 = tl.make_block_ptr(beta + bos * HV + i_h, (T,), (HV,), (i_tc1,), (BC,), (0,)) + p_b2 = tl.make_block_ptr(beta + bos * HV + i_h, (T,), (HV,), (i_tc2,), (BC,), (0,)) + p_b3 = tl.make_block_ptr(beta + bos * HV + i_h, (T,), (HV,), (i_tc3,), (BC,), (0,)) + b_b0 = tl.load(p_b0, boundary_check=(0,)).to(tl.float32) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + + # load gate if used + if USE_G: + p_g0 = tl.make_block_ptr(g + bos * HV + i_h, (T,), (HV,), (i_tc0,), (BC,), (0,)) + p_g1 = tl.make_block_ptr(g + bos * HV + i_h, (T,), (HV,), (i_tc1,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + bos * HV + i_h, (T,), (HV,), (i_tc2,), (BC,), (0,)) + p_g3 = tl.make_block_ptr(g + bos * HV + i_h, (T,), (HV,), (i_tc3,), (BC,), (0,)) + + b_g0 = tl.load(p_g0, boundary_check=(0,)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0,)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0,)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0,)).to(tl.float32) + + ############################################################################ + # Step 1: compute all 10 lower-triangular [BC, BC] blocks of K @ K^T + ############################################################################ + + # 4 diagonal blocks + b_A00 = tl.zeros([BC, BC], dtype=tl.float32) + b_A11 = tl.zeros([BC, BC], dtype=tl.float32) + b_A22 = tl.zeros([BC, BC], dtype=tl.float32) + b_A33 = tl.zeros([BC, BC], dtype=tl.float32) + + # 6 off-diagonal blocks + b_A10 = tl.zeros([BC, BC], dtype=tl.float32) + b_A20 = tl.zeros([BC, BC], dtype=tl.float32) + b_A21 = tl.zeros([BC, BC], dtype=tl.float32) + b_A30 = tl.zeros([BC, BC], dtype=tl.float32) + b_A31 = tl.zeros([BC, BC], dtype=tl.float32) + b_A32 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k0 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) + b_k0 = tl.load(p_k0, boundary_check=(0, 1)) + # diagonal block 0 + b_A00 += tl.dot(b_k0, tl.trans(b_k0)) + + if i_tc1 < T: + p_k1 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)) + # diagonal block 1 + b_A11 += tl.dot(b_k1, tl.trans(b_k1)) + # off-diagonal (1,0) + b_A10 += tl.dot(b_k1, tl.trans(b_k0)) + + if i_tc2 < T: + p_k2 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)) + # diagonal block 2 + b_A22 += tl.dot(b_k2, tl.trans(b_k2)) + # off-diagonal (2,0), (2,1) + b_A20 += tl.dot(b_k2, tl.trans(b_k0)) + b_A21 += tl.dot(b_k2, tl.trans(b_k1)) + + if i_tc3 < T: + p_k3 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)) + # diagonal block 3 + b_A33 += tl.dot(b_k3, tl.trans(b_k3)) + # off-diagonal (3,0), (3,1), (3,2) + b_A30 += tl.dot(b_k3, tl.trans(b_k0)) + b_A31 += tl.dot(b_k3, tl.trans(b_k1)) + b_A32 += tl.dot(b_k3, tl.trans(b_k2)) + + ############################################################################ + # Step 2: apply gate and beta scaling + ############################################################################ + + # apply gate, beta scaling, and masking + # m_d: strictly lower triangular mask for diagonal blocks + # m_tc: boundary mask to prevent NaN from 0 * inf (IEEE 754) when + # out-of-bounds g loads as 0 via boundary_check and exp(0 - g_inbounds) overflows + m_d = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + if USE_G: + if USE_EXP2: + b_A00 *= tl.where(m_d & m_tc0[:, None] & m_tc0[None, :], exp2(b_g0[:, None] - b_g0[None, :]), 0.) + b_A11 *= tl.where(m_d & m_tc1[:, None] & m_tc1[None, :], exp2(b_g1[:, None] - b_g1[None, :]), 0.) + b_A22 *= tl.where(m_d & m_tc2[:, None] & m_tc2[None, :], exp2(b_g2[:, None] - b_g2[None, :]), 0.) + b_A33 *= tl.where(m_d & m_tc3[:, None] & m_tc3[None, :], exp2(b_g3[:, None] - b_g3[None, :]), 0.) + + b_A10 *= tl.where(m_tc1[:, None] & m_tc0[None, :], exp2(b_g1[:, None] - b_g0[None, :]), 0.) + b_A20 *= tl.where(m_tc2[:, None] & m_tc0[None, :], exp2(b_g2[:, None] - b_g0[None, :]), 0.) + b_A21 *= tl.where(m_tc2[:, None] & m_tc1[None, :], exp2(b_g2[:, None] - b_g1[None, :]), 0.) + b_A30 *= tl.where(m_tc3[:, None] & m_tc0[None, :], exp2(b_g3[:, None] - b_g0[None, :]), 0.) + b_A31 *= tl.where(m_tc3[:, None] & m_tc1[None, :], exp2(b_g3[:, None] - b_g1[None, :]), 0.) + b_A32 *= tl.where(m_tc3[:, None] & m_tc2[None, :], exp2(b_g3[:, None] - b_g2[None, :]), 0.) + else: + b_A00 *= tl.where(m_d & m_tc0[:, None] & m_tc0[None, :], exp(b_g0[:, None] - b_g0[None, :]), 0.) + b_A11 *= tl.where(m_d & m_tc1[:, None] & m_tc1[None, :], exp(b_g1[:, None] - b_g1[None, :]), 0.) + b_A22 *= tl.where(m_d & m_tc2[:, None] & m_tc2[None, :], exp(b_g2[:, None] - b_g2[None, :]), 0.) + b_A33 *= tl.where(m_d & m_tc3[:, None] & m_tc3[None, :], exp(b_g3[:, None] - b_g3[None, :]), 0.) + + b_A10 *= tl.where(m_tc1[:, None] & m_tc0[None, :], exp(b_g1[:, None] - b_g0[None, :]), 0.) + b_A20 *= tl.where(m_tc2[:, None] & m_tc0[None, :], exp(b_g2[:, None] - b_g0[None, :]), 0.) + b_A21 *= tl.where(m_tc2[:, None] & m_tc1[None, :], exp(b_g2[:, None] - b_g1[None, :]), 0.) + b_A30 *= tl.where(m_tc3[:, None] & m_tc0[None, :], exp(b_g3[:, None] - b_g0[None, :]), 0.) + b_A31 *= tl.where(m_tc3[:, None] & m_tc1[None, :], exp(b_g3[:, None] - b_g1[None, :]), 0.) + b_A32 *= tl.where(m_tc3[:, None] & m_tc2[None, :], exp(b_g3[:, None] - b_g2[None, :]), 0.) + else: + b_A00 = tl.where(m_d, b_A00, 0.) + b_A11 = tl.where(m_d, b_A11, 0.) + b_A22 = tl.where(m_d, b_A22, 0.) + b_A33 = tl.where(m_d, b_A33, 0.) + + # diagonal blocks: scaled by beta + b_A00 = b_A00 * b_b0[:, None] + b_A11 = b_A11 * b_b1[:, None] + b_A22 = b_A22 * b_b2[:, None] + b_A33 = b_A33 * b_b3[:, None] + + # off-diagonal blocks: full block, scaled by beta + b_A10 = b_A10 * b_b1[:, None] + b_A20 = b_A20 * b_b2[:, None] + b_A21 = b_A21 * b_b2[:, None] + b_A30 = b_A30 * b_b3[:, None] + b_A31 = b_A31 * b_b3[:, None] + b_A32 = b_A32 * b_b3[:, None] + + ############################################################################ + # Step 3: forward substitution on diagonal blocks -> (I + A_diag)^{-1} + # + # Same algorithm as solve_tril, but rows are extracted from in-register + # [BC, BC] tensor via tl.sum(tl.where(mask, tensor, 0), 0) instead of + # tl.load from HBM. + ############################################################################ + + b_Ai00 = -b_A00 + b_Ai11 = -b_A11 + b_Ai22 = -b_A22 + b_Ai33 = -b_A33 + + for i in range(2, min(BC, T - i_tc0)): + b_a00 = tl.sum(tl.where((o_i == i)[:, None], -b_A00, 0.), 0) + b_a00 = tl.where(o_i < i, b_a00, 0.) + b_a00 = b_a00 + tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(2, min(BC, T - i_tc1)): + b_a11 = tl.sum(tl.where((o_i == i)[:, None], -b_A11, 0.), 0) + b_a11 = tl.where(o_i < i, b_a11, 0.) + b_a11 = b_a11 + tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i)[:, None], b_a11, b_Ai11) + for i in range(2, min(BC, T - i_tc2)): + b_a22 = tl.sum(tl.where((o_i == i)[:, None], -b_A22, 0.), 0) + b_a22 = tl.where(o_i < i, b_a22, 0.) + b_a22 = b_a22 + tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i)[:, None], b_a22, b_Ai22) + for i in range(2, min(BC, T - i_tc3)): + b_a33 = tl.sum(tl.where((o_i == i)[:, None], -b_A33, 0.), 0) + b_a33 = tl.where(o_i < i, b_a33, 0.) + b_a33 = b_a33 + tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + ############################################################################ + # Step 4: block merge -> full (I + A)^{-1} + ############################################################################ + + b_Ai10 = -tl.dot( + tl.dot(b_Ai11, b_A10, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai00, + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + b_Ai21 = -tl.dot( + tl.dot(b_Ai22, b_A21, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai11, + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + b_Ai32 = -tl.dot( + tl.dot(b_Ai33, b_A32, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai22, + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_A20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_A21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_A31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_A32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_A30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_A31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_A32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + ############################################################################ + # Step 5: store full (I + A)^{-1} to output A + ############################################################################ + + p_A00 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_A10 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_A11 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + p_A20 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_A21 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + p_A22 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc2, 2*BC), (BC, BC), (1, 0)) + p_A30 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_A31 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_A32 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc3, 2*BC), (BC, BC), (1, 0)) + p_A33 = tl.make_block_ptr(A, (T, BT), (HV*BT, 1), (i_tc3, 3*BC), (BC, BC), (1, 0)) + + tl.store(p_A00, b_Ai00.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A10, b_Ai10.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A11, b_Ai11.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A20, b_Ai20.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A21, b_Ai21.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A22, b_Ai22.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A30, b_Ai30.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A31, b_Ai31.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A32, b_Ai32.to(A.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A33, b_Ai33.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_intra( + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + beta: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + r""" + GDN intra-chunk forward: fused kkt + solve_tril + recompute_w_u. + + Equivalent to: + A = chunk_scaled_dot_kkt_fwd(k, g, beta, ...) # kernel 1 + A = solve_tril(A, ...) # kernel 2 + w, u = recompute_w_u_fwd(k, v, beta, A, g, ...) # kernel 3 + + Fuses kernels 1+2 into a single kernel, reducing from 3 to 2 kernel launches + and eliminating the HBM round-trip for the intermediate A matrix. + + Args: + k (paddle.Tensor): + The key tensor of shape `[B, T, H, K]`. + v (paddle.Tensor): + The value tensor of shape `[B, T, HV, V]`. + g (paddle.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, HV]`. Default: `None`. + beta (paddle.Tensor): + The beta tensor of shape `[B, T, HV]`. + cu_seqlens (paddle.Tensor): + The cumulative sequence lengths. Default: `None`. + chunk_size (int): + The chunk size. Default: 64. + chunk_indices (paddle.Tensor): + Precomputed chunk indices. Default: `None`. + + Returns: + w (paddle.Tensor): shape `[B, T, HV, K]` + u (paddle.Tensor): shape `[B, T, HV, V]` + A (paddle.Tensor): shape `[B, T, HV, BT]`, the solved (I+A)^{-1} matrix + """ + B, T, H, K, HV = *k.shape, beta.shape[2] + BT = chunk_size + BC = 16 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Step 1: fused kkt + solve_tril + A = paddle.zeros([B, T, HV, BT], dtype=k.dtype) + chunk_gated_delta_rule_fwd_kkt_solve_kernel[(NT, B * HV)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + HV=HV, + K=K, + BT=BT, + BC=BC, + USE_EXP2=use_exp2, + ) + + # Step 2: recompute_w_u + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + ) + return w, u, A diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/fused_recurrent.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/fused_recurrent.py new file mode 100644 index 00000000000..64a1bab5fc5 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/fused_recurrent.py @@ -0,0 +1,397 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.op import exp, exp2 +from flash_mask.linear_attn.utils import input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_GK': lambda args: args['gk'] is not None, + 'USE_GV': lambda args: args['gv'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + gk, + gv, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_EXP2: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if USE_G: + p_g = g + bos * HV + i_hv + if USE_GK: + p_gk = gk + (bos * HV + i_hv) * K + o_k + if USE_GV: + p_gv = gv + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + bos * HV + i_hv + else: + p_beta = beta + (bos * HV + i_hv) * V + o_v + + p_o = o + (bos * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + if TRANSPOSE_STATE: + mask_h = mask_v[:, None] & mask_k[None, :] + else: + mask_h = mask_k[:, None] & mask_v[None, :] + + if TRANSPOSE_STATE: + b_h = tl.zeros([BV, BK], dtype=tl.float32) + else: + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if TRANSPOSE_STATE: + p_h0 = h0 + i_nh * K*V + o_v[:, None] * K + o_k[None, :] + else: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in tl.range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta).to(tl.float32) + else: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + if USE_EXP2: + b_h *= exp2(b_g) + else: + b_h *= exp(b_g) + + if USE_GK: + b_gk = tl.load(p_gk).to(tl.float32) + if USE_EXP2: + if TRANSPOSE_STATE: + b_h *= exp2(b_gk[None, :]) + else: + b_h *= exp2(b_gk[:, None]) + else: + if TRANSPOSE_STATE: + b_h *= exp(b_gk[None, :]) + else: + b_h *= exp(b_gk[:, None]) + + if USE_GV: + b_gv = tl.load(p_gv).to(tl.float32) + if USE_EXP2: + if TRANSPOSE_STATE: + b_h *= exp2(b_gv[:, None]) + else: + b_h *= exp2(b_gv[None, :]) + else: + if TRANSPOSE_STATE: + b_h *= exp(b_gv[:, None]) + else: + b_h *= exp(b_gv[None, :]) + + if TRANSPOSE_STATE: + b_v = b_beta * (b_v - tl.sum(b_h * b_k[None, :], 1)) + b_h += b_v[:, None] * b_k[None, :] + b_o = tl.sum(b_h * b_q[None, :], 1) + else: + b_v = b_beta * (b_v - tl.sum(b_h * b_k[:, None], 0)) + b_h += b_k[:, None] * b_v + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H*K + p_k += H*K + p_v += HV*V + if USE_G: + p_g += HV + if USE_GK: + p_gk += HV*K + if USE_GV: + p_gv += HV*V + p_beta += HV * (1 if IS_BETA_HEADWISE else V) + p_o += HV*V + + if STORE_FINAL_STATE: + if TRANSPOSE_STATE: + p_ht = ht + i_nh * K*V + o_v[:, None] * K + o_k[None, :] + else: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + beta: paddle.Tensor | None = None, + scale: float = None, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + BV = min(8, triton.next_power_of_2(V)) if gv is None else triton.next_power_of_2(V) + NV = triton.cdiv(V, BV) + + o = paddle.empty_like(v) + if output_final_state: + if transpose_state_layout: + final_state = paddle.empty([N, HV, V, K], dtype=paddle.float32) + else: + final_state = paddle.empty([N, HV, K, V], dtype=paddle.float32) + else: + final_state = None + + grid = (NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim != v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + USE_EXP2=use_exp2, + TRANSPOSE_STATE=transpose_state_layout, + num_warps=1, + num_stages=3, + ) + return o, final_state + + +class FusedRecurrentFunction(paddle.autograd.PyLayer): + + @staticmethod + @input_guard + def forward( + ctx, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + beta: paddle.Tensor | None = None, + scale: float = None, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, + ): + # Store non-tensor params as ctx attributes + ctx.scale = scale + ctx.output_final_state = output_final_state + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.use_exp2 = use_exp2 + ctx.transpose_state_layout = transpose_state_layout + + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens, + use_exp2=use_exp2, + transpose_state_layout=transpose_state_layout, + ) + + # Paddle PyLayer forward cannot return None, use dummy tensor as placeholder + if final_state is None: + final_state = paddle.zeros([1], dtype=q.dtype) + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps.", + ) + + +def fused_recurrent_gated_delta_rule( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + gv: paddle.Tensor | None = None, + beta: paddle.Tensor | None = None, + scale: float = None, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor]: + r""" + Args: + q (paddle.Tensor): + queries of shape `[B, T, H, K]`. + k (paddle.Tensor): + keys of shape `[B, T, H, K]`. + v (paddle.Tensor): + values of shape `[B, T, HV, V]`. + GVA (Grouped Value Attention) is applied if `HV > H`, where `HV` must be divisible by `H`. + g (paddle.Tensor): + g (decays) of shape `[B, T, HV]`. Default: `None`. + gk (paddle.Tensor): + gk (decays) of shape `[B, T, HV, K]`. Default: `None`. + gv (paddle.Tensor): + gv (decays) of shape `[B, T, HV, V]`. Default: `None`. + beta (paddle.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[paddle.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (Optional[bool]): + Whether to use L2 normalization in the kernel. Default: `False`. + cu_seqlens (paddle.Tensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + transpose_state_layout (bool): + Whether to use transposed state layout `[V, K]` instead of `[K, V]`. Default: `False`. + + Returns: + o (paddle.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (paddle.Tensor): + Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`. + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.", + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + if beta is None: + beta = paddle.ones_like(q[..., 0]) + + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + gk, + gv, + beta, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + cu_seqlens, + use_exp2, + transpose_state_layout, + ) + # Convert dummy tensor back to None when output_final_state=False + if not output_final_state: + final_state = None + return o, final_state + + +fused_recurrent_gdn = fused_recurrent_gated_delta_rule diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/gate.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/gate.py new file mode 100644 index 00000000000..1eeadb76540 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/gate.py @@ -0,0 +1,350 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import paddle.nn.functional as F +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.index import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp +from flash_mask.linear_attn.ops.utils.softplus import softplus +from flash_mask.linear_attn.utils import autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +def naive_gdn_gate( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor: + """ + Paddle reference implementation for GDN gate computation. + + Computes: ``g = -A_log.exp() * softplus(g + dt_bias)`` + + Args: + g (paddle.Tensor): + Input tensor of shape `[..., HV]`. + A_log (paddle.Tensor): + Decay parameter tensor with `HV` elements. + dt_bias (paddle.Tensor | None): + Optional bias tensor added to `g` before activation, shape `[HV]`. + + Returns: + Output tensor of shape `[..., HV]`. + """ + g = g.cast(paddle.float32) + if dt_bias is not None: + g = g + dt_bias.cast(paddle.float32) + return (-A_log.cast(paddle.float32).exp() * F.softplus(g)).cast(output_dtype) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_BIAS': lambda args: args['dt_bias'] is not None, + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['H', 'BT', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def gdn_gate_chunk_cumsum_scalar_kernel( + g, + A_log, + dt_bias, + o, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + if HAS_BIAS: + b_g = b_g + tl.load(dt_bias + i_h).to(tl.float32) + b_A = tl.load(A_log + i_h).to(tl.float32) + b_gate = -exp(b_A) * softplus(b_g) + + b_o = tl.cumsum(b_gate, axis=0) + if REVERSE: + b_z = tl.sum(b_gate, axis=0) + b_o = -b_o + b_z[None] + b_gate + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_BIAS': lambda args: args['dt_bias'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['H', 'BT'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def gdn_gate_bwd_kernel( + g, + A_log, + dt_bias, + dyg, + dg, + dA, + T, + H: tl.constexpr, + BT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + i_t, i_h = tl.program_id(0), tl.program_id(1) + + b_A = tl.load(A_log + i_h).to(tl.float32) + + p_g = tl.make_block_ptr(g + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dg = tl.make_block_ptr(dg + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dyg = tl.make_block_ptr(dyg + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_dyg = tl.load(p_dyg, boundary_check=(0,)).to(tl.float32) + + if HAS_BIAS: + b_g = b_g + tl.load(dt_bias + i_h).to(tl.float32) + + # gate = -exp(A_log) * softplus(g + bias) + # d(gate)/d(g) = -exp(A_log) * sigmoid(g + bias) (softplus' = sigmoid) + # d(gate)/d(A_log) = -exp(A_log) * softplus(g + bias) = gate + b_neg_expA = -exp(b_A) + b_yg = b_neg_expA * softplus(b_g) + b_dg = b_neg_expA * (b_dyg * tl.sigmoid(b_g)) + b_dA = tl.sum(b_dyg * b_yg, 0) + + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + tl.store(dA + i_t * H + i_h, b_dA) + + +@input_guard +def gdn_gate_chunk_cumsum( + g: paddle.Tensor, + A_log: paddle.Tensor, + chunk_size: int, + scale: float = None, + dt_bias: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + output_dtype: paddle.dtype | None = paddle.float32, +) -> paddle.Tensor: + B, T, H = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = paddle.empty_like(g).cast(output_dtype or g.dtype) + gdn_gate_chunk_cumsum_scalar_kernel[(NT, B * H)]( + g=g, + A_log=A_log, + dt_bias=dt_bias, + o=o, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + REVERSE=False, + ) + return o + + +def gdn_gate_bwd( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None, + dyg: paddle.Tensor, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor | None]: + H = g.shape[-1] + T = g.numel().item() // H + BT = 32 + NT = triton.cdiv(T, BT) + + dg = paddle.empty_like(g).cast(paddle.float32) + dA = paddle.empty([NT, H], dtype=paddle.float32) + + gdn_gate_bwd_kernel[(NT, H)]( + g=g, + A_log=A_log, + dt_bias=dt_bias, + dyg=dyg, + dg=dg, + dA=dA, + T=T, + H=H, + BT=BT, + ) + + # Compute dbias from dg while still in float32 (before casting to g.dtype which may be float16). + # Paddle's .sum() does not promote float16 to float32 for accumulation like PyTorch does. + dbias = dg.reshape([-1, H]).sum(axis=0).cast(dt_bias.dtype) if dt_bias is not None else None + dg = dg.reshape(g.shape).cast(g.dtype) + dA = dA.sum(axis=0).reshape(A_log.shape).cast(A_log.dtype) + + return dg, dA, dbias + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_BIAS': lambda args: args['dt_bias'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3] + ], + key=['H'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def gdn_gate_fwd_kernel( + g, + A_log, + dt_bias, + yg, + T, + H: tl.constexpr, + BT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + i_t, i_h = tl.program_id(0), tl.program_id(1) + + b_A = tl.load(A_log + i_h).to(tl.float32) + + p_g = tl.make_block_ptr(g + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_yg = tl.make_block_ptr(yg + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + if HAS_BIAS: + b_g = b_g + tl.load(dt_bias + i_h).to(tl.float32) + b_yg = -exp(b_A) * softplus(b_g) + tl.store(p_yg, b_yg.to(p_yg.dtype.element_ty), boundary_check=(0,)) + + +def gdn_gate_fwd( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor: + H = g.shape[-1] + T = g.numel().item() // H + + yg = paddle.empty_like(g).cast(output_dtype) + + def grid(meta): + return (triton.cdiv(T, meta['BT']), H) + + gdn_gate_fwd_kernel[grid]( + g=g, + A_log=A_log, + dt_bias=dt_bias, + yg=yg, + T=T, + H=H, + ) + return yg + + +class GDNGateFunction(paddle.autograd.PyLayer): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + output_dtype: paddle.dtype = paddle.float32, + ) -> paddle.Tensor: + yg = gdn_gate_fwd(g=g, A_log=A_log, dt_bias=dt_bias, output_dtype=output_dtype) + ctx.save_for_backward(g, A_log, dt_bias) + ctx.output_dtype = output_dtype + # Paddle PyLayer backward must return exactly as many values as tensor inputs. + _forward_args = [g, A_log, dt_bias, output_dtype] + ctx._tensor_mask = tuple(isinstance(a, paddle.Tensor) for a in _forward_args) + ctx._needs_grad = tuple( + isinstance(a, paddle.Tensor) and not a.stop_gradient for a in _forward_args + ) + return yg + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, dyg: paddle.Tensor): + g, A_log, dt_bias = ctx.saved_tensor() + dg, dA, dbias = gdn_gate_bwd(g=g, A_log=A_log, dt_bias=dt_bias, dyg=dyg) + all_grads = [dg, dA, dbias, None] + return tuple( + g if needs_grad else None + for g, is_tensor, needs_grad in zip(all_grads, ctx._tensor_mask, ctx._needs_grad) + if is_tensor + ) + + +def fused_gdn_gate( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor: + r""" + Fused GDN gate computation with autograd support. + + Computes: ``g = -A_log.exp() * softplus(g + dt_bias)`` + + Args: + g (paddle.Tensor): + Input tensor of shape `[..., HV]`. + A_log (paddle.Tensor): + Decay parameter tensor with `HV` elements. + dt_bias (paddle.Tensor | None): + Optional bias tensor added to `g` before activation, shape `[HV]`. + output_dtype (paddle.dtype): + The dtype of the output tensor. Default: `paddle.float32`. + + Returns: + Output tensor of shape `[..., HV]`. + """ + return GDNGateFunction.apply(g, A_log, dt_bias, output_dtype) diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/naive.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/naive.py new file mode 100644 index 00000000000..d3889c19178 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/naive.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# Adapted for PaddlePaddle + +import paddle + + +def naive_recurrent_gated_delta_rule( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + g: paddle.Tensor, + scale: float = None, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, +): + """ + Reference PaddlePaddle implementation of recurrent gated delta rule. + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + v: [B, T, H, V] + beta: [B, T, H] + g: [B, T, H] + scale: float, optional + initial_state: [B, H, K, V], optional + output_final_state: bool + + Returns: + o: [B, T, H, V] + final_state: [B, H, K, V] if output_final_state else None + """ + q, k, v, beta, g = map( + lambda x: x.transpose([0, 2, 1, 3]).contiguous().cast(paddle.float32) if x.ndim == 4 + else x.transpose([0, 2, 1]).contiguous().cast(paddle.float32), + [q, k, v, beta, g] + ) + B, H, T, K = k.shape + V = v.shape[-1] + o = paddle.zeros([B, H, T, V], dtype=v.dtype) + h = paddle.zeros([B, H, K, V], dtype=v.dtype) + if initial_state is not None: + h = initial_state.cast(paddle.float32) + if scale is None: + scale = 1 / (q.shape[-1] ** 0.5) + q = q * scale + + for i in range(T): + b_q = q[:, :, i] + b_k = k[:, :, i] + b_v = v[:, :, i].clone() + h = h.clone() * g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + b_beta = beta[:, :, i] + b_v = b_v - (h.clone() * b_k.unsqueeze(-1)).sum(-2) + b_v = b_v * b_beta.unsqueeze(-1) + h = h.clone() + b_k.unsqueeze(-1) * b_v.unsqueeze(-2) + o[:, :, i] = paddle.einsum('bhd,bhdm->bhm', b_q, h) + + if not output_final_state: + h = None + o = o.transpose([0, 2, 1, 3]).contiguous() + return o, h diff --git a/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/wy_fast.py b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/wy_fast.py new file mode 100644 index 00000000000..388e2ae61a1 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gated_delta_rule/wy_fast.py @@ -0,0 +1,362 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp, exp2 +from flash_mask.linear_attn.utils import IS_NVIDIA_BLACKWELL, autotune_cache_kwargs, check_shared_mem +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +if IS_NVIDIA_BLACKWELL: + """ + Compute tl.dot with SM100 workaround. + + On SM100 (Blackwell) GPUs, wraps the result in inline assembly to prevent + the TritonGPUHoistTMEMAlloc pass from incorrectly fusing add and dot operations. + See: https://github.com/fla-org/flash-linear-attention/issues/638 + + TODO: Remove this workaround once the Triton compiler bug is fixed. + Track upstream issue at: https://github.com/triton-lang/triton/issues/8695 + """ + @triton.jit + def safe_dot(a, b): + return tl.inline_asm_elementwise( + asm="mov.f32 $0, $1;", + constraints="=r,r", + args=[tl.dot(a, b)], + dtype=tl.float32, + is_pure=True, + pack=1, + ) +else: + @triton.jit + def safe_dot(a, b): + return tl.dot(a, b) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'HV', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // HV, i_bh % HV + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos*HV + i_h, (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos*HV + i_h) * BT, (T, BT), (HV*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*HV + i_h) * V, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*HV + i_h) * V, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + p_g = tl.make_block_ptr(g + (bos*HV + i_h), (T,), (HV,), (i_t * BT,), (BT,), (0,)) + if USE_EXP2: + b_g = exp2(tl.load(p_g, boundary_check=(0,))) + else: + b_g = exp(tl.load(p_g, boundary_check=(0,))) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h // (HV // H)) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*HV + i_h) * K, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_b[:, None] + if USE_G: + b_kb *= b_g[:, None] + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'HV', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + g, + A, + dw, + du, + dk, + dv, + db, + dg, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // HV, i_bh % HV + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_b = tl.make_block_ptr(beta + (bos*HV + i_h), (T,), (HV,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(db + (bos*HV + i_h), (T,), (HV,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*HV + i_h) * BT, (BT, T), (1, HV*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_b = tl.load(p_b, boundary_check=(0,)) + b_db = tl.zeros([BT], dtype=tl.float32) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + + if USE_G: + p_g = tl.make_block_ptr(g + (bos*HV + i_h), (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + b_g_exp = exp2(b_g) + else: + b_g_exp = tl.exp(b_g) + b_dg = tl.zeros([BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h // (HV // H)) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*HV + i_h) * K, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*HV + i_h) * K, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_G: + b_kbg = b_k * (b_b * b_g_exp)[:, None] + else: + b_kbg = b_k * b_b[:, None] + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) + b_dkbg = tl.dot(b_A, b_dw) + if USE_G: + b_dk = b_dkbg * (b_g_exp * b_b)[:, None] + b_db += tl.sum(b_dkbg * b_k * b_g_exp[:, None], 1) + b_dg += tl.sum(b_dkbg * b_kbg, 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*HV + i_h) * V, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*HV + i_h) * V, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*HV + i_h) * V, (T, V), (HV*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_vb)) + b_dvb = tl.dot(b_A, b_du) + b_dv = b_dvb * b_b[:, None] + b_db += tl.sum(b_dvb * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + + if USE_G: + if USE_EXP2: + b_dA *= exp2(b_g[:, None] - b_g[None, :]) + else: + b_dA *= exp(b_g[:, None] - b_g[None, :]) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + b_dA = tl.where(m_A, -b_dA, 0).to(k.dtype.element_ty) + + tl.debug_barrier() + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h // (HV // H)) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*HV + i_h) * K, (T, K), (HV*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kt = tl.trans(b_k) + b_kb = b_k * b_b[:, None] + + b_A += tl.dot(b_k, b_kt) + b_dkb = tl.dot(b_dA, b_k) + b_db += tl.sum(b_dkb * b_k, 1) + b_dk = b_dkb * b_b[:, None] + tl.trans(tl.dot(tl.trans(b_kb).to(b_dA.dtype), b_dA)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + b_A *= b_b[:, None] + if USE_G: + b_AdA = b_dA * b_A + p_dg = tl.make_block_ptr(dg + (bos*HV + i_h), (T,), (HV,), (i_t * BT,), (BT,), (0,)) + b_dg += tl.sum(b_AdA, axis=1) - tl.sum(b_AdA, axis=0) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +def recompute_w_u_fwd( + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + A: paddle.Tensor, + g: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor]: + B, T, H, K, V, HV = *k.shape, v.shape[-1], v.shape[2] + BT = A.shape[-1] + BK = 64 + BV = 64 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = paddle.empty([B, T, HV, K], dtype=k.dtype) + u = paddle.empty_like(v) + recompute_w_u_fwd_kernel[(NT, B*HV)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_EXP2=use_exp2, + ) + return w, u + + +def prepare_wy_repr_bwd( + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + A: paddle.Tensor, + dw: paddle.Tensor, + du: paddle.Tensor, + g: paddle.Tensor = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + B, T, H, K, V, HV = *k.shape, v.shape[-1], v.shape[2] + BT = 64 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + + dk = paddle.empty([B, T, HV, K], dtype=k.dtype) + dv = paddle.empty_like(v) + dg = paddle.empty_like(g) if g is not None else None + db = paddle.empty_like(beta) + prepare_wy_repr_bwd_kernel[(NT, B * HV)]( + k=k, + v=v, + beta=beta, + g=g, + A=A, + dw=dw, + du=du, + dk=dk, + dv=dv, + db=db, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_EXP2=use_exp2, + ) + if H != HV: + dk = dk.reshape([B, T, H, HV // H, K]).sum(axis=3) + return dk, dv, db, dg + + +fwd_recompute_w_u = recompute_w_u_fwd +bwd_prepare_wy_repr = prepare_wy_repr_bwd diff --git a/flashmask/flash_mask/linear_attn/ops/gla/__init__.py b/flashmask/flash_mask/linear_attn/ops/gla/__init__.py new file mode 100644 index 00000000000..40a96afc6ff --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gla/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/flashmask/flash_mask/linear_attn/ops/gla/chunk.py b/flashmask/flash_mask/linear_attn/ops/gla/chunk.py new file mode 100644 index 00000000000..4188e9a1185 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/gla/chunk.py @@ -0,0 +1,150 @@ +# Only the chunk_gla_fwd_o_gk function and its kernel, extracted from fla/ops/gla/chunk.py +# +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp, exp2 +from flash_mask.linear_attn.utils import autotune_cache_kwargs, check_shared_mem +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'TRANSPOSE_STATE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_EXP2: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t.to(tl.int64) + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = (i_b * NT + i_t).to(tl.int64) + bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64) + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + if TRANSPOSE_STATE: + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) + else: + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + # [BT, BK] + if USE_EXP2: + b_qg = (b_q * exp2(b_g)).to(b_q.dtype) + else: + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + b_h = tl.load(p_h, boundary_check=(0, 1)) + if i_k >= 0: + if TRANSPOSE_STATE: + b_o += tl.dot(b_qg, tl.trans(b_h).to(b_qg.dtype)) + else: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + b_o *= scale + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype) + b_o += tl.dot(b_A, b_v) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_o_gk( + q: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + A: paddle.Tensor, + h: paddle.Tensor, + scale: float, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + use_exp2: bool = False, + transpose_state_layout: bool = False, +): + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Please ensure zeros, since vllm will use padding v + o = paddle.zeros_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_gla_fwd_kernel_o[grid]( + q=q, + v=v, + g=g, + h=h, + o=o, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + TRANSPOSE_STATE=transpose_state_layout, + ) + return o diff --git a/flashmask/flash_mask/linear_attn/ops/kda/__init__.py b/flashmask/flash_mask/linear_attn/ops/kda/__init__.py new file mode 100644 index 00000000000..f896d2b9bc3 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +from .chunk import chunk_kda +from .fused_recurrent import fused_recurrent_kda + +__all__ = [ + "chunk_kda", + "fused_recurrent_kda", +] diff --git a/flashmask/flash_mask/linear_attn/ops/kda/chunk.py b/flashmask/flash_mask/linear_attn/ops/kda/chunk.py new file mode 100644 index 00000000000..be7238919ab --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/chunk.py @@ -0,0 +1,358 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +# Related files are modified and supported by the Moonshot AI Team + +import paddle + +from flash_mask.linear_attn.modules.l2norm import l2norm_bwd, l2norm_fwd +from flash_mask.linear_attn.ops.kda.chunk_bwd import chunk_kda_bwd +from flash_mask.linear_attn.ops.kda.chunk_fwd import chunk_kda_fwd +from flash_mask.linear_attn.triton_utils import activate_paddle_driver, compat_kernel_wrapper_fastpath +from flash_mask.linear_attn.ops.utils.index import prepare_chunk_indices +from flash_mask.linear_attn.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +class ChunkKDAFunction(paddle.autograd.PyLayer): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + cu_seqlens: paddle.Tensor | None = None, + cu_seqlens_cpu: paddle.Tensor | None = None, + safe_gate: bool = False, + lower_bound: float | None = None, + disable_recompute: bool = False, + return_intermediate_states: bool = False, + cp_context=None, + transpose_state_layout: bool = False, + ): + chunk_size = 64 + with activate_paddle_driver(), compat_kernel_wrapper_fastpath(): + _orig_forward_args = [ + q, k, v, g, beta, A_log, dt_bias, scale, initial_state, + output_final_state, use_qk_l2norm_in_kernel, use_gate_in_kernel, + cu_seqlens, cu_seqlens_cpu, safe_gate, lower_bound, + disable_recompute, return_intermediate_states, cp_context, transpose_state_layout, + ] + + q_rstd, k_rstd = None, None + if use_qk_l2norm_in_kernel: + q, q_rstd = l2norm_fwd(q) + k, k_rstd = l2norm_fwd(k) + + chunk_indices = prepare_chunk_indices( + cu_seqlens, chunk_size, cu_seqlens_cpu=cu_seqlens_cpu) if cu_seqlens is not None else None + + g_input = g + + (o, final_state, g_cumsum, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state) = chunk_kda_fwd( + q=q, + k=k, + v=v, + g=g_input, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + chunk_indices=chunk_indices, + safe_gate=safe_gate, + lower_bound=lower_bound, + use_gate_in_kernel=use_gate_in_kernel, + A_log=A_log, + dt_bias=dt_bias, + disable_recompute=disable_recompute, + return_intermediate_states=return_intermediate_states, + cp_context=cp_context, + transpose_state_layout=transpose_state_layout, + ) + + if return_intermediate_states: + assert not paddle.is_grad_enabled(), "return_intermediate_states is only allowed in inference mode" + assert disable_recompute is False, "return_intermediate_states must be used with disable_recompute=False" + return o.cast(q.dtype), final_state, h + + saved_tensors = [ + q, q_rstd, k, k_rstd, v, g_cumsum, g_input, beta, A_log, dt_bias, Aqk, Akk, + initial_state, cu_seqlens, chunk_indices, + ] + if disable_recompute: + saved_tensors.extend([w, u, qg, kg, v_new, h]) + ctx.save_for_backward(*saved_tensors) + ctx.chunk_size = chunk_size + ctx.safe_gate = safe_gate + ctx.scale = scale + ctx.lower_bound = lower_bound + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + ctx.use_gate_in_kernel = use_gate_in_kernel + ctx.disable_recompute = disable_recompute + ctx.cp_context = cp_context + ctx.transpose_state_layout = transpose_state_layout + ctx.output_final_state = output_final_state + # Paddle PyLayer backward must return exactly as many values as tensor inputs. + # Record which forward args are tensors so backward can filter its return. + # Also record which tensor inputs need gradients (stop_gradient=False). + # IMPORTANT: Use _orig_forward_args (captured before q/k were reassigned by + # l2norm_fwd) so that stop_gradient reflects the *caller's* tensors. + ctx._tensor_mask = tuple(isinstance(a, paddle.Tensor) for a in _orig_forward_args) + ctx._needs_grad = tuple( + isinstance(a, paddle.Tensor) and not a.stop_gradient for a in _orig_forward_args + ) + # Paddle PyLayer forward cannot return None, use dummy tensor as placeholder + if final_state is None: + final_state = paddle.zeros([1], dtype=q.dtype) + return o.cast(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + # When output_final_state=False, forward returned a dummy tensor; + # restore dht to None so downstream bwd functions handle it correctly + if not ctx.output_final_state: + dht = None + with activate_paddle_driver(), compat_kernel_wrapper_fastpath(): + saved_tensors = ctx.saved_tensor() + if ctx.disable_recompute: + (q, q_rstd, k, k_rstd, v, g_cumsum, g_input, beta, A_log, dt_bias, Aqk, Akk, + initial_state, cu_seqlens, chunk_indices, + w, u, qg, kg, v_new, h) = saved_tensors + else: + (q, q_rstd, k, k_rstd, v, g_cumsum, g_input, beta, A_log, dt_bias, Aqk, Akk, + initial_state, cu_seqlens, chunk_indices) = saved_tensors + w = u = qg = kg = v_new = h = None + + dq, dk, dv, db, dg, dh0, dA, dbias = chunk_kda_bwd( + q=q, + k=k, + v=v, + g=g_cumsum, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_size=ctx.chunk_size, + safe_gate=ctx.safe_gate, + g_org=g_input if ctx.use_gate_in_kernel else None, lower_bound=ctx.lower_bound, + use_gate_in_kernel=ctx.use_gate_in_kernel, + A_log=A_log, dt_bias=dt_bias, + disable_recompute=ctx.disable_recompute, + w=w, u=u, qg=qg, kg=kg, v_new=v_new, h=h, + cp_context=ctx.cp_context, + transpose_state_layout=ctx.transpose_state_layout, + ) + if ctx.use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q, q_rstd, dq) + dk = l2norm_bwd(k, k_rstd, dk) + + # Build all grads in forward arg order, filter to tensor inputs only. + # Order: q, k, v, g, beta, A_log, dt_bias, scale, initial_state, + # output_final_state, use_qk_l2norm_in_kernel, use_gate_in_kernel, + # cu_seqlens, cu_seqlens_cpu, safe_gate, lower_bound, + # disable_recompute, return_intermediate_states, cp_context, transpose_state_layout + all_grads = [ + dq.cast(q.dtype), dk.cast(k.dtype), dv.cast(v.dtype), dg.cast(g_input.dtype), db.cast(beta.dtype), + dA, dbias, None, dh0, None, None, None, None, None, None, None, None, None, None, None, + ] + return tuple( + g if needs_grad else None + for g, is_tensor, needs_grad in zip(all_grads, ctx._tensor_mask, ctx._needs_grad) + if is_tensor + ) + + +def chunk_kda( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + cu_seqlens: paddle.Tensor | None = None, + cu_seqlens_cpu: paddle.Tensor | None = None, + safe_gate: bool = False, + lower_bound: float | None = None, + disable_recompute: bool = False, + return_intermediate_states: bool = False, + cp_context=None, + transpose_state_layout: bool = False, + **kwargs, +): + r""" + Args: + q (paddle.Tensor): + queries of shape `[B, T, H, K]`. + k (paddle.Tensor): + keys of shape `[B, T, H, K]`. + v (paddle.Tensor): + values of shape `[B, T, H, V]`. + g (paddle.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H, K]`. + beta (paddle.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[float]): + Scale factor for the KDA attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[paddle.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q,k tensor internally. Default: `False`. + use_gate_in_kernel (bool): + Whether to compute the log-space KDA decay internally. + - If `True`: + The passed `g` acts as the raw input for `-exp(A_log).view(H, -1) * softplus(g + dt_bias.view(H, K))`. + Note that as part of the input arguments, + `A_log` (shape `[H]`) and the optional `dt_bias` (shape `[H * K]`) should be provided. + - If `False`, `g` is expected to be the pre-computed decay value. + Default: `False`. + cu_seqlens (paddle.Tensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + cu_seqlens_cpu (paddle.Tensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + safe_gate (bool): + Whether the kernel can assume the gate values (in log space) are in a safe range + and use M=16 TensorCore acceleration for higher throughput. + The safe range is ``[lower_bound, 0)``. With the default ``lower_bound=-5``, + the per-step decay factor ``exp(g)`` is bounded in ``[exp(-5), 1) = [0.0067, 1)``, + meaning each step retains at least ~0.67% of the state -- a negligible loss that + has minimal impact on model quality while enabling significant speedup. + Requires ``lower_bound`` to be set. Default: ``False``. + lower_bound (Optional[float]): + Lower bound for the forget gate (in log space) when ``use_gate_in_kernel=True``. + Changes the gate activation from ``-exp(A_log) * softplus(g + dt_bias)`` + to ``lower_bound * sigmoid(exp(A_log) * (g + dt_bias))``, + which naturally clamps the output to ``[lower_bound, 0)``. + Recommended value: ``-5`` (i.e., ``exp(-5) = 0.0067``). Default: ``None``. + disable_recompute (bool): + Whether to disable gradient recomputation in the kernel. When `True`, the kernel + will save all intermediate activations for backward pass, which is beneficial + for training small models at the cost of increased memory usage. Default: `False`. + return_intermediate_states (bool): + If True, returns intermediate state `h` for inference scenarios (e.g., vLLM). + Must be used outside `paddle.is_grad_enabled()` and will return a 3-tuple instead of 2-tuple. + This is not intended for training as it bypasses autograd. Default: `False`. + cp_context: + Context parallel context (skipped in Paddle migration). Default: `None`. + transpose_state_layout (Optional[bool]): + Whether to use the transposed state layout for the hidden state. + Default: `False`. + + Returns: + - Normal mode (return_intermediate_states=False): A tuple (o, final_state) + o (paddle.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (paddle.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + - Inference mode (return_intermediate_states=True): A tuple (o, final_state, h) + o (paddle.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (paddle.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + h (paddle.Tensor): + Intermediate states of shape `[B, NT, H, K, V]` and dtype `bfloat16` for caching or further processing. + - For equal-length sequences: `NT = #chunks_per_sequence` (typically `ceil(T / chunk_size)`) + - For variable-length sequences (cu_seqlens): B is always 1 (flattened), + NT is the total number of chunks across all sequences, + determined by `prepare_chunk_indices(cu_seqlens, chunk_size)` + """ + + # CP (Context Parallel) is skipped in Paddle migration + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.", + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", + ) + if initial_state is not None: + assert initial_state.dtype == paddle.float32, "initial_state must be in float32." + + A_log, dt_bias = None, None + if use_gate_in_kernel: + assert "A_log" in kwargs, "A_log must be provided when use_gate_in_kernel=True." + A_log, dt_bias = kwargs["A_log"], kwargs.get("dt_bias") + + if safe_gate and use_gate_in_kernel: + if lower_bound is None: + raise ValueError("`lower_bound` must be specified when `safe_gate=True` and `use_gate_in_kernel=True`.") + if not (-5 <= lower_bound < 0): + raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") + + assert q.shape == k.shape == g.shape, "q, k, g must have the same shape." + assert k.shape[-1] <= 256, "Currently we only support key headdim <=256 for KDA :-(" + assert beta.shape == q.shape[:3], "beta must be of shape (batch size, seq len, num of head)." + assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." + + if scale is None: + scale = k.shape[-1] ** -0.5 + result = ChunkKDAFunction.apply( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + use_gate_in_kernel, + cu_seqlens, + cu_seqlens_cpu, + safe_gate, + lower_bound, + disable_recompute, + return_intermediate_states, + cp_context, + transpose_state_layout, + ) + if return_intermediate_states: + o, final_state, h = result + if not output_final_state: + final_state = None + return o, final_state, h + o, final_state = result + # Convert dummy tensor back to None when output_final_state=False + if not output_final_state: + final_state = None + return o, final_state diff --git a/flashmask/flash_mask/linear_attn/ops/kda/chunk_bwd.py b/flashmask/flash_mask/linear_attn/ops/kda/chunk_bwd.py new file mode 100644 index 00000000000..f7f25a0ad65 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/chunk_bwd.py @@ -0,0 +1,580 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +from functools import lru_cache + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from flash_mask.linear_attn.ops.kda.chunk_intra import chunk_kda_bwd_intra +from flash_mask.linear_attn.ops.kda.gate import kda_gate_bwd, kda_gate_chunk_cumsum +from flash_mask.linear_attn.ops.kda.wy_fast import recompute_w_u_fwd +from flash_mask.linear_attn.ops.utils import chunk_local_cumsum, prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.constant import RCP_LN2 +from flash_mask.linear_attn.ops.utils.op import exp2 +from flash_mask.linear_attn.utils import ( + IS_NVIDIA_HOPPER, + autotune_cache_kwargs, + check_shared_mem, +) +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32] + + +@lru_cache(maxsize=None) +def _chunk_kda_tiling(device_idx: int) -> int: + if check_shared_mem('hopper'): + return 128 + if check_shared_mem('ada'): + return 64 + return 32 + + +@lru_cache(maxsize=None) +def _chunk_kda_launch_meta(device_idx: int, T: int, K: int, V: int, BT: int) -> tuple[int, int, int]: + const_tiling = _chunk_kda_tiling(device_idx) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NT = triton.cdiv(T, BT) + return BK, BV, NT +NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8] + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_kda_bwd_kernel_dAv( + q, + k, + v, + A, + do, + dv, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + dA += (bos * H + i_h) * BT + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, b_v) + # [BT, BV] + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + p_dA = tl.make_block_ptr(dA, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_dA = tl.where(o_t[:, None] >= o_t, b_dA * scale, 0.) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + if not (IS_NVIDIA_HOPPER and BK == 32 and num_warps == 4) + ], + key=['BT', 'TRANSPOSE_STATE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_kda_bwd_kernel_wy_dqkg_fused( + q, + k, + v, + v_new, + g, + beta, + A, + h, + do, + dh, + dq, + dk, + dv, + dv2, + dg, + db, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t.to(tl.int64) + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = (eos - bos).to(tl.int32) + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = (i_b * NT + i_t).to(tl.int64) + bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_last = (o_t == min(T, i_t * BT + BT) - 1) + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + v_new += (bos * H + i_h) * V + g += (bos * H + i_h) * K + beta += bos * H + i_h + A += (bos * H + i_h) * BT + h += (i_tg * H + i_h) * K*V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K*V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + dv2 += (bos * H + i_h) * V + dg += (bos * H + i_h) * K + db += bos * H + i_h + dA += (bos * H + i_h) * BT + + p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + b_db = tl.zeros([BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + p_gn = g + (min(T, i_t * BT + BT) - 1).to(tl.int64) * H*K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v_new = tl.make_block_ptr(v_new, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + if TRANSPOSE_STATE: + p_h = tl.make_block_ptr(h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) + p_dh = tl.make_block_ptr(dh, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)) + else: + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BV] + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + + b_dgk += tl.sum(b_h * b_dh, axis=0) + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v_new, b_dh.to(b_v_new.dtype)) + b_dw += tl.dot(b_dv.to(b_v_new.dtype), b_h.to(b_v_new.dtype)) + tl.debug_barrier() # DO NOT REMOVE THIS LINE! + if i_k == 0: + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dv, tl.trans(b_v)) + + b_dvb = tl.dot(b_A, b_dv) + b_dv2 = b_dvb * b_beta[:, None] + b_db += tl.sum(b_dvb * b_v, 1) + + tl.store(p_dv2, b_dv2.to(p_dv2.dtype.element_ty), boundary_check=(0, 1)) + + b_gk_exp = exp2(b_g) + b_gb = b_gk_exp * b_beta[:, None] + b_dgk *= exp2(b_gn) + b_dq = b_dq * b_gk_exp * scale + b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) + + b_kg = b_k * b_gk_exp + + b_dw = -b_dw.to(b_A.dtype) + b_dA += tl.dot(b_dw, tl.trans(b_kg.to(b_A.dtype))) + + b_dkgb = tl.dot(b_A, b_dw) + b_db += tl.sum(b_dkgb * b_kg, 1) + + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kdk = b_k * b_dk + b_dgk += tl.sum(b_kdk, axis=0) + b_dg = b_q * b_dq - b_kdk + m_last[:, None] * b_dgk + b_kg * b_dkgb * b_beta[:, None] + b_dk = b_dk + b_dkgb * b_gb + + p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA * b_beta[None, :], 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(m_A, -b_dA, 0) + + p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def chunk_kda_bwd_dAv( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + do: paddle.Tensor, + A: paddle.Tensor | None = None, + scale: float = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, +) -> tuple[paddle.Tensor, paddle.Tensor]: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is None: + BK, BV, NT = _chunk_kda_launch_meta(k.place.gpu_device_id(), T, K, V, BT) + else: + const_tiling = _chunk_kda_tiling(k.place.gpu_device_id()) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NT = len(chunk_indices) + + dA = paddle.empty([B, T, H, BT], dtype=paddle.float32) + dv = paddle.empty_like(do) + grid = (NT, B * H) + chunk_kda_bwd_kernel_dAv[grid]( + q=q, + k=k, + v=v, + A=A, + do=do, + dv=dv, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dA, dv + + +def chunk_kda_bwd_wy_dqkg_fused( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + v_new: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + A: paddle.Tensor, + h: paddle.Tensor, + do: paddle.Tensor, + dh: paddle.Tensor, + dv: paddle.Tensor, + scale: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + transpose_state_layout: bool = False, +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dq = paddle.empty_like(q, dtype=paddle.float32) + dk = paddle.empty_like(k, dtype=paddle.float32) + dv2 = paddle.empty_like(v) + dg = paddle.empty_like(g, dtype=paddle.float32) + db = paddle.empty_like(beta, dtype=paddle.float32) + dA = paddle.empty_like(A, dtype=paddle.float32) + + grid = (NT, B * H) + chunk_kda_bwd_kernel_wy_dqkg_fused[grid]( + q=q, + k=k, + v=v, + v_new=v_new, + g=g, + beta=beta, + A=A, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + dv=dv, + dv2=dv2, + dg=dg, + db=db, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + TRANSPOSE_STATE=transpose_state_layout, + ) + dv = dv2 + return dq, dk, dv, db, dg, dA + + +def chunk_kda_bwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + Aqk: paddle.Tensor, + Akk: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + do: paddle.Tensor, + dht: paddle.Tensor, + g: paddle.Tensor | None = None, + g_org: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + chunk_size: int = 64, + safe_gate: bool = False, + lower_bound: float | None = None, + use_gate_in_kernel: bool = False, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + disable_recompute: bool = False, + cp_context=None, + transpose_state_layout: bool = False, + **kwargs, +): + if cp_context is not None: + raise NotImplementedError("CP not supported in paddle migration") + + if disable_recompute is False: + if use_gate_in_kernel: + g = kda_gate_chunk_cumsum( + g=g_org, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=lower_bound + ) + w, u, qg, kg = recompute_w_u_fwd( + q=q, + k=k, + v=v, + beta=beta, + A=Akk, + gk=g, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=kg, + w=w, + u=u, + gk=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=True, + transpose_state_layout=transpose_state_layout, + ) + else: + w, u, qg, kg, v_new, h = kwargs["w"], kwargs["u"], kwargs["qg"], kwargs["kg"], kwargs["v_new"], kwargs["h"] + + # dAqk = do @ v.T + # dv = A @ do + dAqk, dv = chunk_kda_bwd_dAv( + q=q, + k=k, + v=v_new, + do=do, + A=Aqk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + ) + + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=qg, + k=kg, + w=w, + gk=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=True, + transpose_state_layout=transpose_state_layout, + ) + + dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused( + q=q, + k=k, + v=v, + v_new=v_new, + g=g, + beta=beta, + A=Akk, + h=h, + do=do, + dh=dh, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + transpose_state_layout=transpose_state_layout, + ) + + dq, dk, db, dg = chunk_kda_bwd_intra( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=safe_gate + ) + + dA, dbias = None, None + dg = chunk_local_cumsum( + dg, + chunk_size=chunk_size, + reverse=True, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + if use_gate_in_kernel: + dg, dA, dbias = kda_gate_bwd( + g=g_org, + A_log=A_log, + dt_bias=dt_bias, + dyg=dg, + lower_bound=lower_bound + ) + + return dq, dk, dv, db, dg, dh0, dA, dbias diff --git a/flashmask/flash_mask/linear_attn/ops/kda/chunk_fwd.py b/flashmask/flash_mask/linear_attn/ops/kda/chunk_fwd.py new file mode 100644 index 00000000000..93bbb39d8f2 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/chunk_fwd.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle + +from flash_mask.linear_attn.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h +from flash_mask.linear_attn.ops.gla.chunk import chunk_gla_fwd_o_gk +from flash_mask.linear_attn.ops.kda.chunk_intra import chunk_kda_fwd_intra +from flash_mask.linear_attn.ops.kda.gate import kda_gate_chunk_cumsum +from flash_mask.linear_attn.ops.utils import chunk_local_cumsum +from flash_mask.linear_attn.ops.utils.constant import RCP_LN2 + + +def chunk_kda_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float, + initial_state: paddle.Tensor, + output_final_state: bool, + cu_seqlens: paddle.Tensor | None = None, + cu_seqlens_cpu: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + chunk_size: int = 64, + safe_gate: bool = False, + lower_bound: float | None = None, + use_gate_in_kernel: bool = False, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + disable_recompute: bool = False, + return_intermediate_states: bool = False, + cp_context=None, + transpose_state_layout: bool = False, +): + # Apply gate activation + g_org = None + if use_gate_in_kernel: + g_org = g + g = kda_gate_chunk_cumsum( + g=g_org, + A_log=A_log, + dt_bias=dt_bias, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=lower_bound, + ) + else: + g = chunk_local_cumsum( + g=g, + scale=RCP_LN2, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices + ) + + # qg = None if disable_recompute is False + w, u, qg, kg, Aqk, Akk = chunk_kda_fwd_intra( + q=q, + k=k, + v=v, + gk=g, + beta=beta, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + safe_gate=safe_gate, + disable_recompute=disable_recompute + ) + + # CP (Context Parallel) is skipped in Paddle migration + + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=kg, + w=w, + u=u, + gk=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + chunk_indices=chunk_indices, + use_exp2=True, + transpose_state_layout=transpose_state_layout, + ) + + # CP (Context Parallel) is skipped in Paddle migration + + o = chunk_gla_fwd_o_gk( + q=q, + v=v_new, + g=g, + A=Aqk, + h=h, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + chunk_indices=chunk_indices, + use_exp2=True, + transpose_state_layout=transpose_state_layout, + ) + if disable_recompute is False: + # Delete to save memory + w, u, qg, kg, v_new = None, None, None, None, None + if not return_intermediate_states: + # Only delete h if not requested for inference + h = None + if use_gate_in_kernel: + g = None + return o, final_state, g, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state diff --git a/flashmask/flash_mask/linear_attn/ops/kda/chunk_intra.py b/flashmask/flash_mask/linear_attn/ops/kda/chunk_intra.py new file mode 100644 index 00000000000..bb42b8135bb --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/chunk_intra.py @@ -0,0 +1,909 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.kda.chunk_intra_token_parallel import chunk_kda_fwd_intra_token_parallel +from flash_mask.linear_attn.ops.kda.wy_fast import recompute_w_u_fwd +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp2, gather +from flash_mask.linear_attn.utils import IS_GATHER_SUPPORTED, IS_TF32_SUPPORTED, autotune_cache_kwargs +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +if IS_TF32_SUPPORTED: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32') +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr('ieee') + +################################################################################ +# Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass +################################################################################ + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps) + for BK in [32, 64] + for num_warps in [1, 2, 4] + ], + key=["H", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_kda_fwd_kernel_inter_solve_fused( + q, + k, + g, + beta, + Aqk, + Akkd, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_SAFE_GATE: tl.constexpr, +): + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akkd. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akkd (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + Aqk += (bos * H + i_h) * BT + Akk += (bos * H + i_h) * BT + Akkd += (bos * H + i_h) * BC + + o_i = tl.arange(0, BC) + m_tc1 = (i_tc1 + o_i) < T + m_tc2 = (i_tc2 + o_i) < T + m_tc3 = (i_tc3 + o_i) < T + + b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + ################################################################################ + # off-diagonal blocks + ################################################################################ + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) + p_g0 = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_tc0, i_k * BK), (BC, BK), (1, 0)) + b_k0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_g0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + if i_tc1 < T: + p_q1 = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_k1 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_g1 = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn1 = tl.load(g + i_tc1 * H*K + o_k, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_gqn = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn1[None, :] - b_g0)) + # [BC, BC] + b_Aqk10 += tl.dot(b_q1 * b_gqn, b_kgt) + b_Akk10 += tl.dot(b_k1 * b_gqn, b_kgt) + + if i_tc2 < T: + p_q2 = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_g2 = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn2 = tl.load(g + i_tc2 * H*K + o_k, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) + b_qg2 = b_q2 * b_gqn2 + b_kg2 = b_k2 * b_gqn2 + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn2[None, :] - b_g0)) + b_Aqk20 += tl.dot(b_qg2, b_kgt) + b_Akk20 += tl.dot(b_kg2, b_kgt) + # [BC, BC] + b_kgt = tl.trans(b_k1 * exp2(b_gn2[None, :] - b_g1)) + # [BC, BC] + b_Aqk21 += tl.dot(b_qg2, b_kgt) + b_Akk21 += tl.dot(b_kg2, b_kgt) + + if i_tc3 < T: + p_q3 = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_k3 = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_g3 = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + # [BK] + b_gn3 = tl.load(g + i_tc3 * H*K + o_k, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) + b_qg3 = b_q3 * b_gqn3 + b_kg3 = b_k3 * b_gqn3 + # [BK, BC] + b_kgt = tl.trans(b_k0 * exp2(b_gn3[None, :] - b_g0)) + # [BC, BC] + b_Aqk30 += tl.dot(b_qg3, b_kgt) + b_Akk30 += tl.dot(b_kg3, b_kgt) + # [BK, BC] + b_kgt = tl.trans(b_k1 * exp2(b_gn3[None, :] - b_g1)) + # [BC, BC] + b_Aqk31 += tl.dot(b_qg3, b_kgt) + b_Akk31 += tl.dot(b_kg3, b_kgt) + # [BK, BC] + b_kgt = tl.trans(b_k2 * exp2(b_gn3[None, :] - b_g2)) + # [BC, BC] + b_Aqk32 += tl.dot(b_qg3, b_kgt) + b_Akk32 += tl.dot(b_kg3, b_kgt) + + ################################################################################ + # save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + if i_tc1 < T: + p_Aqk10 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b1 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_Aqk20 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Aqk21 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + tl.store(p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b2 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_Aqk30 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Aqk31 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Aqk32 = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_tc3, 2*BC), (BC, BC), (1, 0)) + tl.store(p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b3 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + p_Akk00 = tl.make_block_ptr(Akkd, (T, BC), (H*BC, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akkd, (T, BC), (H*BC, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akkd, (T, BC), (H*BC, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akkd, (T, BC), (H*BC, 1), (i_tc3, 0), (BC, BC), (1, 0)) + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + ################################################################################ + # forward substitution on diagonals + ################################################################################ + + if not USE_SAFE_GATE: + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Ai00 = -tl.where(m_A, b_Ai00, 0) + b_Ai11 = -tl.where(m_A, b_Ai11, 0) + b_Ai22 = -tl.where(m_A, b_Ai22, 0) + b_Ai33 = -tl.where(m_A, b_Ai33, 0) + + for i in range(2, min(BC, T - i_tc0)): + b_a00 = -tl.load(Akkd + (i_tc0 + i) * H*BC + o_i) + b_a00 = tl.where(o_i < i, b_a00, 0.) + b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(BC + 2, min(2*BC, T - i_tc0)): + b_a11 = -tl.load(Akkd + (i_tc0 + i) * H*BC + o_i) + b_a11 = tl.where(o_i < i - BC, b_a11, 0.) + b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) + for i in range(2*BC + 2, min(3*BC, T - i_tc0)): + b_a22 = -tl.load(Akkd + (i_tc0 + i) * H*BC + o_i) + b_a22 = tl.where(o_i < i - 2*BC, b_a22, 0.) + b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i - 2*BC)[:, None], b_a22, b_Ai22) + for i in range(3*BC + 2, min(4*BC, T - i_tc0)): + b_a33 = -tl.load(Akkd + (i_tc0 + i) * H*BC + o_i) + b_a33 = tl.where(o_i < i - 3*BC, b_a33, 0.) + b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i - 3*BC)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + ################################################################################ + # compute merged inverse using off-diagonals + ################################################################################ + + # we used tf32 to maintain matrix inverse's precision whenever possible. + b_Ai10 = -tl.dot( + tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai00, + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + b_Ai21 = -tl.dot( + tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai11, + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + b_Ai32 = -tl.dot( + tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), + b_Ai22, + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION + ) + + ################################################################################ + # store full Akk_inv to Akk + ################################################################################ + + p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk21 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc2, 2*BC), (BC, BC), (1, 0)) + p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Akk31 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Akk32 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc3, 2*BC), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk, (T, BT), (H*BT, 1), (i_tc3, 3*BC), (BC, BC), (1, 0)) + + tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BK', 'NC', 'BT'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['B', 'T']) +def chunk_kda_bwd_kernel_intra( + q, + k, + g, + beta, + dAqk, + dAkk, + dq, + dq2, + dk, + dk2, + dg, + dg2, + db, + cu_seqlens, + chunk_indices, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, + SAFE_GATE: tl.constexpr, + USE_GATHER: tl.constexpr, +): + i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_k, i_i = i_kc // NC, i_kc % NC + + all = B * T + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + beta += bos * H + i_h + + dAqk += (bos * H + i_h) * BT + dAkk += (bos * H + i_h) * BT + dq += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + dg2 += (bos * H + i_h) * K + db += (i_k * all + bos) * H + i_h + + p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) + b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + p_gn = g + i_ti * H*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H*BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H*BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp2(b_gn - b_gk) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + # [BC, BK] + b_dq2 += tl.dot(b_dAqk, b_kg) + b_dk2 += tl.dot(b_dAkk, b_kg) + b_gqn = exp2(b_g - b_gn) + b_dq2 *= b_gqn + b_dk2 *= b_gqn + + o_i = tl.arange(0, BC) + m_dA = (i_ti + o_i) < T + o_dA = (i_ti + o_i) * H*BT + i_i * BC + p_kj = k + i_ti * H*K + o_k + p_gkj = g + i_ti * H*K + o_k + + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + if SAFE_GATE: + if USE_GATHER: + b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) + else: + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H*K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0)[None, :] + + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H*BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H*BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + b_dAqk_diag_qk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) + b_dAkk_diag_qk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) + + m_i_diag_qk = (o_i[:, None] >= o_i[None, :]) & ((i_ti + o_i[:, None]) < T) & ((i_ti + o_i[None, :]) < T) + m_j_diag_qk = (i_ti + o_i[:, None]) < T + + b_dAqk_diag_qk = tl.where(m_i_diag_qk, b_dAqk_diag_qk, 0.) + b_dAkk_diag_qk = tl.where(m_i_diag_qk, b_dAkk_diag_qk, 0.) + b_g_diag_qk = tl.where(m_j_diag_qk, b_g - b_gn, 0.) + exp_b_g_diag_qk = tl.where(m_j_diag_qk, exp2(b_g_diag_qk), 0.) + exp_neg_b_g_diag_qk = tl.where(m_j_diag_qk, exp2(-b_g_diag_qk), 0.) + + b_k_exp_diag_qk = b_k * exp_neg_b_g_diag_qk + b_dq2 += tl.dot(b_dAqk_diag_qk, b_k_exp_diag_qk) * exp_b_g_diag_qk + b_dk2 += tl.dot(b_dAkk_diag_qk, b_k_exp_diag_qk) * exp_b_g_diag_qk + else: + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC] + b_dAqk = tl.load(dAqk + o_dA + j, mask=m_dA, other=0) + b_dAkk = tl.load(dAkk + o_dA + j, mask=m_dA, other=0) + # [BK] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_gqk = exp2(b_g - b_gkj[None, :]) + b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kj[None, :] * b_gqk, 0.) + b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kj[None, :] * b_gqk, 0.) + + p_kj += H*K + p_gkj += H*K + + b_db = tl.sum(b_dk2 * b_k, 1) + b_dk2 *= b_b[:, None] + + p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_dg2 = b_q * b_dq2 + b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + tl.debug_barrier() + b_dkt = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = g + (min(i_ti + BC, T) - 1) * H*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + # [BC] + b_b = tl.load(p_b, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] + b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + + o_j = i_t * BT + i_j * BC + o_i + m_j = o_j < T + # [BC, BK] + b_gkn = exp2(b_gk - b_gn) + b_qg = b_q * tl.where(m_j[:, None], b_gkn, 0) + b_kbg = b_kb * tl.where(m_j[:, None], b_gkn, 0) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dkt += tl.dot(b_dAqk, b_qg) + b_dkt += tl.dot(b_dAkk, b_kbg) + b_dkt *= exp2(b_gn - b_g) + o_dA = i_ti * H*BT + i_i * BC + o_i + p_qj = q + i_ti * H*K + o_k + p_kj = k + i_ti * H*K + o_k + p_gkj = g + i_ti * H*K + o_k + p_bj = beta + i_ti * H + + if SAFE_GATE: + if USE_GATHER: + b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) + else: + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H*K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)[None, :] + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H*BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H*BT), (i_i * BC, i_ti), (BC, BC), (0, 1)) + b_dAqk_diag_kk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) + b_dAkk_diag_kk = tl.load(p_dAkk, boundary_check=(0, 1)).to(tl.float32) + + m_i_diag_kk = (o_i[:, None] <= o_i[None, :]) & ((i_ti + o_i[:, None]) < T) & ((i_ti + o_i[None, :]) < T) + m_j_diag_kk = (i_ti + o_i[:, None]) < T + + b_dAqk_diag_kk = tl.where(m_i_diag_kk, b_dAqk_diag_kk, 0.) + b_dAkk_diag_kk = tl.where(m_i_diag_kk, b_dAkk_diag_kk, 0.) + # ensure numerical stability + b_g_diag_kk = tl.where(m_j_diag_kk, b_g - b_gn, 0.) + exp_b_g_diag_kk = tl.where(m_j_diag_kk, exp2(b_g_diag_kk), 0.) + exp_neg_b_g_diag_kk = tl.where(m_j_diag_kk, exp2(-b_g_diag_kk), 0.) + + b_q_exp = b_q * exp_b_g_diag_kk + b_kb_exp = b_k * b_b[:, None] * exp_b_g_diag_kk + + b_dkt += tl.dot(b_dAqk_diag_kk, b_q_exp) * exp_neg_b_g_diag_kk + b_dkt += tl.dot(b_dAkk_diag_kk, b_kb_exp) * exp_neg_b_g_diag_kk + else: + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dAqk = tl.load(dAqk + o_dA + j * H*BT) + b_dAkk = tl.load(dAkk + o_dA + j * H*BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_gkq = exp2(b_gkj[None, :] - b_g) + b_dkt += tl.where(m_i, b_dAqk[:, None] * b_qj[None, :] * b_gkq, 0.) + b_dkt += tl.where(m_i, b_dAkk[:, None] * b_kbj[None, :] * b_gkq, 0.) + + p_qj += H*K + p_kj += H*K + p_gkj += H*K + p_bj += H + p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2, (T, K), (H*K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + + b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) + b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) + b_dk2 += b_dkt + + tl.store(p_dk2, b_dk2.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg2.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_kda_fwd_kernel_intra_sub_chunk( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_GATHER: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_c = i_ti + tl.arange(0, BC) + m_c = o_c < T + + q = q + (bos * H + i_h) * K + k = k + (bos * H + i_h) * K + g = g + (bos * H + i_h) * K + beta = beta + bos * H + i_h + Aqk = Aqk + (bos * H + i_h) * BT + Akk = Akk + (bos * H + i_h) * BC + + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_ti, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_ti, 0), (BC, BK), (1, 0)) + + p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + if USE_GATHER: + b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) + else: + # caculate offset + p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H*K + tl.arange(0, BK) + b_gn = tl.load(p_gn, mask=tl.arange(0, BK) < K, other=0.0) + b_gn = b_gn[None, :] + + # current block, keep numerical stability by subtracting the left boundary + # less than 85 to avoid overflow in exp2 + b_gm = (b_g - b_gn).to(tl.float32) + + b_gq = tl.where(m_c[:, None], exp2(b_gm), 0.) + b_gk = tl.where(m_c[:, None], exp2(-b_gm), 0.) + + b_kgt = tl.trans(b_k * b_gk) + + b_Aqk = tl.dot(b_q * b_gq, b_kgt) * scale + b_Akk = tl.dot(b_k * b_gq, b_kgt) * b_beta[:, None] + + o_i = tl.arange(0, BC) + m_Aqk = o_i[:, None] >= o_i[None, :] + m_Akk = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Aqk = tl.where(m_Aqk, b_Aqk, 0.0) + b_Akk = tl.where(m_Akk, b_Akk, 0.0) + + p_Aqk = tl.make_block_ptr(Aqk, (T, BT), (H*BT, 1), (i_ti, i_i * BC), (BC, BC), (1, 0)) + p_Akk = tl.make_block_ptr(Akk, (T, BC), (H*BC, 1), (i_ti, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk, b_Akk.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + + ################################################################################ + # forward substitution + ################################################################################ + + b_Ai = -b_Akk + for i in range(2, min(BC, T - i_ti)): + b_a = -tl.load(Akk + (i_ti + i) * H*BC + o_i) + b_a = tl.where(o_i < i, b_a, 0.) + b_a += tl.sum(b_a[:, None] * b_Ai, 0) + b_Ai = tl.where((o_i == i)[:, None], b_a, b_Ai) + b_Ai += m_I + tl.store(p_Akk, b_Ai.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_fwd_intra( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + gk: paddle.Tensor | None = None, + beta: paddle.Tensor | None = None, + scale: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + chunk_indices: paddle.Tensor | None = None, + safe_gate: bool = False, + disable_recompute: bool = False, +): + B, T, H, K = k.shape + BT = chunk_size + BC = 16 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + + Aqk = paddle.empty([B, T, H, BT], dtype=k.dtype) + # Akk must be zero-initialized - kernel only writes lower triangular + Akk = paddle.zeros([B, T, H, BT], dtype=k.dtype) + # Separate fp32 buffer for diagonal 16x16 blocks (for precision in solve_tril) + Akkd = paddle.empty([B, T, H, BC], dtype=paddle.float32) + + # Step 1: Run token_parallel first to compute diagonal blocks into Akkd (fp32) + # Step 1: compute diagonal blocks into Akk_diag (fp32) + if safe_gate: + grid = (NT, NC, B * H) + BK = triton.next_power_of_2(K) + chunk_kda_fwd_kernel_intra_sub_chunk[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akkd, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + USE_GATHER=IS_GATHER_SUPPORTED, + ) + else: + Aqk, Akkd = chunk_kda_fwd_intra_token_parallel( + q=q, + k=k, + gk=gk, + beta=beta, + Aqk=Aqk, + Akk=Akkd, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + sub_chunk_size=BC, + ) + + # Step 2: Fused inter + solve_tril (works for both fixed-len and varlen) + grid = (NT, B * H) + chunk_kda_fwd_kernel_inter_solve_fused[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akkd=Akkd, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + USE_SAFE_GATE=safe_gate, + ) + w, u, qg, kg = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=Akk, + q=q if disable_recompute else None, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + return w, u, qg, kg, Aqk, Akk + + +def chunk_kda_bwd_intra( + q: paddle.Tensor, + k: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + dAqk: paddle.Tensor, + dAkk: paddle.Tensor, + dq: paddle.Tensor, + dk: paddle.Tensor, + db: paddle.Tensor, + dg: paddle.Tensor, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + chunk_size: int = 64, + safe_gate: bool = False, +): + B, T, H, K = k.shape + BT = chunk_size + BC = min(16, BT) + BK = min(32, triton.next_power_of_2(K)) + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq2 = paddle.empty_like(q) + dk2 = paddle.empty_like(k) + db2 = paddle.empty([NK] + list(beta.shape), dtype=paddle.float32) + dg2 = paddle.empty_like(dg, dtype=paddle.float32) + grid = (NK * NC, NT, B * H) + chunk_kda_bwd_kernel_intra[grid]( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dq2=dq2, + dk=dk, + dk2=dk2, + dg=dg, + dg2=dg2, + db=db2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + SAFE_GATE=safe_gate, + USE_GATHER=IS_GATHER_SUPPORTED, + ) + dq = dq2 + dk = dk2 + db = db2.sum(axis=0).add_(db) + dg = dg2 + + return dq, dk, db, dg diff --git a/flashmask/flash_mask/linear_attn/ops/kda/chunk_intra_token_parallel.py b/flashmask/flash_mask/linear_attn/ops/kda/chunk_intra_token_parallel.py new file mode 100644 index 00000000000..3264e408385 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/chunk_intra_token_parallel.py @@ -0,0 +1,177 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +# Token-parallel implementation of KDA intra chunk kernel + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.op import exp2 +from flash_mask.linear_attn.utils import autotune_cache_kwargs +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BH': BH}, num_warps=num_warps) + for BH in [1, 2, 4, 8] + for num_warps in [1, 2, 4, 8] + ], + key=["K", "H"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T', 'N']) +def chunk_kda_fwd_kernel_intra_token_parallel( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + N, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_tg, i_hg = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = 0 + left, right = 0, N + + # Unrolled binary search (max B=2^32) + # We can limit iterations based on expected max batch size if needed + # 20 iterations covers B=1M, usually enough + for _ in range(20): + if left < right: + mid = (left + right) // 2 + if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): + right = mid + else: + left = mid + 1 + i_n = left + + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + i_t = i_tg - bos + else: + bos = (i_tg // T) * T + i_t = i_tg % T + + if i_t >= T: + return + + i_c = i_t // BT + i_s = (i_t % BT) // BC + i_tc = i_c * BT + i_ts = i_tc + i_s * BC + + q += bos * H*K + k += bos * H*K + g += bos * H*K + Aqk += bos * H*BT + Akk += bos * H*BC + beta += bos * H + + BK: tl.constexpr = triton.next_power_of_2(K) + o_h = tl.arange(0, BH) + o_k = tl.arange(0, BK) + m_h = (i_hg * BH + o_h) < H + m_k = o_k < K + + p_q = tl.make_block_ptr(q + i_t * H*K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_t * H*K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_t * H*K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) + # [BH, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] + + for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): + p_kj = tl.make_block_ptr(k + j * H*K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_gj = tl.make_block_ptr(g + j * H*K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + # [BH, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) + b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) + + b_kgj = b_kj * exp2(b_g - b_gj) + + b_kgj = tl.where(m_k[None, :], b_kgj, 0.0) + # [BH] + b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale + b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) + + tl.store(Aqk + i_t * H*BT + (i_hg * BH + o_h) * BT + j % BT, b_Aqk.to(Aqk.dtype.element_ty), mask=m_h) + tl.store(Akk + i_t * H*BC + (i_hg * BH + o_h) * BC + j - i_ts, b_Akk.to(Akk.dtype.element_ty), mask=m_h) + + +def chunk_kda_fwd_intra_token_parallel( + q: paddle.Tensor, + k: paddle.Tensor, + gk: paddle.Tensor, + beta: paddle.Tensor, + Aqk: paddle.Tensor, + Akk: paddle.Tensor, + scale: float, + cu_seqlens: paddle.Tensor | None = None, + chunk_size: int = 64, + sub_chunk_size: int = 16, +) -> None: + """ + Token-parallel implementation: each token gets its own thread block. + Supports both fixed-length and variable-length sequences. + Reduces wasted computation on padding. + + Writes directly to Aqk and Akk tensors (in-place). + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + gk: [B, T, H, K] cumsum of gates + beta: [B, T, H] + Aqk: [B, T, H, BT] output tensor to write to + Akk: [B, T, H, BC] output tensor for diagonal blocks (fp32) + scale: attention scale + chunk_size: BT (default 64) + sub_chunk_size: BC (default 16) + """ + B, T, H, K = q.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BT = chunk_size + BC = sub_chunk_size + + def grid(meta): return (B * T, triton.cdiv(H, meta['BH'])) + chunk_kda_fwd_kernel_intra_token_parallel[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + N=N, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) + return Aqk, Akk diff --git a/flashmask/flash_mask/linear_attn/ops/kda/fused_recurrent.py b/flashmask/flash_mask/linear_attn/ops/kda/fused_recurrent.py new file mode 100644 index 00000000000..ce3c01e8dfa --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/fused_recurrent.py @@ -0,0 +1,441 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +# This kernel is modified from the Decode kernel of the vllm gdn/kda model. + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.op import exp +from flash_mask.linear_attn.ops.utils.softplus import softplus +from flash_mask.linear_attn.utils import input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "HAS_DT_BIAS": lambda args: args["dt_bias"] is not None, + "USE_LOWER_BOUND": lambda args: args["lower_bound"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_kda_fwd_kernel( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + lower_bound, + scale: tl.constexpr, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + USE_GATE_IN_KERNEL: tl.constexpr, + USE_LOWER_BOUND: tl.constexpr, + TRANSPOSE_STATE: tl.constexpr, + num_stages: tl.constexpr, +): + pid = tl.program_id(0) + NV = tl.cdiv(V, BV) + NK = tl.cdiv(K, BK) + i_k = pid % NK + pid_rest = pid // NK + + i_v = pid_rest % NV + i_nh = pid_rest // NV + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + p_g = g + (bos * HV + i_hv) * K + o_k + p_o = o + (bos * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + if TRANSPOSE_STATE: + mask_h = mask_v[:, None] & mask_k[None, :] + else: + mask_h = mask_k[:, None] & mask_v[None, :] + + if TRANSPOSE_STATE: + b_h = tl.zeros([BV, BK], dtype=tl.float32) + else: + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + * stride_init_state_token + ) + if TRANSPOSE_STATE: + p_h0 = p_h0 + i_hv * K * V + o_v[:, None] * K + o_k[None, :] + else: + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + else: + if TRANSPOSE_STATE: + p_h0 = h0 + (i_n * HV + i_hv) * K * V + o_v[:, None] * K + o_k[None, :] + else: + p_h0 = h0 + (i_n * HV + i_hv) * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in tl.range(0, T, num_stages=num_stages): + b_q = tl.load(p_q, mask=mask_k, other=0, eviction_policy='evict_last').to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0, eviction_policy='evict_last').to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0, eviction_policy='evict_first').to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + b_g = tl.load(p_g, eviction_policy='evict_last').to(tl.float32) + + if USE_GATE_IN_KERNEL: + b_A = tl.load(A_log + i_h).to(tl.float32) + + if HAS_DT_BIAS: + b_bias = tl.load(dt_bias + i_h * K + o_k, mask=mask_k, other=0).to(tl.float32) + b_g = b_g + b_bias + + if USE_LOWER_BOUND: + b_gk = lower_bound * tl.sigmoid(exp(b_A) * b_g) + else: + b_gk = -exp(b_A) * softplus(b_g) + else: + b_gk = b_g + + if TRANSPOSE_STATE: + b_h *= exp(b_gk[None, :]) + else: + b_h *= exp(b_gk[:, None]) + + if TRANSPOSE_STATE: + b_v -= tl.sum(b_h * b_k[None, :], 1) + else: + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0, eviction_policy='evict_first').to(tl.float32) + else: + b_beta = tl.load(p_beta, eviction_policy='evict_last').to(tl.float32) + b_v *= b_beta + if TRANSPOSE_STATE: + b_h += b_v[:, None] * b_k[None, :] + b_o = tl.sum(b_h * b_q[None, :], 1) + else: + b_h += b_k[:, None] * b_v[None, :] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v, eviction_policy='evict_first') + + if IS_CONTINUOUS_BATCHING: + if INPLACE_FINAL_STATE: + p_ht = ( + ht + + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + * stride_final_state_token + ) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + if TRANSPOSE_STATE: + p_ht = p_ht + i_hv * K * V + o_v[:, None] * K + o_k[None, :] + else: + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + if not IS_CONTINUOUS_BATCHING: + if STORE_FINAL_STATE: + if TRANSPOSE_STATE: + p_ht = ht + (i_n * HV + i_hv) * K * V + o_v[:, None] * K + o_k[None, :] + else: + p_ht = ht + (i_n * HV + i_hv) * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_kda_fwd( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + initial_state: paddle.Tensor | None = None, + scale: float | None = None, + output_final_state: bool = False, + inplace_final_state: bool = True, + cu_seqlens: paddle.Tensor | None = None, + ssm_state_indices: paddle.Tensor | None = None, + num_accepted_tokens: paddle.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + lower_bound: float | None = None, + out: paddle.Tensor | None = None, + transpose_state_layout: bool = False, + **kwargs, +) -> tuple[paddle.Tensor, paddle.Tensor]: + if scale is None: + scale = k.shape[-1] ** -0.5 + + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + BV = 32 + + if out is None: + out = paddle.zeros_like(v) + else: + assert out.shape == v.shape + if inplace_final_state: + assert initial_state is not None + final_state = initial_state + elif output_final_state: + if transpose_state_layout: + final_state = paddle.empty([N, HV, V, K], dtype=paddle.float32) + else: + final_state = paddle.empty([N, HV, K, V], dtype=paddle.float32) + else: + final_state = None + + stride_init_state_token = initial_state.stride()[0] if initial_state is not None else 1 + stride_final_state_token = final_state.stride()[0] if final_state is not None else 1 + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()[0], 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (triton.cdiv(V, BV) * N * HV, ) + fused_recurrent_kda_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A_log=A_log, + dt_bias=dt_bias, + o=out, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + lower_bound=lower_bound, + scale=scale, + N=N, + T=T, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + USE_GATE_IN_KERNEL=use_gate_in_kernel, + TRANSPOSE_STATE=transpose_state_layout, + num_warps=4, + num_stages=2, + ) + + return out, final_state + + +@input_guard +def fused_recurrent_kda( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + A_log: paddle.Tensor | None = None, + dt_bias: paddle.Tensor | None = None, + scale: float | None = None, + initial_state: paddle.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + lower_bound: float | None = None, + cu_seqlens: paddle.Tensor | None = None, + transpose_state_layout: bool = False, + **kwargs, +) -> tuple[paddle.Tensor, paddle.Tensor]: + r""" + Args: + q (paddle.Tensor): + queries of shape `[B, T, H, K]`. + k (paddle.Tensor): + keys of shape `[B, T, H, K]`. + v (paddle.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (paddle.Tensor): + g (decays) of shape `[B, T, HV, K]`. + beta (paddle.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[float]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[paddle.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (Optional[bool]): + Whether to use L2 normalization in the kernel. Default: `False`. + cu_seqlens (paddle.Tensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + transpose_state_layout (bool): + Whether to use transposed state layout `[V, K]` instead of `[K, V]`. Default: `False`. + + Returns: + o (paddle.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (paddle.Tensor): + Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import paddle + >>> import paddle.nn.functional as F + >>> from einops import rearrange + >>> from flash_mask.linear_attn.ops.kda import fused_recurrent_kda + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = paddle.randn([B, T, H, K]) + >>> k = F.normalize(paddle.randn([B, T, H, K]), p=2, axis=-1) + >>> v = paddle.randn([B, T, HV, V]) + >>> g = F.log_sigmoid(paddle.rand([B, T, HV, K])) + >>> beta = paddle.rand([B, T, HV]).sigmoid() + >>> h0 = paddle.randn([B, HV, K, V]) + >>> o, ht = fused_recurrent_kda( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = paddle.to_tensor([0, 2048, 4096, 6144, 8192], dtype=paddle.int64) + >>> o_var, ht_var = fused_recurrent_kda( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.", + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o, final_state = fused_recurrent_kda_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + A_log=A_log, + dt_bias=dt_bias, + scale=scale, + initial_state=initial_state, + inplace_final_state=False, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + lower_bound=lower_bound, + cu_seqlens=cu_seqlens, + transpose_state_layout=transpose_state_layout, + ) + return o, final_state diff --git a/flashmask/flash_mask/linear_attn/ops/kda/gate.py b/flashmask/flash_mask/linear_attn/ops/kda/gate.py new file mode 100644 index 00000000000..174adc21dbf --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/gate.py @@ -0,0 +1,478 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +# This file is modified and supported by the Moonshot AI Team + +import paddle +import paddle.nn.functional as F +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.index import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp +from flash_mask.linear_attn.ops.utils.softplus import softplus +from flash_mask.linear_attn.utils import IS_AMD, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_shared_mem, input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] +BT_LIST_AUTOTUNE = [32, 64, 128] +NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if IS_AMD else [4, 8, 16, 32] + + +def naive_kda_gate( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor: + """ + Paddle reference implementation for KDA gate computation. + + Computes: g = -A_log.exp().unsqueeze(-1) * softplus(g + dt_bias.view(g.shape[-2:])) + + Args: + g (paddle.Tensor): + Input tensor of shape `[..., H, K]`. + A_log (paddle.Tensor): + Parameter tensor with `H` elements. + dt_bias (paddle.Tensor | None): + Optional bias tensor added to `g` before activation, shape `[H * K]`. + + Returns: + Output tensor of shape `[..., H, K]` . + """ + H, _ = g.shape[-2:] + g = g.cast(paddle.float32) + if dt_bias is not None: + g = g + dt_bias.reshape([H, -1]) + + g = (-A_log.reshape([H, 1]).cast(paddle.float32).exp() * F.softplus(g.cast(paddle.float32))).cast(output_dtype) + return g + + +def naive_kda_lowerbound_gate( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + lower_bound: float = -5.0, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor: + H, _ = g.shape[-2:] + g = g.cast(paddle.float32) + if dt_bias is not None: + g = g + dt_bias.reshape([H, -1]) + g = lower_bound * F.sigmoid(A_log.reshape([H, 1]).exp() * g) + return g.cast(output_dtype) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + "HAS_BIAS": lambda args: args["dt_bias"] is not None, + "HAS_BETA": lambda args: args["beta"] is not None, + 'USE_LOWER_BOUND': lambda args: args['lower_bound'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in BT_LIST_AUTOTUNE + for num_warps in NUM_WARPS_AUTOTUNE + for num_stages in [2, 3] + ], + key=["H", "D"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def kda_gate_fwd_kernel( + g, + A_log, + dt_bias, + beta, + yg, + yb, + lower_bound, + T, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_BETA: tl.constexpr, + USE_LOWER_BOUND: tl.constexpr, +): + i_t, i_h = tl.program_id(0), tl.program_id(1) + + b_A = tl.load(A_log + i_h).to(tl.float32) + + p_g = tl.make_block_ptr(g + i_h * D, (T, D), (H * D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_yg = tl.make_block_ptr(yg + i_h * D, (T, D), (H * D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + if HAS_BIAS: + p_b = tl.make_block_ptr(dt_bias, (H * D,), (1,), (i_h * D,), (BD,), (0,)) + b_g = b_g + tl.load(p_b, boundary_check=(0,)).to(tl.float32) + if not USE_LOWER_BOUND: + b_yg = -exp(b_A) * softplus(b_g) + else: + b_yg = lower_bound * tl.sigmoid(exp(b_A) * b_g) + tl.store(p_yg, b_yg.to(p_yg.dtype.element_ty), boundary_check=(0, 1)) + + if HAS_BETA: + p_b = tl.make_block_ptr(beta + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_yb = tl.make_block_ptr(yb + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_yb = tl.sigmoid(tl.load(p_b, boundary_check=(0,)).to(tl.float32)) + tl.store(p_yb, b_yb.to(p_yb.dtype.element_ty), boundary_check=(0,)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + "HAS_BIAS": lambda args: args["dt_bias"] is not None, + "HAS_BETA": lambda args: args["beta"] is not None, + 'USE_LOWER_BOUND': lambda args: args['lower_bound'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS_AUTOTUNE + for num_stages in [2, 3] + ], + key=["H", "D"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def kda_gate_bwd_kernel( + g, + A_log, + dt_bias, + beta, + dyg, + dyb, + dg, + dA, + dbeta, + lower_bound, + T, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_BETA: tl.constexpr, + USE_LOWER_BOUND: tl.constexpr, +): + i_t, i_h = tl.program_id(0), tl.program_id(1) + + b_A = tl.load(A_log + i_h).to(tl.float32) + + p_g = tl.make_block_ptr(g + i_h * D, (T, D), (H * D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_h * D, (T, D), (H * D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_dyg = tl.make_block_ptr(dyg + i_h * D, (T, D), (H * D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_dyg = tl.load(p_dyg, boundary_check=(0, 1)).to(tl.float32) + + if HAS_BIAS: + p_b = tl.make_block_ptr(dt_bias, (H * D,), (1,), (i_h * D,), (BD,), (0,)) + b_g = b_g + tl.load(p_b, boundary_check=(0,)).to(tl.float32) + + # [BT, BD] + if not USE_LOWER_BOUND: + b_A = -exp(b_A) + b_yg = b_A * softplus(b_g) + b_dg = b_A * (b_dyg * tl.sigmoid(b_g)) + b_dA = tl.sum(tl.sum(b_dyg * b_yg, 1), 0) + else: + b_A = exp(b_A) + b_inner = b_A * b_g + b_sig = tl.sigmoid(b_inner) + b_dsig = b_sig * (1.0 - b_sig) + # Common term: dy * (LB * dsig) + b_d_inner_term = b_dyg * (lower_bound * b_dsig) + # dg = d_inner_term * A + b_dg = b_d_inner_term * b_A + b_dA = tl.sum(tl.sum(b_dg * b_g, 1), 0) + + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dA + i_t * H + i_h, b_dA) + + if HAS_BETA: + p_b = tl.make_block_ptr(beta + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(dbeta + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dyb = tl.make_block_ptr(dyb + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_b = tl.load(p_b, boundary_check=(0,)).to(tl.float32) + b_db = tl.load(p_dyb, boundary_check=(0,)).to(tl.float32) * b_b * (1.0 - b_b) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def kda_gate_fwd( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + lower_bound: float | None = None, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor: + H, K = g.shape[-2:] + T = g.numel().item() // (H * K) + + yg = paddle.empty_like(g).cast(output_dtype) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]), H) + + kda_gate_fwd_kernel[grid]( + g=g, + A_log=A_log, + dt_bias=dt_bias, + beta=None, + yg=yg, + yb=None, + T=T, + H=H, + D=K, + BD=triton.next_power_of_2(K), + lower_bound=lower_bound, + ) + return yg + + +def kda_gate_bwd( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + dyg: paddle.Tensor | None = None, + lower_bound: float | None = None, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor | None]: + H, K = g.shape[-2:] + T = g.numel().item() // (H * K) + BT = 32 + NT = triton.cdiv(T, BT) + + dg = paddle.empty_like(g).cast(paddle.float32) + dA = paddle.empty([NT, H], dtype=paddle.float32) + + grid = (triton.cdiv(T, BT), H) + kda_gate_bwd_kernel[grid]( + g=g, + A_log=A_log, + dt_bias=dt_bias, + beta=None, + dyg=dyg, + dyb=None, + dg=dg, + dA=dA, + dbeta=None, + T=T, + H=H, + D=K, + BT=BT, + BD=triton.next_power_of_2(K), + lower_bound=lower_bound, + ) + + # Compute dbias from dg while still in float32 (before casting to g.dtype which may be float16). + # PyTorch's .sum() on float16 internally promotes to float32 for accumulation, + # but Paddle's .sum() accumulates directly in the input dtype. Computing dbias + # in float32 avoids precision loss when summing over B*T elements. + dbias = dg.reshape([-1, H * K]).sum(axis=0).cast(dt_bias.dtype) if dt_bias is not None else None + dg = dg.reshape(g.shape).cast(g.dtype) + dA = dA.sum(axis=0).reshape(A_log.shape).cast(A_log.dtype) + + return dg, dA, dbias + + +class KDAGateFunction(paddle.autograd.PyLayer): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + lower_bound: float | None = None, + output_dtype: paddle.dtype = paddle.float32, + ) -> paddle.Tensor: + yg = kda_gate_fwd( + g=g, + A_log=A_log, + dt_bias=dt_bias, + lower_bound=lower_bound, + output_dtype=output_dtype + ) + ctx.save_for_backward(g, A_log, dt_bias) + ctx.lower_bound = lower_bound + # Paddle PyLayer backward must return exactly as many values as tensor inputs. + _forward_args = [g, A_log, dt_bias, lower_bound, output_dtype] + ctx._tensor_mask = tuple(isinstance(a, paddle.Tensor) for a in _forward_args) + ctx._needs_grad = tuple( + isinstance(a, paddle.Tensor) and not a.stop_gradient for a in _forward_args + ) + return yg + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, dyg: paddle.Tensor): + g, A_log, dt_bias = ctx.saved_tensor() + dg, dA, dbias = kda_gate_bwd( + g=g, + A_log=A_log, + dt_bias=dt_bias, + dyg=dyg, + lower_bound=ctx.lower_bound + ) + all_grads = [dg, dA, dbias, None, None] + return tuple( + g if needs_grad else None + for g, is_tensor, needs_grad in zip(all_grads, ctx._tensor_mask, ctx._needs_grad) + if is_tensor + ) + + +def fused_kda_gate( + g: paddle.Tensor, + A_log: paddle.Tensor, + dt_bias: paddle.Tensor | None = None, + lower_bound: float | None = None, + output_dtype: paddle.dtype = paddle.float32, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Fused KDA gate computation with autograd support. + + Computes: g = -A_log.exp().unsqueeze(-1) * softplus(g + dt_bias.view(g.shape[-2:])) + + Args: + g (paddle.Tensor): + Input tensor of shape `[..., H, K]`. + A_log (paddle.Tensor): + Parameter tensor with `H` elements. + dt_bias (paddle.Tensor | None): + Optional bias tensor added to `g` before activation, shape `[H * K]`. + + Returns: + Output tensor of shape `[..., H, K]`. + """ + return KDAGateFunction.apply(g, A_log, dt_bias, lower_bound, output_dtype) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + "HAS_BIAS": lambda args: args["dt_bias"] is not None, + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_LOWER_BOUND': lambda args: args['lower_bound'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps) + for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=['H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def kda_gate_chunk_cumsum_vector_kernel( + s, + A_log, + dt_bias, + o, + scale, + cu_seqlens, + chunk_indices, + lower_bound, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_LOWER_BOUND: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + + # Apply dt_bias if exists + if HAS_BIAS: + p_b = tl.make_block_ptr(dt_bias + i_h * S, (S,), (1,), (i_s * BS,), (BS,), (0,)) + b_bias = tl.load(p_b, boundary_check=(0,)).to(tl.float32) + b_s = b_s + b_bias[None, :] + + b_A = tl.load(A_log + i_h).to(tl.float32) + if not USE_LOWER_BOUND: + # Apply gate: -exp(A_log) * softplus(g + bias) + b_gate = -exp(b_A) * softplus(b_s) + else: + b_gate = lower_bound * tl.sigmoid(exp(b_A) * b_s) + + # Apply chunk local cumsum + if REVERSE: + b_o = tl.cumsum(b_gate, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_gate, axis=0) + + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@input_guard +def kda_gate_chunk_cumsum( + g: paddle.Tensor, + A_log: paddle.Tensor, + chunk_size: int, + scale: float = None, + dt_bias: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + output_dtype: paddle.dtype | None = paddle.float32, + chunk_indices: paddle.Tensor | None = None, + lower_bound: float | None = None, + **kwargs, +) -> paddle.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + assert len(g.shape) == 4 + B, T, H, S = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + + g_org, g = g, paddle.empty_like(g).cast(output_dtype or g.dtype) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + kda_gate_chunk_cumsum_vector_kernel[grid]( + s=g_org, + A_log=A_log, + dt_bias=dt_bias, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=lower_bound, + T=T, + H=H, + S=S, + BT=BT, + REVERSE=False, + ) + return g diff --git a/flashmask/flash_mask/linear_attn/ops/kda/naive.py b/flashmask/flash_mask/linear_attn/ops/kda/naive.py new file mode 100644 index 00000000000..d089f87885d --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/naive.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +from einops import rearrange + + +def naive_recurrent_kda( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, +): + dtype = v.dtype + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = K ** -0.5 + + q, k, v, g, beta = [x.cast(paddle.float32) for x in [q, k, v, g, beta]] + q = q * scale + + S = paddle.zeros([B, H, K, V], dtype=q.dtype) + if initial_state is not None: + S += initial_state + o = paddle.zeros_like(v) + for i in range(0, T): + q_i, k_i, v_i, g_i, b_i = q[:, i], k[:, i], v[:, i], g[:, i], beta[:, i] + S = S * g_i[..., None].exp() + S = S + paddle.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2)) + o[:, i] = paddle.einsum('b h k, b h k v -> b h v', q_i, S) + if not output_final_state: + S = None + return o.cast(dtype), S + + +def naive_chunk_kda( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + g: paddle.Tensor, + beta: paddle.Tensor, + scale: float | None = None, + initial_state: paddle.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, +): + dtype = v.dtype + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + NT = T // BT + if scale is None: + scale = K ** -0.5 + assert T % BT == 0 + + q, k, v, g, beta = [rearrange(x, 'b (n c) h ... -> b h n c ...', c=BT).cast(paddle.float32) for x in [q, k, v, g, beta]] + q = q * scale + g = g.cumsum(-2) + + # note that diagonal is masked. + mask = paddle.triu(paddle.ones([BT, BT], dtype='bool'), diagonal=0) + + A = paddle.zeros([*q.shape[:-1], BT], dtype=paddle.float32) + for i in range(BT): + k_i = k[..., i, :] + g_i = g[..., i:i+1, :] + A[..., i] = paddle.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i) + A = A * beta[..., None] + + A = -(A * (~mask).cast(A.dtype)) + for i in range(1, BT): + A[..., i, :i] = A[..., i, :i].clone() + (A[..., i, :, None].clone() * A[..., :, :i].clone()).sum(-2) + A = (A + paddle.eye(BT, dtype=paddle.float32)) * beta[..., None, :] + + w = A @ (g.exp() * k) + u = A @ v + + S = paddle.zeros([B, H, K, V], dtype=q.dtype) + if initial_state is not None: + S += initial_state + o = paddle.zeros_like(v) + mask = paddle.triu(paddle.ones([BT, BT], dtype='bool'), diagonal=1) + for i in range(0, NT): + # [B, H, BT, ...] + q_i, k_i, u_i, g_i, w_i = q[:, :, i], k[:, :, i], u[:, :, i], g[:, :, i], w[:, :, i] + A = paddle.zeros([B, H, BT, BT], dtype=paddle.float32) + for j in range(BT): + k_j = k[:, :, i, j] + g_j = g[:, :, i, j:j+1, :] + A[..., j] = paddle.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) + A = A * (~mask).cast(A.dtype) + v_i = u_i - w_i @ S + o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i + S = S * rearrange(g_i[:, :, -1].exp(), 'b h k -> b h k 1') + S += rearrange((g_i[:, :, -1:] - g_i).exp() * k_i, 'b h c k -> b h k c') @ v_i + if not output_final_state: + S = None + return rearrange(o, 'b h n c d -> b (n c) h d').cast(dtype), S diff --git a/flashmask/flash_mask/linear_attn/ops/kda/wy_fast.py b/flashmask/flash_mask/linear_attn/ops/kda/wy_fast.py new file mode 100644 index 00000000000..34d527b6ca4 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/kda/wy_fast.py @@ -0,0 +1,337 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +from functools import lru_cache + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import exp2 +from flash_mask.linear_attn.utils import autotune_cache_kwargs, check_shared_mem +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'STORE_QG': lambda args: args['qg'] is not None, + 'STORE_KG': lambda args: args['kg'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kda_kernel( + q, + k, + qg, + kg, + v, + beta, + w, + u, + A, + gk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + STORE_QG: tl.constexpr, + STORE_KG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_b[:, None] + + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) + b_kb *= exp2(b_gk) + if STORE_QG: + p_q = tl.make_block_ptr(q + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_qg = tl.make_block_ptr(qg + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * exp2(b_gk) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) + if STORE_KG: + last_idx = min(i_t * BT + BT, T) - 1 + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load(gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.).to(tl.float32) + b_kg = b_k * tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], exp2(b_gn[None, :] - b_gk), 0) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) + + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kda_kernel( + k, + v, + beta, + gk, + A, + dA, + dw, + du, + dk, + dk2, + dv, + db, + dg, + dg2, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_b = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(db + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_b = tl.load(p_b, boundary_check=(0,)) + b_db = tl.zeros([BT], dtype=tl.float32) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2 + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2 + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk_exp = exp2(tl.load(p_gk, boundary_check=(0, 1))) + b_kbg = b_k * b_b[:, None] * b_gk_exp + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) + b_dkbg = tl.dot(b_A, b_dw) + b_dk = b_dkbg * b_gk_exp * b_b[:, None] + tl.load(p_dk, boundary_check=(0, 1)) + b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1) + b_dg = b_kbg * b_dkbg + tl.load(p_dg, boundary_check=(0, 1)) + + tl.store(p_dk2, b_dk.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_vb)) + b_dvb = tl.dot(b_A, b_du) + b_dv = b_dvb * b_b[:, None] + b_db += tl.sum(b_dvb * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + + b_dA = tl.where(m_A, -b_dA, 0) + + # if using gk, save dA first and handle dk in another kernel + p_dA = tl.make_block_ptr(dA + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def recompute_w_u_fwd( + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + A: paddle.Tensor, + q: paddle.Tensor | None = None, + gk: paddle.Tensor | None = None, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor | None, paddle.Tensor | None]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 64 + BV = 64 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = paddle.empty_like(k) + u = paddle.empty_like(v) + qg = paddle.empty_like(q) if q is not None else None + kg = paddle.empty_like(k) if gk is not None else None + recompute_w_u_fwd_kda_kernel[(NT, B*H)]( + q=q, + k=k, + qg=qg, + kg=kg, + v=v, + beta=beta, + w=w, + u=u, + A=A, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u, qg, kg + + +def prepare_wy_repr_bwd( + k: paddle.Tensor, + v: paddle.Tensor, + beta: paddle.Tensor, + gk: paddle.Tensor, + A: paddle.Tensor, + dk: paddle.Tensor, + dw: paddle.Tensor, + du: paddle.Tensor, + dg: paddle.Tensor, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is None: + BK, BV, NT = _wy_launch_meta(k.place.gpu_device_id(), T, K, V, BT) + else: + const_tiling = _wy_tiling(k.place.gpu_device_id()) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NT = len(chunk_indices) + + dk2 = paddle.empty_like(dk, dtype=paddle.float32) + dv = paddle.empty_like(v) + dg2 = paddle.empty_like(gk, dtype=paddle.float32) + dA = paddle.empty_like(A, dtype=paddle.float32) + db = paddle.empty_like(beta, dtype=paddle.float32) + prepare_wy_repr_bwd_kda_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + gk=gk, + A=A, + dA=dA, + dw=dw, + du=du, + dk=dk, + dk2=dk2, + dv=dv, + db=db, + dg=dg, + dg2=dg2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + dk = dk2 + dg = dg2 + return dk, dv, db, dg, dA +@lru_cache(maxsize=None) +def _wy_tiling(device_idx: int) -> int: + return 64 if check_shared_mem() else 32 + + +@lru_cache(maxsize=None) +def _wy_launch_meta(device_idx: int, T: int, K: int, V: int, BT: int) -> tuple[int, int, int]: + const_tiling = _wy_tiling(device_idx) + BK = min(max(triton.next_power_of_2(K), 16), const_tiling) + BV = min(max(triton.next_power_of_2(V), 16), const_tiling) + NT = triton.cdiv(T, BT) + return BK, BV, NT + diff --git a/flashmask/flash_mask/linear_attn/ops/utils/__init__.py b/flashmask/flash_mask/linear_attn/ops/utils/__init__.py new file mode 100644 index 00000000000..38ba5644bca --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Adapted from fla/ops/utils/__init__.py for PaddlePaddle + +from .cumsum import ( + chunk_global_cumsum, + chunk_global_cumsum_scalar, + chunk_global_cumsum_vector, + chunk_local_cumsum, + chunk_local_cumsum_scalar, + chunk_local_cumsum_vector, +) +from .index import ( + get_max_num_splits, + prepare_chunk_indices, + prepare_chunk_offsets, + prepare_cu_seqlens_from_lens, + prepare_cu_seqlens_from_mask, + prepare_lens, + prepare_lens_from_mask, + prepare_position_ids, + prepare_sequence_ids, + prepare_token_indices, +) +from .softmax import softmax_bwd, softmax_fwd +from .softplus import softplus +from .solve_tril import solve_tril diff --git a/flashmask/flash_mask/linear_attn/ops/utils/constant.py b/flashmask/flash_mask/linear_attn/ops/utils/constant.py new file mode 100644 index 00000000000..dac7fbea17a --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/constant.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# Same as fla/ops/utils/constant.py - pure constants, no framework dependency + +RCP_LN2 = 1.4426950216 diff --git a/flashmask/flash_mask/linear_attn/ops/utils/cumsum.py b/flashmask/flash_mask/linear_attn/ops/utils/cumsum.py new file mode 100644 index 00000000000..4e666a91f0c --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/cumsum.py @@ -0,0 +1,477 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.index import prepare_chunk_indices +from flash_mask.linear_attn.utils import autotune_cache_kwargs, check_shared_mem, input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps) + for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_o = tl.cumsum(b_s, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [32, 64, 128, 256] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=['B', 'H', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_global_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_ss = tl.sum(b_s, 0) + if REVERSE: + b_o = -b_o + b_ss + b_s + b_o += b_z + if i_c >= 0: + b_z += b_ss + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [16, 32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=['B', 'H', 'S', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def chunk_global_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([BS], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0, reverse=True) + else: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_c *= scale + tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + b_z += tl.sum(b_s, 0) + + +def chunk_local_cumsum_scalar( + g: paddle.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: paddle.Tensor | None = None, + head_first: bool = False, + output_dtype=paddle.float32, + chunk_indices: paddle.Tensor | None = None, +) -> paddle.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, paddle.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +def chunk_local_cumsum_vector( + g: paddle.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: paddle.Tensor | None = None, + head_first: bool = False, + output_dtype=paddle.float32, + chunk_indices: paddle.Tensor | None = None, +) -> paddle.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + + g_org, g = g, paddle.empty_like(g, dtype=output_dtype or g.dtype) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +@input_guard +def chunk_global_cumsum_scalar( + s: paddle.Tensor, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, + scale: float = None, + head_first: bool = False, + output_dtype=paddle.float32, +) -> paddle.Tensor: + if head_first: + B, H, T = s.shape + else: + B, T, H = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + + z = paddle.empty_like(s, dtype=output_dtype or s.dtype) + grid = (N * H,) + chunk_global_cumsum_scalar_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum_vector( + s: paddle.Tensor, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, + scale: float = None, + head_first: bool = False, + output_dtype=paddle.float32, +) -> paddle.Tensor: + if head_first: + B, H, T, S = s.shape + else: + B, T, H, S = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BS = min(32, triton.next_power_of_2(S)) + + z = paddle.empty_like(s, dtype=output_dtype or s.dtype) + grid = (triton.cdiv(S, BS), N * H) + chunk_global_cumsum_vector_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + S=S, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum( + s: paddle.Tensor, + reverse: bool = False, + cu_seqlens: paddle.Tensor | None = None, + scale: float = None, + head_first: bool = False, + output_dtype=paddle.float32, +) -> paddle.Tensor: + if cu_seqlens is not None: + assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(s.shape) == 3: + return chunk_global_cumsum_scalar( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError( + f"Unsupported input shape {s.shape}, " + f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " + f"or [B, H, T]/[B, H, T, D] otherwise", + ) + + +@input_guard +def chunk_local_cumsum( + g: paddle.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: paddle.Tensor | None = None, + head_first: bool = False, + output_dtype=paddle.float32, + chunk_indices: paddle.Tensor | None = None, + **kwargs, +) -> paddle.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise", + ) diff --git a/flashmask/flash_mask/linear_attn/ops/utils/index.py b/flashmask/flash_mask/linear_attn/ops/utils/index.py new file mode 100644 index 00000000000..589150af38b --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/index.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import paddle.nn.functional as F +import triton +import triton.language as tl + +from flash_mask.linear_attn.utils import autotune_cache_kwargs, tensor_cache +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [4, 8, 16, 32] + ], + key=['B'], + **autotune_cache_kwargs, +) +@triton.jit +def prepare_position_ids_kernel( + y, + cu_seqlens, + B: tl.constexpr, +): + i_n = tl.program_id(0) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + + o = tl.arange(0, B) + for i in range(0, tl.cdiv(T, B) * B, B): + o_i = o + i + tl.store(y + bos + o_i, o_i, o_i < T) + + +@tensor_cache +def prepare_lens(cu_seqlens: paddle.Tensor) -> paddle.Tensor: + return paddle.diff(cu_seqlens) + + +@tensor_cache +def prepare_lens_from_mask(mask: paddle.Tensor) -> paddle.Tensor: + return mask.sum(axis=-1).cast(paddle.int32) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: paddle.Tensor, + dtype=paddle.int32, +) -> paddle.Tensor: + return F.pad(lens.cumsum(axis=0).cast(dtype), (1, 0)) + + +@tensor_cache +def prepare_cu_seqlens_from_mask( + mask: paddle.Tensor, + dtype=paddle.int32, +) -> paddle.Tensor: + return prepare_cu_seqlens_from_lens(prepare_lens_from_mask(mask), dtype) + + +@tensor_cache +def prepare_split_cu_seqlens( + batch_size: int, + seq_len: int, + split_size: int, + cu_seqlens: paddle.Tensor | None = None, + dtype=paddle.int32, +) -> paddle.Tensor: + if cu_seqlens is None: + total_tokens = batch_size * seq_len + cu_seqlens = list(range(0, total_tokens, seq_len)) + [total_tokens] + else: + cu_seqlens = cu_seqlens.tolist() + return paddle.to_tensor( + [ + i + for bos, eos in zip(cu_seqlens[:-1], cu_seqlens[1:], strict=False) + for i in range(bos, eos, split_size) + ] + [cu_seqlens[-1]], + dtype=dtype, + ) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: paddle.Tensor, cu_seqlens_cpu: paddle.Tensor | None = None) -> paddle.Tensor: + if cu_seqlens_cpu is not None: + return paddle.concat([ + paddle.arange(n, dtype=cu_seqlens.dtype) + for n in prepare_lens(cu_seqlens_cpu).unbind() + ]) + return paddle.concat([ + paddle.arange(n, dtype=cu_seqlens.dtype) + for n in prepare_lens(cu_seqlens).unbind() + ]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: paddle.Tensor, cu_seqlens_cpu: paddle.Tensor | None = None) -> paddle.Tensor: + return (prepare_position_ids(cu_seqlens, cu_seqlens_cpu) == 0).cast(paddle.int64).cumsum(axis=0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: paddle.Tensor, cu_seqlens_cpu: paddle.Tensor | None = None) -> paddle.Tensor: + position_ids = prepare_position_ids(cu_seqlens, cu_seqlens_cpu) + return paddle.stack([prepare_sequence_ids(cu_seqlens, cu_seqlens_cpu), position_ids], 1).cast(cu_seqlens.dtype) + + +@tensor_cache +def prepare_chunk_indices( + cu_seqlens: paddle.Tensor, + chunk_size: int, + cu_seqlens_cpu: paddle.Tensor | None = None, +) -> paddle.Tensor: + if cu_seqlens_cpu is not None: + indices = paddle.concat([paddle.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens_cpu), chunk_size).tolist()]) + return paddle.stack([(indices == 0).cast(paddle.int64).cumsum(axis=0) - 1, indices], 1).cast(cu_seqlens.dtype) + indices = paddle.concat([paddle.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return paddle.stack([(indices == 0).cast(paddle.int64).cumsum(axis=0) - 1, indices], 1).cast(cu_seqlens.dtype) + + +@tensor_cache +def prepare_chunk_offsets( + cu_seqlens: paddle.Tensor, + chunk_size: int, +) -> paddle.Tensor: + return F.pad(triton.cdiv(prepare_lens(cu_seqlens), chunk_size), (1, 0), value=0).cumsum(axis=-1) + + +@tensor_cache +def get_max_num_splits( + cu_seqlens: paddle.Tensor, + chunk_size: int, + cu_seqlens_cpu: paddle.Tensor | None = None +) -> int: + if cu_seqlens_cpu is not None: + return triton.cdiv(int(max(prepare_lens(cu_seqlens_cpu))), chunk_size) + return triton.cdiv(int(max(prepare_lens(cu_seqlens))), chunk_size) diff --git a/flashmask/flash_mask/linear_attn/ops/utils/op.py b/flashmask/flash_mask/linear_attn/ops/utils/op.py new file mode 100644 index 00000000000..edb5471d5e3 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/op.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import os + +import triton +import triton.language as tl +import triton.language.extra.libdevice as tldevice + +from flash_mask.linear_attn.utils import IS_GATHER_SUPPORTED + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + @triton.jit + def exp(x): return tldevice.fast_expf(x.to(tl.float32)) + @triton.jit + def exp2(x): return tldevice.exp2(x.to(tl.float32)) + @triton.jit + def log(x): return tldevice.fast_logf(x.to(tl.float32)) + @triton.jit + def log2(x): return tldevice.fast_log2f(x.to(tl.float32)) + @triton.jit + def tanh(x): return tldevice.fast_tanhf(x.to(tl.float32)) +else: + @triton.jit + def exp(x): return tl.exp(x.to(tl.float32)) + @triton.jit + def exp2(x): return tl.math.exp2(x.to(tl.float32)) + @triton.jit + def log(x): return tl.log(x.to(tl.float32)) + @triton.jit + def log2(x): return tl.log2(x.to(tl.float32)) + @triton.jit + def tanh(x): return tldevice.tanh(x.to(tl.float32)) + + +if not IS_GATHER_SUPPORTED: + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None +else: + gather = tl.gather + + +if hasattr(triton.language, '_experimental_make_tensor_descriptor'): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, 'make_tensor_descriptor'): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/flashmask/flash_mask/linear_attn/ops/utils/softmax.py b/flashmask/flash_mask/linear_attn/ops/utils/softmax.py new file mode 100644 index 00000000000..c80660e828a --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/softmax.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.op import exp +from flash_mask.linear_attn.utils import IS_AMD, autotune_cache_kwargs +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +NUM_WARPS_AUTOTUNE = [1, 2, 4, 8, 16] if IS_AMD else [1, 2, 4, 8, 16, 32] + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS_AUTOTUNE + ], + key=['D'], + **autotune_cache_kwargs, +) +@triton.jit +def softmax_fwd_kernel( + x, + p, + D: tl.constexpr, + B: tl.constexpr, +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < D + + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) + b_m = tl.max(b_x, 0) + b_x = exp(b_x - b_m) + b_p = b_x / tl.sum(b_x, 0) + + tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d) + + +@enable_compat_on_triton_kernel +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS_AUTOTUNE + ], + key=['D'], + **autotune_cache_kwargs, +) +@triton.jit +def softmax_bwd_kernel( + p, + dp, + ds, + D: tl.constexpr, + B: tl.constexpr, +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < D + + b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.) + b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.) + b_pp = tl.sum(b_p * b_dp, 0) + b_ds = b_p * b_dp - b_p * b_pp + tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d) + + +def softmax_fwd( + x: paddle.Tensor, + dtype=paddle.float32, +) -> paddle.Tensor: + shape = x.shape + x = x.reshape([-1, x.shape[-1]]) + + N, D = x.shape + B = triton.next_power_of_2(D) + + p = paddle.empty_like(x).cast(dtype) + softmax_fwd_kernel[(N,)]( + x=x, + p=p, + D=D, + B=B, + ) + return p.reshape(shape) + + +def softmax_bwd( + p: paddle.Tensor, + dp: paddle.Tensor, + dtype=paddle.float32, +) -> paddle.Tensor: + shape = p.shape + p = p.reshape([-1, p.shape[-1]]) + ds = paddle.empty_like(p).cast(dtype) + + N, D = p.shape + B = triton.next_power_of_2(D) + softmax_bwd_kernel[(N,)]( + p=p, + dp=dp, + ds=ds, + D=D, + B=B, + ) + return ds.reshape(shape) diff --git a/flashmask/flash_mask/linear_attn/ops/utils/softplus.py b/flashmask/flash_mask/linear_attn/ops/utils/softplus.py new file mode 100644 index 00000000000..958cbde1d0c --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/softplus.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +# REVISED FROM +# https://github.com/shawntan/stickbreaking-attention/blob/main/stickbreaking_attention/sb_varlen/softplus.py + +import triton +from triton import language as tl + +from flash_mask.linear_attn.utils import IS_NVIDIA + + +def _generate_softplus(num_pack): + template = """ + .reg .pred p; + setp.gt.f32 p, ${in_reg}, 20.; + @p mov.f32 ${out_reg}, ${in_reg}; + @!p mul.f32 ${out_reg}, ${in_reg}, 1.4426950408889634; + @!p ex2.approx.ftz.f32 ${out_reg}, ${out_reg}; + @!p add.f32 ${out_reg}, ${out_reg}, 1.0; + @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; + @!p mul.f32 ${out_reg}, ${out_reg}, 0.6931471805599453; + """ + out_str = "" + + for i in range(num_pack): + inner_str = template.format(out_reg=i, in_reg=i + num_pack) + out_str += "{" + inner_str + "}\n" + # flatten out because torch.compile doesn't like newlines + out_str = " ".join(out_str.split("\n")) + return out_str + + +def _generate_softplus2(num_pack): + template = """ + .reg .pred p; + setp.gt.f32 p, ${in_reg}, 15.; + @p mov.f32 ${out_reg}, ${in_reg}; + @!p ex2.approx.ftz.f32 ${out_reg}, ${in_reg}; + @!p add.f32 ${out_reg}, ${out_reg}, 1.0; + @!p lg2.approx.ftz.f32 ${out_reg}, ${out_reg}; + """ + out_str = "" + + for i in range(num_pack): + inner_str = template.format(out_reg=i, in_reg=i + num_pack) + out_str += "{" + inner_str + "}\n" + # flatten out because torch.compile doesn't like newlines + out_str = " ".join(out_str.split("\n")) + return out_str + + +def _generate_constraints(num_pack): + return ",".join("=r" for i in range(num_pack)) + "," + ",".join("r" for i in range(num_pack)) + + +_NUM_REG = 1 +s_softplus: tl.constexpr = tl.constexpr(_generate_softplus(_NUM_REG)) +s_softplus2: tl.constexpr = tl.constexpr(_generate_softplus2(_NUM_REG)) +s_constraints: tl.constexpr = tl.constexpr(_generate_constraints(_NUM_REG)) +NUM_REG: tl.constexpr = tl.constexpr(_NUM_REG) + + +@triton.jit +def softplus_nv(x): + # equivalent to: + # return tl.where(x < 20.0, tl.math.log(1 + tl.math.exp(x)), x) + return tl.inline_asm_elementwise( + asm=s_softplus, + constraints=s_constraints, + pack=NUM_REG, + args=[ + x, + ], + dtype=tl.float32, + is_pure=True, + ) + + +@triton.jit +def softplus_triton(x): + return tl.where(x < 20.0, tl.math.log(1 + tl.math.exp(x)), x) + + +@triton.jit +def softplus2_nv(x): + # equivalent to: + # return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) + return tl.inline_asm_elementwise( + asm=s_softplus2, + constraints=s_constraints, + pack=NUM_REG, + args=[ + x, + ], + dtype=tl.float32, + is_pure=True, + ) + + +@triton.jit +def softplus2_triton(x): + return tl.where(x < 15.0, tl.math.log2(1 + tl.math.exp2(x)), x) + + +if IS_NVIDIA: + softplus = softplus_nv + softplus2 = softplus2_nv +else: + softplus = softplus_triton + softplus2 = softplus2_triton diff --git a/flashmask/flash_mask/linear_attn/ops/utils/solve_tril.py b/flashmask/flash_mask/linear_attn/ops/utils/solve_tril.py new file mode 100644 index 00000000000..904a9920fd5 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/ops/utils/solve_tril.py @@ -0,0 +1,401 @@ +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# For a list of all contributors, visit: +# https://github.com/fla-org/flash-linear-attention/graphs/contributors + +import os + +import paddle +import triton +import triton.language as tl + +from flash_mask.linear_attn.ops.utils.index import prepare_chunk_indices +from flash_mask.linear_attn.ops.utils.op import make_tensor_descriptor +from flash_mask.linear_attn.utils import IS_TMA_SUPPORTED, autotune_cache_kwargs, input_guard +from flash_mask.linear_attn.triton_utils import enable_compat_on_triton_kernel + +FLA_TRIL_PRECISION = os.environ.get('FLA_TRIL_PRECISION', 'ieee') +assert FLA_TRIL_PRECISION in ['ieee', 'tf32', 'tf32x3'], \ + f"FLA_TRIL_PRECISION must be one of 'ieee', 'tf32', or 'tf32x3', but got {FLA_TRIL_PRECISION}" +DOT_PRECISION_AUTOTUNE_LIST = ["ieee"] if not IS_TMA_SUPPORTED else list({"ieee", FLA_TRIL_PRECISION}) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'DOT_PRECISION': 'ieee'}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['BT'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def solve_tril_16x16_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + A = A + (bos*H + i_h) * BT + Ai = Ai + (bos*H + i_h) * 16 + + offset = (i_t * 16) % BT + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = tl.where(m_A, b_A, 0) + else: + desc = make_tensor_descriptor(A, [T, BT], [H*BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H*16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = tl.where(m_A, b_A, 0) + b_A = -b_A + + for i in range(2, min(16, T - i_t * 16)): + # [16] + b_a = -tl.load(A + (i_t * 16 + i) * H*BT + o_i + offset) + b_a = tl.where(o_i < i, b_a, 0.) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 16), (H*16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'DOT_PRECISION': DOT_PRECISION}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + for DOT_PRECISION in DOT_PRECISION_AUTOTUNE_LIST + ], + key=['H', 'BT', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H*BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H*BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H*BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H*BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@enable_compat_on_triton_kernel +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'DOT_PRECISION': DOT_PRECISION}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + for DOT_PRECISION in DOT_PRECISION_AUTOTUNE_LIST + ], + key=['H', 'BT', 'IS_VARLEN'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_A_33 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_A_44 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H*BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H*BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H*BT + o_i) + b_a_11 = tl.where(o_i < i, b_a_11, 0.) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H*BT + o_i + 16) + b_a_22 = tl.where(o_i < i - 16, b_a_22, 0.) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H*BT + o_i + 32) + b_a_33 = tl.where(o_i < i - 32, b_a_33, 0.) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H*BT + o_i + 48) + b_a_44 = tl.where(o_i < i - 48, b_a_44, 0.) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, BT), (H*BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) + + b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION) + b_Ai_32 = -tl.dot(tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), b_Ai_22, input_precision=DOT_PRECISION) + b_Ai_43 = -tl.dot(tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), b_Ai_33, input_precision=DOT_PRECISION) + + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H*BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_33, b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_44, b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_31, b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_32, b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_41, b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_42, b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_43, b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@input_guard +def solve_tril( + A: paddle.Tensor, + cu_seqlens: paddle.Tensor | None = None, + chunk_indices: paddle.Tensor | None = None, + output_dtype=paddle.float32, +) -> paddle.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (paddle.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (paddle.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype: + The dtype of the output tensor. Default: `paddle.float32`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype + + B, T, H, BT = A.shape + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = paddle.zeros_like(A).cast(output_dtype) + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel + + merge_fn[NT, B * H]( + A=A, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + USE_TMA=IS_TMA_SUPPORTED, + ) + return Ai diff --git a/flashmask/flash_mask/linear_attn/triton_utils.py b/flashmask/flash_mask/linear_attn/triton_utils.py new file mode 100644 index 00000000000..8fa66804f52 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/triton_utils.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2026, Songlin Yang, Yu Zhang, Zhiyuan Li +# Adapted for PaddlePaddle + +import os +from contextlib import contextmanager + +import paddle +from functools import cache +from importlib.metadata import PackageNotFoundError, distribution + + +@cache +def _is_package_installed(dist_name: str) -> bool: + try: + distribution(dist_name) + return True + except PackageNotFoundError: + return False + + +# Pre-create Paddle triton driver in mixed torch environment +paddle_driver = None +if _is_package_installed("torch"): + with paddle.use_compat_guard(enable=True, silent=True): + from triton.runtime.driver import _create_driver + paddle_driver = _create_driver() + + +# --------------------------------------------------------------------------- +# Driver probe: captures the active triton driver *during* kernel execution. +# Disabled by default (zero overhead). Tests enable it via enable_driver_probe(). +# Set FLA_BENCHMARK=1 to keep probing disabled even if tests try to enable it. +# --------------------------------------------------------------------------- +_driver_probe_enabled: bool = False +_driver_probe_result: str = "not_probed" +_compat_wrapper_fastpath_depth: int = 0 + + +def enable_driver_probe(): + """Enable driver probing during kernel launch (for tests).""" + global _driver_probe_enabled, _driver_probe_result + if os.environ.get("FLA_BENCHMARK", "0") == "1": + return + _driver_probe_enabled = True + _driver_probe_result = "not_probed" + + +def disable_driver_probe(): + """Disable driver probing (restore zero overhead).""" + global _driver_probe_enabled + _driver_probe_enabled = False + + +def get_driver_probe_result() -> str: + """Return the driver framework detected during the last kernel launch.""" + return _driver_probe_result + + +def _detect_driver_framework(active_driver) -> str: + """Identify the framework behind a triton driver object.""" + fn = active_driver.get_current_stream + # Check __module__ first (most reliable) + mod = getattr(fn, '__module__', '') or '' + if 'paddle' in mod: + return 'paddle' + if 'torch' in mod: + return 'torch' + # Fallback: check string representation + fn_str = str(fn) + if 'paddle' in fn_str or '_get_current_raw_stream' in fn_str: + return 'paddle' + if 'torch' in fn_str or '_cuda_getCurrentRawStream' in fn_str: + return 'torch' + return 'unknown' + + +def _probe_active_driver(): + """Snapshot the active triton driver framework (called inside swap guard).""" + global _driver_probe_result + try: + from triton.runtime.driver import driver + _driver_probe_result = _detect_driver_framework(driver.active) + except Exception as e: + _driver_probe_result = f'error({e})' + + +def _wrap_probe_only(fn): + def wrapped_fn(*args, **kwargs): + if _driver_probe_enabled: + _probe_active_driver() + return fn(*args, **kwargs) + + return wrapped_fn + + +def swap_driver_guard(fn): + """Temporarily swap triton's active driver to Paddle driver.""" + from triton.runtime.driver import driver + + def wrapped_fn(*args, **kwargs): + if paddle_driver is None or driver.active is paddle_driver: + if _driver_probe_enabled: + _probe_active_driver() + return fn(*args, **kwargs) + driver.set_active(paddle_driver) + try: + if _driver_probe_enabled: + _probe_active_driver() + return fn(*args, **kwargs) + finally: + driver.reset_active() + + return wrapped_fn + + +def _should_bypass_compat_kernel_wrapper() -> bool: + if _compat_wrapper_fastpath_depth <= 0 or paddle_driver is None: + return False + try: + from triton.runtime.driver import driver + except Exception: + return False + return driver.active is paddle_driver + + +@contextmanager +def compat_kernel_wrapper_fastpath(): + """Allow compat-wrapped kernels to skip re-wrapping when Paddle driver is already active.""" + global _compat_wrapper_fastpath_depth + _compat_wrapper_fastpath_depth += 1 + try: + yield + finally: + _compat_wrapper_fastpath_depth -= 1 + + +@contextmanager +def activate_paddle_driver(): + """Activate the Paddle Triton driver for a wider Python region when available.""" + if paddle_driver is None: + yield + return + + from triton.runtime.driver import driver + + if driver.active is paddle_driver: + yield + return + + driver.set_active(paddle_driver) + try: + yield + finally: + driver.reset_active() + + +def enable_compat_on_triton_kernel(triton_kernel): + """ + Triton kernel compat decorator (ref: FastDeploy PR#6897). + + - No torch env: return original kernel (zero overhead, relies on global enable_compat) + - Has torch env: wrap kernel to use Paddle driver on launch + + Usage: + @enable_compat_on_triton_kernel # outermost + @triton.autotune(...) # optional + @triton.jit + def my_kernel(...): + ... + """ + if not _is_package_installed("torch"): + return triton_kernel + + class WrappedTritonKernel: + def __init__(self, kernel): + self.kernel = kernel + + def __getitem__(self, index): + if _should_bypass_compat_kernel_wrapper(): + launcher = self.kernel[index] + if _driver_probe_enabled: + return _wrap_probe_only(launcher) + return launcher + return swap_driver_guard(self.kernel[index]) + + def __getattr__(self, name): + return getattr(self.kernel, name) + + return WrappedTritonKernel(triton_kernel) diff --git a/flashmask/flash_mask/linear_attn/utils.py b/flashmask/flash_mask/linear_attn/utils.py new file mode 100644 index 00000000000..c6a1f7c65f9 --- /dev/null +++ b/flashmask/flash_mask/linear_attn/utils.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +# Adapted from fla/utils.py for PaddlePaddle + +import os +import functools +import inspect + +import paddle +import triton + +# ===== Environment checks ===== +FLA_CI_ENV = os.environ.get('FLA_CI_ENV', '0') == '1' +FLA_CACHE_RESULTS = os.environ.get('FLA_CACHE_RESULTS', '1') == '1' +FLA_DISABLE_TENSOR_CACHE = os.environ.get('FLA_DISABLE_TENSOR_CACHE', '0') == '1' + + +# ===== Device detection ===== +def get_available_device(): + return 'cuda' + + +def get_multiprocessor_count(): + props = paddle.device.cuda.get_device_properties() + return props['multi_processor_count'] + + +def _get_device_name(): + return paddle.device.cuda.get_device_name() + + +device_name = _get_device_name() + +IS_NVIDIA = 'nvidia' in device_name.lower() or 'geforce' in device_name.lower() or 'tesla' in device_name.lower() +IS_AMD = 'amd' in device_name.lower() or 'instinct' in device_name.lower() +IS_INTEL = 'intel' in device_name.lower() + +try: + capability = paddle.device.cuda.get_device_capability() +except: + capability = (0, 0) + +IS_NVIDIA_HOPPER = IS_NVIDIA and capability[0] >= 9 +IS_NVIDIA_BLACKWELL = IS_NVIDIA and capability[0] >= 10 +IS_TF32_SUPPORTED = IS_NVIDIA and capability[0] >= 8 +IS_GATHER_SUPPORTED = True +IS_TMA_SUPPORTED = False # TMA not supported in Paddle migration for now + +USE_CUDA_GRAPH = os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1' + +# lowercase aliases +is_nvidia = IS_NVIDIA +is_amd = IS_AMD +is_intel = IS_INTEL + + +# ===== Backend enum for shared memory ===== +class Backend: + ADA = 101376 + AMPERE = 166912 + HOPPER = 232448 + DEFAULT = 102400 + + +def _get_device_property(props, key, default=None): + if isinstance(props, dict): + return props.get(key, default) + return getattr(props, key, default) + + +def _infer_max_shared_mem_from_device(): + try: + capability = paddle.device.cuda.get_device_capability() + except Exception: + capability = (0, 0) + + try: + name = paddle.device.cuda.get_device_name().lower() + except Exception: + name = '' + + if capability[0] >= 9: + return Backend.HOPPER + if capability[0] >= 8: + if any(token in name for token in ('ada', '4090', '4080', '4070', '4060', 'l40', 'l4')): + return Backend.ADA + return Backend.AMPERE + return 49152 + + +def _get_max_shared_mem(): + try: + props = paddle.device.cuda.get_device_properties() + except Exception: + props = None + + for key in ('shared_memory_per_block_optin', 'shared_memory_per_block'): + value = _get_device_property(props, key) + if value is not None: + return value + return _infer_max_shared_mem_from_device() + + +def check_shared_mem(arch=None, tensor_idx=0): + """Check if device shared memory meets requirements.""" + max_smem = _get_max_shared_mem() + + if arch is None: + return max_smem >= Backend.DEFAULT + elif arch == 'ampere': + return max_smem >= Backend.AMPERE + elif arch == 'hopper': + return max_smem >= Backend.HOPPER + elif arch == 'ada': + return max_smem >= Backend.ADA + return max_smem >= Backend.DEFAULT + + +def get_all_max_shared_mem(): + return _get_max_shared_mem() + + +# ===== Triton version checks ===== +def _check_triton_version(min_version): + try: + from importlib.metadata import version + triton_ver = version('triton') + from packaging.version import Version + return Version(triton_ver) >= Version(min_version) + except: + return False + +TRITON_ABOVE_3_4_0 = _check_triton_version('3.4.0') +TRITON_ABOVE_3_5_1 = _check_triton_version('3.5.1') + +# ===== autotune cache ===== +SUPPORTS_AUTOTUNE_CACHE = hasattr(triton.autotune, '__wrapped__') or True +try: + # Check if triton.autotune supports cache_results + import inspect + sig = inspect.signature(triton.autotune) + SUPPORTS_AUTOTUNE_CACHE = 'cache_results' in sig.parameters +except: + SUPPORTS_AUTOTUNE_CACHE = False + +autotune_cache_kwargs = {} +if SUPPORTS_AUTOTUNE_CACHE and FLA_CACHE_RESULTS: + autotune_cache_kwargs = {'cache_results': True} + + +# ===== AMP adapters ===== +def autocast_custom_fwd(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with paddle.amp.auto_cast(enable=False): + return fn(*args, **kwargs) + return wrapper + + +def autocast_custom_bwd(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with paddle.amp.auto_cast(enable=False): + return fn(*args, **kwargs) + return wrapper + + +# ===== tensor_cache ===== +def tensor_cache(fn): + """Single-entry tensor function cache.""" + if FLA_DISABLE_TENSOR_CACHE: + return fn + + _cache = {} + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + key = (args, tuple(sorted(kwargs.items()))) + if key not in _cache: + _cache.clear() + _cache[key] = fn(*args, **kwargs) + return _cache[key] + return wrapper + + +# ===== input_guard ===== +def input_guard(fn=None, *, no_guard_contiguous=None): + """Ensure all tensor inputs are contiguous.""" + if fn is None: + return functools.partial(input_guard, no_guard_contiguous=no_guard_contiguous) + + skip_names = set(no_guard_contiguous) if no_guard_contiguous else set() + try: + params = list(inspect.signature(fn).parameters.keys()) + except (ValueError, TypeError): + params = [] + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + new_args = [] + for i, arg in enumerate(args): + param_name = params[i] if i < len(params) else '' + if isinstance(arg, paddle.Tensor) and param_name not in skip_names: + if not arg.is_contiguous(): + arg = arg.contiguous() + new_args.append(arg) + + new_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, paddle.Tensor) and k not in skip_names: + if not v.is_contiguous(): + v = v.contiguous() + new_kwargs[k] = v + return fn(*new_args, **new_kwargs) + return wrapper + + +def contiguous(fn): + """Alias for input_guard without parameters.""" + return input_guard(fn) + + +# ===== checkpoint ===== +def checkpoint(fn): + """Wrap function with recompute.""" + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return paddle.distributed.fleet.utils.recompute(fn, *args, **kwargs) + return wrapper + + +# ===== Testing helpers ===== +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().pow(2).mean().sqrt().item() + base = x.detach().flatten().pow(2).mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + error_rate = get_err_ratio(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {error_rate:.6f}" + if abs_atol <= err_atol: + return + assert not paddle.isnan(ref).any(), f"{prefix}: NaN detected in ref" + assert not paddle.isnan(tri).any(), f"{prefix}: NaN detected in tri" + if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)): + if error_rate > ratio: + import warnings + warnings.warn(msg) + else: + assert error_rate < ratio, msg diff --git a/flashmask/setup.py b/flashmask/setup.py index b6940550a88..82b6f639410 100644 --- a/flashmask/setup.py +++ b/flashmask/setup.py @@ -43,9 +43,12 @@ BUILD_FA3 = FLASHMASK_BUILD in ('fa3', 'all') BUILD_FA4 = FLASHMASK_BUILD in ('fa4', 'all') +BUILD_FLA = os.environ.get('BUILD_FLA', '0') == '1' print(f"[flashmask] FLASHMASK_BUILD={FLASHMASK_BUILD} " - f"BUILD_FA3={BUILD_FA3} BUILD_FA4={BUILD_FA4}") + f"BUILD_FA3={BUILD_FA3} BUILD_FA4={BUILD_FA4} BUILD_FLA={BUILD_FLA}") +if BUILD_FLA: + print("[flashmask] Note: FLA (Flash Linear Attention) in flashmask currently only supports GDN and KDA operators.") # ============================================================ # Config @@ -87,6 +90,11 @@ def _get_version(): 'flash_mask.cute', 'flash_mask.cute.*', ] +if not BUILD_FLA: + exclude_packages += [ + 'flash_mask.linear_attn', + 'flash_mask.linear_attn.*', + ] packages = find_packages(exclude=exclude_packages) @@ -94,6 +102,8 @@ def _get_version(): # Dependencies # ============================================================ install_requires = ['typing_extensions'] +if BUILD_FLA: + install_requires += ['triton>=3.5.1'] if BUILD_FA4: install_requires += [ 'nvidia-cutlass==4.2.0.0', diff --git a/flashmask/tests/benchmarks/test_paddle_ops_runner.py b/flashmask/tests/benchmarks/test_paddle_ops_runner.py new file mode 100644 index 00000000000..ed041ff52ec --- /dev/null +++ b/flashmask/tests/benchmarks/test_paddle_ops_runner.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- + +import importlib +from types import SimpleNamespace + +import paddle +import torch + + +def _load_torch_benchmark_run_module(): + return importlib.import_module('benchmarks.ops.run') + + +def _load_paddle_benchmark_run_module(): + return importlib.import_module('benchmarks.paddle_ops.run') + + +def test_paddle_benchmark_registry_registers_gdn_and_kda(): + from benchmarks.paddle_ops.registry import get_op, list_ops + + ops = list_ops() + + assert 'chunk_gdn' in ops + assert 'recurrent_gdn' in ops + assert 'chunk_kda' in ops + assert 'recurrent_kda' in ops + assert get_op('chunk_gdn').import_path == 'linear_attn.ops.gated_delta_rule' + assert get_op('recurrent_gdn').func_name == 'fused_recurrent_gated_delta_rule' + assert get_op('chunk_kda').import_path == 'linear_attn.ops.kda' + assert get_op('recurrent_kda').func_name == 'fused_recurrent_kda' + + +def test_paddle_benchmark_registry_generates_expected_tensor_shapes(): + from benchmarks.paddle_ops.registry import generate_inputs, get_op + + gdn_inputs = generate_inputs(get_op('chunk_gdn'), B=2, T=16, H=4, D=32, dtype=paddle.bfloat16, device='gpu') + kda_inputs = generate_inputs(get_op('chunk_kda'), B=2, T=16, H=4, D=32, dtype=paddle.bfloat16, device='gpu') + + assert list(gdn_inputs['q'].shape) == [2, 16, 4, 32] + assert list(gdn_inputs['g'].shape) == [2, 16, 4] + assert list(gdn_inputs['beta'].shape) == [2, 16, 4] + assert isinstance(gdn_inputs['q'], paddle.Tensor) + + assert list(kda_inputs['q'].shape) == [2, 16, 4, 32] + assert list(kda_inputs['g'].shape) == [2, 16, 4, 32] + assert list(kda_inputs['beta'].shape) == [2, 16, 4] + assert isinstance(kda_inputs['g'], paddle.Tensor) + + +def test_paddle_benchmark_runner_lists_registered_ops(capsys): + from benchmarks.paddle_ops import run + + run.main(['--list']) + + captured = capsys.readouterr() + assert 'chunk_gdn' in captured.out + assert 'chunk_kda' in captured.out + + +def test_torch_benchmark_runner_skip_backward_warms_up_forward_only(monkeypatch): + module = _load_torch_benchmark_run_module() + import triton + + config = SimpleNamespace( + name='recurrent_stub', + import_path='unused', + inputs={}, + skip_backward=True, + output_is_tuple=True, + extra_kwargs={}, + dim_constraints=None, + ) + + monkeypatch.setattr(module, 'get_op', lambda name: config) + monkeypatch.setattr(module, 'generate_inputs', lambda *args, **kwargs: {'x': torch.ones(1)}) + monkeypatch.setattr(module, '_warmup_autotune', lambda fn, n=None: fn()) + + def fake_op(**kwargs): + return (torch.ones(1),) + + def fake_do_bench(fn, quantiles, **kwargs): + fn() + return (1.0, 0.9, 1.1) + + monkeypatch.setattr(module, '_import_op', lambda cfg: fake_op) + monkeypatch.setattr(triton.testing, 'do_bench', fake_do_bench) + + results = module.benchmark_op('recurrent_stub', {'smoke': {'B': 1, 'T': 2, 'H': 3, 'D': 4}}) + + assert [row['mode'] for row in results] == ['fwd'] + + +def test_paddle_benchmark_runner_skip_backward_warms_up_forward_only(monkeypatch): + module = _load_paddle_benchmark_run_module() + import triton + + config = module.OpConfig( + name='recurrent_stub', + import_path='unused', + inputs={}, + skip_backward=True, + ) + + monkeypatch.setattr(module, 'get_op', lambda name: config) + monkeypatch.setattr(module, 'generate_inputs', lambda *args, **kwargs: {'x': paddle.ones([1])}) + monkeypatch.setattr(module, '_warmup_autotune', lambda fn, n=None: fn()) + + class FakeLoss: + def backward(self): + raise AssertionError('backward should not run during warmup for skip_backward ops') + + def fake_op(**kwargs): + return (paddle.ones([1]),) + + def fake_do_bench(fn, quantiles, **kwargs): + fn() + return (1.0, 0.9, 1.1) + + monkeypatch.setattr(module, '_import_op', lambda cfg: fake_op) + monkeypatch.setattr(module.paddle, 'sum', lambda tensor: FakeLoss()) + monkeypatch.setattr(triton.testing, 'do_bench', fake_do_bench) + + results = module.benchmark_op('recurrent_stub', {'smoke': {'B': 1, 'T': 2, 'H': 3, 'D': 4}}) + + assert [row['mode'] for row in results] == ['fwd'] + + +def test_paddle_benchmark_runner_uses_native_backward_api(monkeypatch): + module = _load_paddle_benchmark_run_module() + + calls = [] + + def fake_backward(tensors, grads): + calls.append((tensors, grads)) + + monkeypatch.setattr(module.paddle.autograd, 'backward', fake_backward) + + tensor = paddle.ones([2], dtype='float32') + grad = paddle.ones([2], dtype='float32') + + module._backward(tensor, grad) + + assert len(calls) == 1 + assert calls[0][0] == [tensor] + assert calls[0][1] == [grad] + + +def test_paddle_benchmark_runner_clears_gradients_by_setting_none(): + module = _load_paddle_benchmark_run_module() + + x = paddle.randn([2, 3], dtype='float32') + x.stop_gradient = False + y = (x * x).sum() + y.backward() + + assert x.grad is not None + + module._clear_gradients({'x': x}) + + assert x.grad is None diff --git a/flashmask/tests/linear_attn/__init__.py b/flashmask/tests/linear_attn/__init__.py new file mode 100644 index 00000000000..40a96afc6ff --- /dev/null +++ b/flashmask/tests/linear_attn/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/flashmask/tests/linear_attn/conftest.py b/flashmask/tests/linear_attn/conftest.py new file mode 100644 index 00000000000..df6e32a9a3a --- /dev/null +++ b/flashmask/tests/linear_attn/conftest.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Test fixtures for Paddle migration tests + +import logging +import os +import warnings +import importlib + +import paddle +import pytest + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# assert_close: matches torch fla.utils.assert_close semantics +# --------------------------------------------------------------------------- + +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().pow(2).mean().sqrt().item() + base = x.detach().flatten().pow(2).mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + error_rate = get_err_ratio(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {error_rate:.6f}" + logger.info(msg) + if abs_atol <= err_atol: + return + assert not paddle.isnan(ref).any(), f"{prefix}: NaN detected in ref" + assert not paddle.isnan(tri).any(), f"{prefix}: NaN detected in tri" + if warning: + if error_rate > ratio: + warnings.warn(msg) + else: + assert error_rate < ratio, msg + + +class FrameworkTracker: + """Track the framework backend of tensors and triton driver at runtime. + + Driver detection modes: + - probe (default in tests): reads the result captured *during* kernel + execution by the swap_driver_guard probe hook. This is accurate even + in mixed torch+paddle environments. + - snapshot: inspects the triton active driver at call time (outside + kernel execution). Fast but shows the *default* driver, which is + torch when torch is installed. + + Set ``FLA_BENCHMARK=1`` to disable all probing overhead. In that mode, + ``detect_triton_driver()`` falls back to snapshot and the report omits + the Triton Driver line entirely. + """ + + def __init__(self): + self._benchmark = os.environ.get("FLA_BENCHMARK", "0") == "1" + + # ------ tensor framework ------ + + @staticmethod + def detect_tensor_framework(tensor) -> str: + """Detect which framework a tensor belongs to.""" + module = type(tensor).__module__ + if 'paddle' in module: + return 'paddle' + elif 'torch' in module: + return 'torch' + return f'unknown({module})' + + # ------ triton driver ------ + + def detect_triton_driver(self) -> str: + """Return the triton driver detected during the last kernel launch. + + Uses the probe result captured inside swap_driver_guard when probing + is enabled; falls back to a snapshot of the current active driver + otherwise. + """ + if self._benchmark: + return self._detect_triton_driver_snapshot() + from linear_attn.triton_utils import get_driver_probe_result + result = get_driver_probe_result() + if result == "not_probed": + return self._detect_triton_driver_snapshot() + return result + + @staticmethod + def _detect_triton_driver_snapshot() -> str: + """Inspect the current triton active driver (outside kernel execution).""" + try: + from triton.runtime.driver import driver + from linear_attn.triton_utils import _detect_driver_framework + return _detect_driver_framework(driver.active) + except Exception as e: + return f'error({e})' + + # ------ autograd ------ + + @staticmethod + def detect_autograd_framework(tensor) -> str: + """Detect the autograd backend of a tensor.""" + import paddle + if isinstance(tensor, paddle.Tensor): + if not tensor.stop_gradient: + return 'paddle' + return 'paddle (no_grad)' + try: + import torch + if isinstance(tensor, torch.Tensor): + if tensor.requires_grad: + return 'torch' + return 'torch (no_grad)' + except ImportError: + pass + return 'unknown' + + # ------ report ------ + + def report(self, tensors: dict, label: str = ""): + """Generate a framework detection report.""" + lines = [] + if label: + lines.append(f"\n{'='*60}") + lines.append(f" Framework Detection Report: {label}") + lines.append(f"{'='*60}") + + if not self._benchmark: + lines.append(f" Triton Driver: {self.detect_triton_driver()}") + lines.append(f" {'─'*56}") + + for name, tensor in tensors.items(): + fw = self.detect_tensor_framework(tensor) + ag = self.detect_autograd_framework(tensor) + lines.append(f" {name:20s} | framework: {fw:8s} | autograd: {ag}") + + lines.append(f"{'='*60}\n") + return '\n'.join(lines) + + +@pytest.fixture(autouse=True) +def _driver_probe_lifecycle(): + """Enable driver probing before each test, disable after.""" + if os.environ.get("FLA_BENCHMARK", "0") == "1": + yield + return + from linear_attn.triton_utils import enable_driver_probe, disable_driver_probe + enable_driver_probe() + yield + disable_driver_probe() + + +@pytest.fixture(autouse=True) +def _linear_attn_cache_isolation(): + modules = [] + for name in ( + 'linear_attn.ops.common.chunk_o', + 'linear_attn.ops.kda.wy_fast', + ): + try: + modules.append(importlib.import_module(name)) + except Exception: + continue + for mod in modules: + for attr in ('_const_tiling', '_chunk_o_launch_meta', '_wy_tiling', '_wy_launch_meta'): + fn = getattr(mod, attr, None) + if fn is not None and hasattr(fn, 'cache_clear'): + fn.cache_clear() + yield + for mod in modules: + for attr in ('_const_tiling', '_chunk_o_launch_meta', '_wy_tiling', '_wy_launch_meta'): + fn = getattr(mod, attr, None) + if fn is not None and hasattr(fn, 'cache_clear'): + fn.cache_clear() + + +@pytest.fixture +def framework_tracker(): + return FrameworkTracker() diff --git a/flashmask/tests/linear_attn/test_gated_delta.py b/flashmask/tests/linear_attn/test_gated_delta.py new file mode 100644 index 00000000000..a5690941c1d --- /dev/null +++ b/flashmask/tests/linear_attn/test_gated_delta.py @@ -0,0 +1,971 @@ +# -*- coding: utf-8 -*- +# Tests for Gated Delta Rule operators on PaddlePaddle +# Aligned with tests/ops/test_gated_delta.py + +import os + +import paddle +import paddle.nn.functional as F +import pytest +from einops import repeat + +from linear_attn.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from linear_attn.ops.gated_delta_rule.gate import fused_gdn_gate, naive_gdn_gate +from linear_attn.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule + +from .conftest import assert_close + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HV', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HV{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1, 1, paddle.float32), + (2, 500, 4, 4, 60, 1, 1, paddle.float32), + (2, 1000, 2, 8, 128, 1, 0.1, paddle.float32), + (3, 1024, 2, 2, 128, 0.1, 1, paddle.float32), + (4, 1024, 3, 3, 128, 1, 10, paddle.float32), + (4, 2048, 4, 4, 64, 0.1, 1, paddle.float32), + (2, 1024, 4, 4, 128, 1, 0.1, paddle.float16), + (2, 1024, 4, 8, 128, 1, 10, paddle.float16), + ] + ], +) +def test_fused_recurrent( + B: int, + T: int, + H: int, + HV: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.randn([B, T, H, D], dtype=paddle.float32) + k = paddle.randn([B, T, H, D], dtype=paddle.float32) + v = paddle.randn([B, T, HV, D], dtype=dtype) + beta = paddle.rand([B, T, HV], dtype=dtype).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, HV], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0 = paddle.randn([B, HV, D, D], dtype=paddle.float32) + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(repeat(q.clone(), 'b t h d -> b t (h g) d', g=HV // H), p=2, axis=-1).cast(dtype), + k=F.normalize(repeat(k.clone(), 'b t h d -> b t (h g) d', g=HV // H), p=2, axis=-1).cast(dtype), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + tri, tri_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0.clone(), + use_qk_l2norm_in_kernel=True, + output_final_state=True, + ) + assert_close('o', ref, tri, 0.002) + assert_close('ht', ref_ht, tri_ht, 0.002) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (2, 75, 4, 64, 1, 0.01, 0, False, paddle.float16), + (2, 500, 3, 60, 1, 1, 0, False, paddle.float16), + (2, 1000, 3, 64, 0.1, 1, 0.5, False, paddle.float16), + (3, 1024, 4, 100, 1, 0.1, 0, False, paddle.float16), + (4, 1024, 4, 128, 0.1, 1, 0, False, paddle.float16), + (4, 1024, 4, 128, 0.1, 1, 0, True, paddle.float16), + (2, 1500, 4, 128, 0.1, 10, 0, False, paddle.float16), + (4, 2048, 8, 64, 0.1, 1, 0, False, paddle.float16), + ] + ], +) +def test_chunk( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, H], dtype=paddle.float32)) + g = g / gate_logit_normalizer + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + for t in [q, k, v, beta, g, h0]: + t.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0.clear_gradient() + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + output_final_state=True, + initial_state=h0.clone(), + ) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + + +@pytest.mark.parametrize( + ('B', 'T', 'Hq', 'H', 'D', 'scale', 'gate_logit_normalizer', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-Hq{}-H{}-D{}-scale{}-gate_logit_normalizer{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (2, 256, 2, 4, 64, 1, 1, False, paddle.float16), + (2, 512, 1, 4, 64, 0.1, 1, False, paddle.float16), + (2, 512, 2, 8, 64, 1, 0.1, True, paddle.float16), + (2, 1024, 4, 8, 128, 0.1, 1, False, paddle.float16), + ] + ], +) +def test_chunk_gqa( + B: int, + T: int, + Hq: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + dtype, +): + paddle.seed(42) + assert H % Hq == 0 + G = H // Hq + + q = paddle.rand([B, T, Hq, D], dtype=dtype) + k = paddle.rand([B, T, Hq, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, H], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + for t in [q, k, v, beta, g, h0]: + t.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0.clear_gradient() + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(repeat(q.clone(), 'b t h d -> b t (h g) d', g=G), p=2, axis=-1), + k=F.normalize(repeat(k.clone(), 'b t h d -> b t (h g) d', g=G), p=2, axis=-1), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + output_final_state=True, + initial_state=h0.clone(), + ) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, paddle.float16), + (2, 500, 3, 60, 1, 1, paddle.float16), + (3, 1024, 4, 128, 0.1, 1, paddle.float16), + (4, 2048, 8, 64, 0.1, 1, paddle.float16), + ] + ], +) +def test_chunk_transpose_state( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=dtype).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, H], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0_kv = paddle.randn([B, H, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + for t in [q, k, v, beta, g, h0_kv, h0_vk]: + t.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_vk.clone(), + output_final_state=True, + transpose_state_layout=True, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht_vk = paddle.randn([B, H, D, D], dtype=paddle.float32) + dht_kv = dht_vk.transpose([0, 1, 3, 2]).contiguous() + ((tri * do).sum() + (tri_ht * dht_vk).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0_vk.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0_vk.clear_gradient() + + ref, ref_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_kv.clone(), + output_final_state=True, + transpose_state_layout=False, + ) + ((ref * do).sum() + (ref_ht * dht_kv).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0_kv.grad.clone() + ) + + assert_close('o', ref, tri, 1e-4) + assert_close('ht', ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + assert_close('dq', ref_dq, tri_dq, 1e-4) + assert_close('dk', ref_dk, tri_dk, 1e-4) + assert_close('dv', ref_dv, tri_dv, 1e-4) + assert_close('db', ref_dbeta, tri_dbeta, 1e-4) + assert_close('dg', ref_dg, tri_dg, 1e-4) + assert_close('dh0', ref_dh0, tri_dh0.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HV', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HV{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1, 1, paddle.float32), + (2, 500, 4, 4, 60, 1, 1, paddle.float32), + (2, 1000, 2, 8, 128, 1, 0.1, paddle.float32), + (3, 1024, 2, 2, 128, 0.1, 1, paddle.float32), + (4, 2048, 4, 4, 64, 0.1, 1, paddle.float32), + ] + ], +) +def test_fused_recurrent_transpose_state( + B: int, + T: int, + H: int, + HV: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.randn([B, T, H, D], dtype=paddle.float32) + k = paddle.randn([B, T, H, D], dtype=paddle.float32) + v = paddle.randn([B, T, HV, D], dtype=dtype) + beta = paddle.rand([B, T, HV], dtype=dtype).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, HV], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0_kv = paddle.randn([B, HV, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + + ref, ref_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0_kv.clone(), + use_qk_l2norm_in_kernel=True, + output_final_state=True, + transpose_state_layout=False, + ) + tri, tri_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0_vk.clone(), + use_qk_l2norm_in_kernel=True, + output_final_state=True, + transpose_state_layout=True, + ) + assert_close('o', ref, tri, 1e-4) + assert_close('ht', ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, 0, [0, 15], paddle.float16), + (4, 64, 0, [0, 256, 500, 1000], paddle.float16), + (4, 64, 0.5, [0, 256, 500, 1000], paddle.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set', +) +def test_chunk_varlen( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, +): + paddle.seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.rand([1, T, H], dtype=dtype)) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + beta = paddle.rand([1, T, H], dtype=paddle.float32).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=dtype) + + for t in [q, k, v, beta, g, h0]: + t.stop_gradient = False + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.rand(h0.shape, dtype=h0.dtype) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0.clear_gradient() + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_gated_delta_rule( + q=q[:, s:e], + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=g[:, s:e], + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) + assert_close('dg', ref_dg, tri_dg, 0.015) + assert_close('dh0', ref_dh0, tri_dh0, 0.007) + + +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, 0, [0, 8192], paddle.float16), + (4, 60, 0, [0, 15], paddle.float16), + (4, 64, 0, [0, 256, 500, 1000], paddle.float16), + (4, 64, 0.5, [0, 256, 500, 1000], paddle.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set', +) +def test_chunk_varlen_prefill( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, +): + paddle.seed(42) + with paddle.no_grad(): + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.rand([1, T, H], dtype=dtype)) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + beta = paddle.rand([1, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=dtype) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + ) + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_gated_delta_rule( + q=q[:, s:e], + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=g[:, s:e], + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'has_dt_bias', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-has_dt_bias{}-use_qk_l2norm{}-{}".format(*test), + ) + for test in [ + (2, 75, 4, 64, 1, True, True, paddle.float16), + (2, 500, 3, 60, 1, False, False, paddle.float16), + (2, 1000, 3, 64, 0.1, True, False, paddle.float16), + (3, 1024, 4, 100, 1, True, True, paddle.float16), + (4, 1024, 4, 128, 0.1, False, True, paddle.float16), + (4, 2048, 8, 64, 0.1, True, False, paddle.float16), + ] + ], +) +def test_chunk_gate_in_kernel( + B: int, + T: int, + H: int, + D: int, + scale: float, + has_dt_bias: bool, + use_qk_l2norm_in_kernel: bool, + dtype, +): + """Test use_gate_in_kernel=True path: fused gate activation + chunk cumsum inside kernel.""" + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g_raw = paddle.randn([B, T, H], dtype=paddle.float32) + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H], dtype=paddle.float32) if has_dt_bias else None + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, beta, g_raw, h0]: + t.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + + # === Triton path: use_gate_in_kernel=True === + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone() if use_qk_l2norm_in_kernel else F.normalize(q.clone(), p=2, axis=-1), + k=k.clone() if use_qk_l2norm_in_kernel else F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g_raw.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=True, + A_log=A_log.clone(), + dt_bias=dt_bias.clone() if dt_bias is not None else None, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g_raw.grad.clone(), h0.grad.clone() + ) + tri_dA_log = A_log.grad.clone() + tri_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g_raw.clear_gradient() + h0.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + # === Reference path: manually compute gate, then use_gate_in_kernel=False === + g_ref = naive_gdn_gate(g_raw, A_log, dt_bias) + ref, ref_ht = chunk_gated_delta_rule( + q=q.clone() if use_qk_l2norm_in_kernel else F.normalize(q.clone(), p=2, axis=-1), + k=k.clone() if use_qk_l2norm_in_kernel else F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g_ref, + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), h0.grad.clone() + ) + ref_dg = g_raw.grad.clone() + ref_dA_log = A_log.grad.clone() + ref_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + assert_close('dA_log', ref_dA_log, tri_dA_log, 0.02) + if dt_bias is not None: + assert_close('ddt_bias', ref_ddt_bias, tri_ddt_bias, 0.02) + + +@pytest.mark.parametrize( + ('B', 'T', 'Hq', 'H', 'D', 'scale', 'has_dt_bias', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-Hq{}-H{}-D{}-scale{}-has_dt_bias{}-{}".format(*test), + ) + for test in [ + (2, 256, 2, 4, 64, 1, True, paddle.float16), + (2, 512, 1, 4, 64, 0.1, False, paddle.float16), + (2, 512, 2, 8, 64, 1, True, paddle.float16), + (2, 1024, 4, 8, 128, 0.1, True, paddle.float16), + ] + ], +) +def test_chunk_gate_in_kernel_gqa( + B: int, + T: int, + Hq: int, + H: int, + D: int, + scale: float, + has_dt_bias: bool, + dtype, +): + """Test use_gate_in_kernel=True with grouped value attention (HV > H).""" + paddle.seed(42) + assert H % Hq == 0 + + q = paddle.rand([B, T, Hq, D], dtype=dtype) + k = paddle.rand([B, T, Hq, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g_raw = paddle.randn([B, T, H], dtype=paddle.float32) + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H], dtype=paddle.float32) if has_dt_bias else None + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, beta, g_raw, h0]: + t.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_raw.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + A_log=A_log.clone(), + dt_bias=dt_bias.clone() if dt_bias is not None else None, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g_raw.grad.clone(), h0.grad.clone() + ) + tri_dA_log = A_log.grad.clone() + tri_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g_raw.clear_gradient() + h0.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + g_ref = naive_gdn_gate(g_raw, A_log, dt_bias) + ref, ref_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_ref, + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), h0.grad.clone() + ) + ref_dg = g_raw.grad.clone() + ref_dA_log = A_log.grad.clone() + ref_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + assert_close('dA_log', ref_dA_log, tri_dA_log, 0.02) + if dt_bias is not None: + assert_close('ddt_bias', ref_ddt_bias, tri_ddt_bias, 0.02) + + +@pytest.mark.parametrize( + ('H', 'D', 'has_dt_bias', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-has_dt_bias{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, True, [0, 15], paddle.float16), + (4, 64, False, [0, 256, 500, 1000], paddle.float16), + (4, 64, True, [0, 256, 500, 1000], paddle.float16), + (4, 100, True, [0, 15, 100, 300, 1200, 2000], paddle.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test because SKIP_TEST_CHUNK_VARLEN is set', +) +def test_chunk_gate_in_kernel_varlen( + H: int, + D: int, + has_dt_bias: bool, + cu_seqlens: list, + dtype, +): + """Test use_gate_in_kernel=True with variable-length sequences.""" + paddle.seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = paddle.randn([1, T, H, D], dtype=dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + beta = paddle.rand([1, T, H], dtype=paddle.float32).sigmoid() + g_raw = paddle.randn([1, T, H], dtype=paddle.float32) + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H], dtype=paddle.float32) if has_dt_bias else None + h0 = paddle.randn([N, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, beta, g_raw, h0]: + t.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.rand(h0.shape, dtype=h0.dtype) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_raw.clone(), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + A_log=A_log.clone(), + dt_bias=dt_bias.clone() if dt_bias is not None else None, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g_raw.grad.clone(), h0.grad.clone() + ) + tri_dA_log = A_log.grad.clone() + tri_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g_raw.clear_gradient() + h0.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + g_ref = naive_gdn_gate(g_raw, A_log, dt_bias) + ref, ref_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_ref, + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + use_qk_l2norm_in_kernel=True, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), h0.grad.clone() + ) + ref_dg = g_raw.grad.clone() + ref_dA_log = A_log.grad.clone() + ref_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + assert_close('dA_log', ref_dA_log, tri_dA_log, 0.02) + if dt_bias is not None: + assert_close('ddt_bias', ref_ddt_bias, tri_ddt_bias, 0.02) + + +@pytest.mark.parametrize( + ('B', 'T', 'HV', 'HAS_BIAS'), + [ + pytest.param(*test, id="B{}-T{}-HV{}-bias{}".format(*test)) + for test in [ + (1, 32, 2, False), + (2, 64, 4, True), + (4, 128, 8, True), + (4, 128, 16, False), + ] + ], +) +def test_gate( + B: int, + T: int, + HV: int, + HAS_BIAS: bool, +): + paddle.seed(42) + g = paddle.randn([B, T, HV], dtype=paddle.float32) + A_log = paddle.log(paddle.uniform([HV], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([HV], dtype=paddle.float32) if HAS_BIAS else None + g.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + do = paddle.randn([B, T, HV], dtype=paddle.float32) + + ref = naive_gdn_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + ) + tri = fused_gdn_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + ) + (ref * do).sum().backward(retain_graph=True) + + ref_dg = g.grad.clone() + ref_dA = A_log.grad.clone() + ref_dbias = dt_bias.grad.clone() if dt_bias is not None else None + g.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + (tri * do).sum().backward(retain_graph=True) + tri_dg = g.grad.clone() + tri_dA = A_log.grad.clone() + tri_dbias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close("o", ref, tri, 1e-4) + assert_close("dg", ref_dg, tri_dg, 1e-4) + assert_close("dA", ref_dA, tri_dA, 1e-4) + if HAS_BIAS: + assert_close("dbias", ref_dbias, tri_dbias, 1e-4) diff --git a/flashmask/tests/linear_attn/test_kda.py b/flashmask/tests/linear_attn/test_kda.py new file mode 100644 index 00000000000..4cedf015aca --- /dev/null +++ b/flashmask/tests/linear_attn/test_kda.py @@ -0,0 +1,909 @@ +# -*- coding: utf-8 -*- +# Tests for KDA (Kimi Delta Attention) operators on PaddlePaddle +# Aligned with tests/ops/test_kda.py + +import paddle +import paddle.nn.functional as F +import pytest + +from linear_attn.ops.kda import chunk_kda, fused_recurrent_kda +from linear_attn.ops.kda.fused_recurrent import fused_recurrent_kda_fwd +from linear_attn.ops.kda.gate import fused_kda_gate, naive_kda_gate, naive_kda_lowerbound_gate +from linear_attn.ops.kda.naive import naive_chunk_kda, naive_recurrent_kda + +from .conftest import assert_close + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test), + ) + for test in [ + (1, 64, 1, 64, 1, 1, paddle.float32), + (2, 512, 3, 60, 1, 1, paddle.float32), + (4, 1024, 4, 128, 0.1, 1, paddle.float32), + (4, 1024, 4, 128, 1, 10, paddle.float32), + ] + ], +) +def test_naive_chunk( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([B, H, D, D], dtype=paddle.float32) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + + tri, tri_ht = naive_chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "use_qk_l2norm_in_kernel", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (1, 64, 1, 64, 1, 1, False, paddle.float32), + (2, 512, 3, 60, 1, 1, False, paddle.float32), + (3, 1000, 4, 100, 0.1, 1, True, paddle.float32), + (4, 1024, 4, 128, 0.1, 1, False, paddle.float32), + ] + ], +) +def test_fused_recurrent( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([B, H, D, D], dtype=paddle.float32) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + + tri, tri_ht = fused_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test), + ) + for test in [ + (1, 64, 1, 64, 1, 1, paddle.float32), + (2, 512, 3, 60, 1, 1, paddle.float32), + (3, 1000, 4, 100, 0.1, 1, paddle.float32), + (4, 1024, 4, 128, 0.1, 1, paddle.float32), + ] + ], +) +def test_fused_recurrent_transpose_state( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0_kv = paddle.randn([B, H, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + + ref, ref_ht = fused_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_kv.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=False, + ) + tri, tri_ht = fused_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_vk.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=True, + ) + assert_close("o", ref, tri, 1e-4) + assert_close("ht", ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ("B", "H", "D", "scale", "gate_logit_normalizer", "use_qk_l2norm_in_kernel", "use_gate_in_kernel", "safe_gate", "dtype"), + [ + pytest.param( + *test, + id="B{}-H{}-D{}-scale{}-norm{}-qk_l2{}-gate{}-safe_gate{}-dtype{}".format(*test), + ) + for test in [ + (16, 16, 128, 0.1, 1.0, True, False, False, paddle.bfloat16), + (32, 8, 64, 1.0, 1.0, False, False, False, paddle.float16), + (16, 16, 128, 0.1, 1.0, True, True, False, paddle.bfloat16), + (32, 8, 64, 1.0, 1.0, False, True, False, paddle.float16), + (7, 32, 128, 0.5, 0.5, True, True, True, paddle.bfloat16), + ] + ], +) +def test_fused_recurrent_vllm_decode( + B: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + safe_gate: bool, + dtype, +): + """Test vLLM-style decoding with continuous batching and paged state storage.""" + paddle.seed(42) + + # Setup cache pool and inputs + max_cache_slots = B * 3 + state_pool = paddle.randn([max_cache_slots, H, D, D], dtype=paddle.float32) + state_indices = paddle.randperm(max_cache_slots)[:B].cast(paddle.int32) + + # Fill unaccessed slots with a huge value to detect out-of-bound access + HUGE_VALUE = 1e30 + mask = paddle.ones([max_cache_slots], dtype='bool') + mask[state_indices.cast(paddle.int64)] = False + state_pool[mask] = HUGE_VALUE + + T = 1 + total_tokens = B * T + + q = paddle.rand([1, total_tokens, H, D], dtype=dtype) + k = paddle.rand([1, total_tokens, H, D], dtype=dtype) + v = paddle.rand([1, total_tokens, H, D], dtype=dtype) + g = paddle.randn([1, total_tokens, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + + if use_gate_in_kernel: + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)).squeeze() + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + lower_bound = -5.0 if safe_gate else None + naive_kda_gate_fn = naive_kda_lowerbound_gate if safe_gate else naive_kda_gate + else: + g = F.log_sigmoid(g) / gate_logit_normalizer + A_log = None + dt_bias = None + lower_bound = None + naive_kda_gate_fn = None + + beta = paddle.randn([1, total_tokens, H], dtype=dtype).sigmoid() + + cu_seqlens = paddle.arange(0, total_tokens + 1, step=T, dtype=paddle.int32) + ref_state_pool = state_pool.clone() + tri_state_pool = state_pool.clone() + + # Reference implementation (loop over batch) + ref_outputs = [] + for i in range(B): + start, end = i, i + 1 + slot_idx = state_indices[i].item() + + q_i = q[:, start:end].clone() + k_i = k[:, start:end].clone() + v_i = v[:, start:end].clone() + g_i = g[:, start:end].clone() + beta_i = beta[:, start:end].clone() + + h_init = ref_state_pool[slot_idx].clone().unsqueeze(0) + ref_o_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q_i, p=2, axis=-1), + k=F.normalize(k_i, p=2, axis=-1), + v=v_i, + g=(naive_kda_gate_fn(g_i, A_log, dt_bias) if use_gate_in_kernel else g_i), + beta=beta_i, + scale=scale, + initial_state=h_init, + output_final_state=True + ) + ref_outputs.append(ref_o_i) + ref_state_pool[slot_idx] = ref_ht_i.squeeze(0) + + ref_out = paddle.concat(ref_outputs, axis=1) + + # Triton kernel + q_in = q.clone() + k_in = k.clone() + if not use_qk_l2norm_in_kernel: + q_in = F.normalize(q_in, p=2, axis=-1) + k_in = F.normalize(k_in, p=2, axis=-1) + + tri_out, _ = fused_recurrent_kda_fwd( + q=q_in, + k=k_in, + v=v, + g=g, + beta=beta, + A_log=A_log, + dt_bias=dt_bias, + initial_state=tri_state_pool, + scale=scale, + output_final_state=False, + inplace_final_state=True, + cu_seqlens=cu_seqlens, + ssm_state_indices=state_indices, + num_accepted_tokens=None, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + lower_bound=lower_bound, + ) + + # Verify results + assert_close("o", ref_out, tri_out, 0.005) + assert_close("ht", ref_state_pool[state_indices.cast(paddle.int64)], + tri_state_pool[state_indices.cast(paddle.int64)], 0.005) + + mask = paddle.ones([max_cache_slots], dtype='bool') + mask[state_indices.cast(paddle.int64)] = False + assert_close("Untouched ht", ref_state_pool[mask], tri_state_pool[mask], 0.0) + + +@pytest.mark.parametrize( + ( + "B", "T", "H", "D", "scale", "gate_logit_normalizer", + "mask_p", "use_qk_l2norm_in_kernel", "use_gate_in_kernel", + "dtype", "safe_gate", "disable_recompute", + ), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-qk_l2norm{}-gate{}-dtype{}-safe_gate{}-disable_recompute{}".format( + *test), + ) + for test in [ + (1, 63, 1, 64, 1, 1, 0, False, False, paddle.float16, True, False), + (2, 500, 3, 60, 1, 1, 0, False, False, paddle.float16, True, True), + (2, 1000, 3, 64, 0.1, 1, 0.5, False, False, paddle.float16, False, True), + (3, 1024, 4, 100, 1, 0.1, 0, False, False, paddle.float16, False, False), + (4, 1024, 4, 128, 0.1, 1, 0, False, False, paddle.float16, True, True), + (4, 1024, 4, 128, 0.1, 1, 0, True, False, paddle.float16, True, False), + (2, 1500, 4, 128, 0.1, 10, 0, False, True, paddle.float16, False, True), + (4, 2048, 8, 64, 0.1, 1, 0, False, True, paddle.float16, True, True), + ] + ], +) +def test_chunk( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + dtype, + safe_gate: bool, + disable_recompute: bool, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = paddle.randn([B, T, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + else: + g = F.log_sigmoid(g) / gate_logit_normalizer + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + if safe_gate: + lower_bound = -5.0 + if not use_gate_in_kernel: + g = g.clip(-5, 0) + naive_kda_gate_fn = naive_kda_lowerbound_gate + else: + lower_bound = None + naive_kda_gate_fn = naive_kda_gate + + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([B, H, D, D], dtype=paddle.float32) + + if use_gate_in_kernel: + A_log.stop_gradient = False + dt_bias.stop_gradient = False + for t in [q, k, v, g, beta, h0]: + t.stop_gradient = False + + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + if use_gate_in_kernel: + ref_dA = A_log.grad.clone() + A_log.clear_gradient() + ref_dbias = dt_bias.grad.clone() + dt_bias.clear_gradient() + ref_dq, ref_dk, ref_dv, ref_dg, ref_db, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + g.clear_gradient() + beta.clear_gradient() + h0.clear_gradient() + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + disable_recompute=disable_recompute, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + if use_gate_in_kernel: + tri_dA = A_log.grad.clone() + A_log.clear_gradient() + tri_dbias = dt_bias.grad.clone() + dt_bias.clear_gradient() + tri_dq, tri_dk, tri_dv, tri_dg, tri_db, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("dq", ref_dq, tri_dq, 0.008) + assert_close("dk", ref_dk, tri_dk, 0.008) + assert_close("dv", ref_dv, tri_dv, 0.008) + assert_close("dg", ref_dg, tri_dg, 0.02) + assert_close("db", ref_db, tri_db, 0.02) + if use_gate_in_kernel: + assert_close("dA", ref_dA, tri_dA, 0.003, warning=True) + assert_close("dbias", ref_dbias, tri_dbias, 0.008) + assert_close("dh0", ref_dh0, tri_dh0, 0.008) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test), + ) + for test in [ + (1, 63, 1, 64, 1, 1, paddle.float16), + (2, 500, 3, 60, 1, 1, paddle.float16), + (3, 1024, 4, 128, 0.1, 1, paddle.float16), + (4, 2048, 8, 64, 0.1, 1, paddle.float16), + ] + ], +) +def test_chunk_transpose_state( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0_kv = paddle.randn([B, H, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + + for t in [q, k, v, g, beta, h0_kv, h0_vk]: + t.stop_gradient = False + + do = paddle.randn(v.shape, dtype=v.dtype) + dht_vk = paddle.randn([B, H, D, D], dtype=paddle.float32) + dht_kv = dht_vk.transpose([0, 1, 3, 2]).contiguous() + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_vk.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=True, + ) + ((tri * do).sum() + (tri_ht * dht_vk).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dg, tri_db, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0_vk.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + g.clear_gradient() + beta.clear_gradient() + h0_vk.clear_gradient() + + ref, ref_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_kv.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=False, + ) + ((ref * do).sum() + (ref_ht * dht_kv).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dg, ref_db, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0_kv.grad.clone() + ) + + assert_close("o", ref, tri, 1e-4) + assert_close("ht", ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + assert_close("dq", ref_dq, tri_dq, 1e-4) + assert_close("dk", ref_dk, tri_dk, 1e-4) + assert_close("dv", ref_dv, tri_dv, 1e-4) + assert_close("dg", ref_dg, tri_dg, 1e-4) + assert_close("db", ref_db, tri_db, 1e-4) + assert_close("dh0", ref_dh0, tri_dh0.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "use_gate_in_kernel", "safe_gate", "disable_recompute"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-gate{}-safe_gate{}-disable_recompute{}".format(*test)) + for test in [ + (4, 60, 0.1, [0, 15], paddle.float16, True, False, False), + (4, 64, 0.9, [0, 256, 500, 1000], paddle.float16, True, False, False), + (4, 128, 0.5, [0, 256, 500, 1000], paddle.float16, False, False, False), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16, True, False, False), + (4, 256, 0, [0, 100, 300, 1200, 3000, 4096], paddle.float16, False, True, True), + ] + ], +) +def test_chunk_varlen( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, + use_gate_in_kernel: bool, + safe_gate: bool, + disable_recompute: bool, +): + paddle.seed(42) + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + cu_seqlens_cpu = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = paddle.randn([1, T, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + else: + g = F.log_sigmoid(g) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + mask = (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + g = g * mask + (1 - mask) * (-1000) + if safe_gate: + assert use_gate_in_kernel is False + g = g.clip(-5, 0) + + beta = paddle.rand([1, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, g, beta, h0]: + t.stop_gradient = False + if use_gate_in_kernel: + A_log.stop_gradient = False + dt_bias.stop_gradient = False + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.rand(h0.shape, dtype=h0.dtype) + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + cu_seqlens_cpu=cu_seqlens_cpu, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + disable_recompute=disable_recompute, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dg, tri_db, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + g.clear_gradient() + beta.clear_gradient() + h0.clear_gradient() + if use_gate_in_kernel: + tri_dA = A_log.grad.clone() + A_log.clear_gradient() + tri_dbias = dt_bias.grad.clone() + dt_bias.clear_gradient() + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q[:, s:e], p=2, axis=-1), + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=(naive_kda_gate(g[:, s:e].cast(paddle.float32), A_log.cast(paddle.float32), + dt_bias.cast(paddle.float32)) if use_gate_in_kernel else g[:, s:e]), + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dg, ref_db, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + if use_gate_in_kernel: + ref_dA = A_log.grad.clone() + ref_dbias = dt_bias.grad.clone() + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("dq", ref_dq, tri_dq, 0.007) + assert_close("dk", ref_dk, tri_dk, 0.008) + assert_close("dv", ref_dv, tri_dv, 0.007) + assert_close("dg", ref_dg, tri_dg, 0.015) + assert_close("db", ref_db, tri_db, 0.015) + assert_close("dh0", ref_dh0, tri_dh0, 0.007) + if use_gate_in_kernel: + assert_close("dA", ref_dA, tri_dA, 0.008, warning=True) + assert_close("dbias", ref_dbias, tri_dbias, 0.005) + + +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "use_gate_in_kernel", "safe_gate", "disable_recompute"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-gate{}-safe_gate{}-disable_recompute{}".format(*test)) + for test in [ + (4, 60, 0.1, [0, 8192], paddle.float16, True, False, False), + (4, 64, 0.9, [0, 256, 500, 1000], paddle.float16, True, False, False), + (4, 128, 0.5, [0, 256, 500, 1000], paddle.float16, False, False, False), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16, True, False, False), + (4, 256, 0, [0, 100, 300, 1200, 3000, 4096], paddle.float16, False, True, True), + ] + ], +) +def test_chunk_varlen_prefill( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, + use_gate_in_kernel: bool, + safe_gate: bool, + disable_recompute: bool, +): + paddle.seed(42) + with paddle.no_grad(): + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + cu_seqlens_cpu = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = paddle.randn([1, T, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + else: + g = F.log_sigmoid(g) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + mask = (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + g = g * mask + (1 - mask) * (-1000) + if safe_gate: + assert use_gate_in_kernel is False + g = g.clip(-5, 0) + + beta = paddle.rand([1, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=paddle.float32) + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + cu_seqlens_cpu=cu_seqlens_cpu, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + disable_recompute=disable_recompute, + ) + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q[:, s:e], p=2, axis=-1), + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=(naive_kda_gate(g[:, s:e].cast(paddle.float32), A_log.cast(paddle.float32), + dt_bias.cast(paddle.float32)) if use_gate_in_kernel else g[:, s:e]), + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "HAS_BIAS", "LOWER_BOUND"), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-bias{}-lowerbound{}".format(*test)) + for test in [ + (1, 2, 2, 12, False, -5.0), + (1, 32, 2, 16, False, -5.0), + (2, 64, 4, 32, False, -5.0), + (4, 128, 8, 64, False, -5.0), + (4, 128, 8, 128, False, None), + (1, 2, 2, 12, True, None), + (1, 32, 2, 16, True, None), + (2, 64, 4, 32, True, None), + (4, 128, 8, 64, True, None), + (4, 128, 8, 128, True, None), + ] + ], +) +def test_gate( + B: int, + T: int, + H: int, + D: int, + HAS_BIAS: bool, + LOWER_BOUND, +): + paddle.seed(42) + g = paddle.randn([B, T, H, D], dtype=paddle.float32) * 10 + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) if HAS_BIAS else None + g.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + do = paddle.randn([B, T, H, D], dtype=paddle.float32) + + if LOWER_BOUND is not None: + ref = naive_kda_lowerbound_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, LOWER_BOUND + ) + else: + ref = naive_kda_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + ) + tri = fused_kda_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + lower_bound=LOWER_BOUND + ) + (ref * do).sum().backward(retain_graph=True) + + ref_dg = g.grad.clone() + ref_dA = A_log.grad.clone() + ref_dbias = dt_bias.grad.clone() if dt_bias is not None else None + g.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + ((tri * do).sum()).backward(retain_graph=True) + tri_dg = g.grad.clone() + tri_dA = A_log.grad.clone() + tri_dbias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close("o", ref, tri, 1e-4) + assert_close("dg", ref_dg, tri_dg, 1e-4) + assert_close("dA", ref_dA, tri_dA, 1e-4) + if HAS_BIAS: + assert_close("dbias", ref_dbias, tri_dbias, 1e-4) + + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +def test_chunk_return_intermediate_states(dtype): + """Test that return_intermediate_states=True works in inference mode and returns h with correct shape.""" + paddle.seed(42) + B, T, H, D = 2, 1024, 4, 128 + chunk_size = 64 + + with paddle.no_grad(): + q = paddle.randn([B, T, H, D], dtype=dtype) + k = paddle.randn([B, T, H, D], dtype=dtype) + v = paddle.randn([B, T, H, D], dtype=dtype) + g = paddle.randn([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=dtype) + + # Test equal-length sequences + o, final_state, h = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=None, + output_final_state=True, + return_intermediate_states=True, + disable_recompute=False, + ) + + # Verify shapes + assert list(o.shape) == [B, T, H, D], f"Output shape mismatch: {o.shape}" + assert list(final_state.shape) == [B, H, D, D], f"Final state shape mismatch: {final_state.shape}" + + expected_nt = (T + chunk_size - 1) // chunk_size + assert list(h.shape) == [B, expected_nt, H, D, D], f"h shape mismatch: {h.shape}" + assert h.dtype == dtype, f"h dtype should be {dtype}, got: {h.dtype}" + + # Test variable-length sequences + total_tokens = 1024 + N = 2 + seq_len = total_tokens // N + cu_seqlens = paddle.to_tensor([0, seq_len, total_tokens], dtype=paddle.int64) + + q_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + k_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + v_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + g_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + beta_varlen = paddle.rand([1, total_tokens, H], dtype=dtype) + + o_varlen, final_state_varlen, h_varlen = chunk_kda( + q=q_varlen, + k=k_varlen, + v=v_varlen, + g=g_varlen, + beta=beta_varlen, + initial_state=None, + output_final_state=True, + cu_seqlens=cu_seqlens, + return_intermediate_states=True, + disable_recompute=False, + ) + + assert list(o_varlen.shape) == [1, total_tokens, H, D], f"Varlen output shape mismatch: {o_varlen.shape}" + assert list(final_state_varlen.shape) == [N, H, D, D], f"Varlen final state shape mismatch: {final_state_varlen.shape}" + assert h_varlen.shape[0] == 1, f"Varlen h batch dim should be 1, got: {h_varlen.shape[0]}" + assert list(h_varlen.shape[2:]) == [H, D, D], f"Varlen h dims mismatch: {h_varlen.shape[2:]}" + assert h_varlen.dtype == dtype, f"Varlen h dtype should be {dtype}, got: {h_varlen.dtype}"