Skip to content

Commit b3dd131

Browse files
ZLkanyo009zovonoir
authored andcommitted
[BugFix] enable deepseek r1 fp4
1 parent 108a70e commit b3dd131

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

atom/plugin/sglang/attention_backend/radix_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def __init__(
8888
torch.tensor([1.0], dtype=torch.float32, device="cuda"),
8989
requires_grad=False,
9090
)
91+
if self.attn.k_scale_float is None:
92+
self.attn.k_scale_float = 1.0
9193
if self.attn.v_scale is None:
9294
self.attn.v_scale = torch.nn.Parameter(
9395
torch.tensor([1.0], dtype=torch.float32, device="cuda"),

atom/plugin/sglang/attention_backend/sgl_attention_mla.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -909,10 +909,14 @@ def _split_and_assign_kc_vc(
909909
w_vc = w_vc.contiguous()
910910
attn.w_vc = bind_or_assign(attn.w_vc, w_vc)
911911

912-
if hasattr(attn.kv_b_proj, "weight_scale") and attn.w_scale is None:
913-
attn.w_scale = bind_or_assign(attn.w_scale, attn.kv_b_proj.weight_scale)
914-
if _is_hip:
915-
attn.w_scale *= 2.0
912+
kv_weight_scale = getattr(attn.kv_b_proj, "weight_scale", None)
913+
if (
914+
kv_weight_scale is not None
915+
and attn.w_scale is None
916+
and w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz)
917+
):
918+
scale = kv_weight_scale * 2.0 if _is_hip else kv_weight_scale
919+
attn.w_scale = bind_or_assign(attn.w_scale, scale)
916920

917921
if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
918922
attn.w_kc = attn.w_kc.to(torch.bfloat16) * attn.w_scale

0 commit comments

Comments
 (0)