Skip to content

Commit 6e456b0

Browse files
ChenhanYuclaude
andcommitted
fix: PTQ 1GPU, export PP divisibility, hidden states conversations key
- megatron_lm_ptq.yaml: Qwen3-8B to single GPU for L40 clusters - quantize.sh: auto-find largest PP dividing model num_hidden_layers for export (Qwen3-8B has 36 layers, not divisible by 8) - compute_hidden_states_trtllm.py: use messages with conversations fallback (matching the HF version) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
1 parent e9a4989 commit 6e456b0

3 files changed

Lines changed: 23 additions & 12 deletions

File tree

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ async def submit_generates():
256256
for entry in dataset:
257257
conversation_id = entry.get("conversation_id", entry.get("uuid"))
258258

259-
conversations = entry["conversations"]
259+
conversations = entry.get("messages") or entry.get("conversations")
260260
if not conversations or not isinstance(conversations, list):
261261
num_invalid += 1
262262
continue

tools/launcher/common/megatron_lm/quantize/quantize.sh

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,22 @@ TP=${TP:-1} PP=${PP:-1} EP=${EP:-1} ETP=${ETP:-1} ${QUANTIZE_EXE} ${MLM_MODEL_CF
4141
export MLM_EXTRA_ARGS="--mmlu-dataset ${MMLU_DATASET:-/hf-local/cais/mmlu} --fraction 0.01 --lower-bound ${MMLU_LOWER_BOUND:-0.38} --disable-tqdm"
4242
TP=${TP:-1} PP=${PP:-1} EP=${EP:-1} ETP=${ETP:-1} MLM_MODEL_CKPT=${MLM_MODEL_SAVE} ${MMLU_EXE} ${MLM_MODEL_CFG}
4343

44-
# Export quantized checkpoint to HF format (PP=all GPUs)
44+
# Export quantized checkpoint to HF format
45+
# Use largest PP <= total GPUs that divides the model's num_hidden_layers
4546
TOTAL_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo ${NUM_GPUS:-1})
46-
echo "=== Exporting ${MLM_MODEL_CFG} ${QUANT_CFG} (PP=${TOTAL_GPUS}) ==="
47+
EXPORT_PP=$(python3 -c "
48+
import json, os
49+
cfg = os.path.join('${HF_MODEL_CKPT}', 'config.json')
50+
n_layers = json.load(open(cfg)).get('num_hidden_layers', 1) if os.path.exists(cfg) else 1
51+
gpus = ${TOTAL_GPUS}
52+
pp = gpus
53+
while pp > 1 and n_layers % pp != 0:
54+
pp -= 1
55+
print(pp)
56+
" 2>/dev/null || echo ${TOTAL_GPUS})
57+
echo "=== Exporting ${MLM_MODEL_CFG} ${QUANT_CFG} (PP=${EXPORT_PP}, ${TOTAL_GPUS} GPUs) ==="
4758
export MLM_EXTRA_ARGS=
48-
TP=1 PP=${TOTAL_GPUS} EP=1 ETP=1 MLM_MODEL_CKPT=${MLM_MODEL_SAVE} ${EXPORT_EXE} ${MLM_MODEL_CFG}
59+
TP=1 PP=${EXPORT_PP} EP=1 ETP=1 MLM_MODEL_CKPT=${MLM_MODEL_SAVE} ${EXPORT_EXE} ${MLM_MODEL_CFG}
4960
ls ${EXPORT_DIR}
5061
cat ${EXPORT_DIR}/hf_quant_config.json
5162

tools/launcher/examples/Qwen/Qwen3-8B/megatron_lm_ptq.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pipeline:
2424
config:
2525
model: Qwen/Qwen3-8B
2626
quant_cfg: NVFP4_DEFAULT_CFG
27-
tp: 8
27+
tp: 1
2828
calib_dataset: abisee/cnn_dailymail
2929
calib_size: 32
3030
mmlu_dataset: cais/mmlu
@@ -33,15 +33,15 @@ pipeline:
3333
slurm_config:
3434
_factory_: "slurm_factory"
3535
nodes: 1
36-
ntasks_per_node: 8
37-
gpus_per_node: 8
36+
ntasks_per_node: 1
37+
gpus_per_node: 1
3838

3939
task_1:
4040
_target_: common.megatron_lm.quantize.task.MegatronLMQuantizeTask
4141
config:
4242
model: Qwen/Qwen3-8B
4343
quant_cfg: FP8_DEFAULT_CFG
44-
tp: 8
44+
tp: 1
4545
calib_dataset: abisee/cnn_dailymail
4646
calib_size: 32
4747
mmlu_dataset: cais/mmlu
@@ -50,18 +50,18 @@ pipeline:
5050
slurm_config:
5151
_factory_: "slurm_factory"
5252
nodes: 1
53-
ntasks_per_node: 8
54-
gpus_per_node: 8
53+
ntasks_per_node: 1
54+
gpus_per_node: 1
5555

5656
# Step 3: TRT-LLM eval MMLU on all exported checkpoints
5757
task_2:
5858
script: common/tensorrt_llm/eval.sh
5959
environment:
6060
- HF_MODEL_CKPT: /scratchspace/export
61-
- TP: "8"
61+
- TP: "1"
6262
- EP: "1"
6363
slurm_config:
6464
_factory_: "slurm_factory"
6565
nodes: 1
6666
ntasks_per_node: 1
67-
gpus_per_node: 8
67+
gpus_per_node: 1

0 commit comments

Comments
 (0)