Skip to content

Commit 9090bd5

Browse files
unused
1 parent 6e9dc7f commit 9090bd5

File tree

1 file changed

+0
-70
lines changed
  • src/fairseq2/recipes/lm/_online_finetune

1 file changed

+0
-70
lines changed

src/fairseq2/recipes/lm/_online_finetune/_grpo.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -73,76 +73,6 @@ class GRPOBatch:
7373
prompt_lengths: list[int]
7474
rewards: torch.Tensor
7575

76-
def clip_outputs_at_think_token(rollouts, tokenizer, think_tokens, answer_len=64):
77-
"""
78-
Clip token_ids and logprobs at the </think> token sequence start,
79-
and recompute the text from clipped tokens.
80-
81-
Args:
82-
rollouts: List of rollout objects
83-
tokenizer: Tokenizer instance
84-
think_tokens: List of token IDs for </think>
85-
answer_len: Number of tokens to keep after </think>
86-
87-
Returns:
88-
List of modified rollout objects
89-
"""
90-
ret = []
91-
for rollout in rollouts:
92-
clipped_outputs = []
93-
94-
for output in rollout.outputs:
95-
# Find the position where </think> tokens start
96-
think_token_len = len(think_tokens)
97-
clip_index = None
98-
99-
# Search for the think tokens sequence in token_ids
100-
for i in range(len(output.token_ids) - think_token_len + 1):
101-
if output.token_ids[i:i + think_token_len] == think_tokens:
102-
clip_index = i + answer_len
103-
break
104-
105-
if clip_index is not None:
106-
# Clip token_ids and logprobs
107-
clipped_token_ids = output.token_ids[:clip_index]
108-
clipped_logprobs = output.logprobs[:clip_index]
109-
110-
# Recompute text from clipped tokens
111-
clipped_text = tokenizer.decode(clipped_token_ids)
112-
113-
# Recalculate cumulative_logprob from clipped logprobs
114-
cumulative_logprob = 0.0
115-
for logprob_dict in clipped_logprobs:
116-
# Get the first token's logprob (the selected token)
117-
first_token_id = list(logprob_dict.keys())[0]
118-
cumulative_logprob += logprob_dict[first_token_id].logprob
119-
120-
# Create new CompletionOutput with clipped data
121-
clipped_output = type(output)(
122-
index=output.index,
123-
text=clipped_text,
124-
token_ids=clipped_token_ids,
125-
cumulative_logprob=cumulative_logprob,
126-
logprobs=clipped_logprobs,
127-
finish_reason=output.finish_reason,
128-
stop_reason=output.stop_reason
129-
)
130-
clipped_outputs.append(clipped_output)
131-
else:
132-
# If </think> not found, keep original output
133-
clipped_outputs.append(output)
134-
135-
# *** FIX: Create new rollout object with clipped outputs ***
136-
clipped_rollout = type(rollout)(
137-
outputs=clipped_outputs,
138-
# Copy other attributes from original rollout
139-
**{k: v for k, v in vars(rollout).items() if k != 'outputs'}
140-
)
141-
ret.append(clipped_rollout)
142-
143-
return ret
144-
145-
14676
def clip_outputs_after_think_token(rollouts, tokenizer, think_tokens, num_tokens):
14777
"""
14878
Clip token_ids and logprobs to keep only num_tokens after the </think> token sequence ends,

0 commit comments

Comments
 (0)