@@ -73,76 +73,6 @@ class GRPOBatch:
7373 prompt_lengths : list [int ]
7474 rewards : torch .Tensor
7575
76- def clip_outputs_at_think_token (rollouts , tokenizer , think_tokens , answer_len = 64 ):
77- """
78- Clip token_ids and logprobs at the </think> token sequence start,
79- and recompute the text from clipped tokens.
80-
81- Args:
82- rollouts: List of rollout objects
83- tokenizer: Tokenizer instance
84- think_tokens: List of token IDs for </think>
85- answer_len: Number of tokens to keep after </think>
86-
87- Returns:
88- List of modified rollout objects
89- """
90- ret = []
91- for rollout in rollouts :
92- clipped_outputs = []
93-
94- for output in rollout .outputs :
95- # Find the position where </think> tokens start
96- think_token_len = len (think_tokens )
97- clip_index = None
98-
99- # Search for the think tokens sequence in token_ids
100- for i in range (len (output .token_ids ) - think_token_len + 1 ):
101- if output .token_ids [i :i + think_token_len ] == think_tokens :
102- clip_index = i + answer_len
103- break
104-
105- if clip_index is not None :
106- # Clip token_ids and logprobs
107- clipped_token_ids = output .token_ids [:clip_index ]
108- clipped_logprobs = output .logprobs [:clip_index ]
109-
110- # Recompute text from clipped tokens
111- clipped_text = tokenizer .decode (clipped_token_ids )
112-
113- # Recalculate cumulative_logprob from clipped logprobs
114- cumulative_logprob = 0.0
115- for logprob_dict in clipped_logprobs :
116- # Get the first token's logprob (the selected token)
117- first_token_id = list (logprob_dict .keys ())[0 ]
118- cumulative_logprob += logprob_dict [first_token_id ].logprob
119-
120- # Create new CompletionOutput with clipped data
121- clipped_output = type (output )(
122- index = output .index ,
123- text = clipped_text ,
124- token_ids = clipped_token_ids ,
125- cumulative_logprob = cumulative_logprob ,
126- logprobs = clipped_logprobs ,
127- finish_reason = output .finish_reason ,
128- stop_reason = output .stop_reason
129- )
130- clipped_outputs .append (clipped_output )
131- else :
132- # If </think> not found, keep original output
133- clipped_outputs .append (output )
134-
135- # *** FIX: Create new rollout object with clipped outputs ***
136- clipped_rollout = type (rollout )(
137- outputs = clipped_outputs ,
138- # Copy other attributes from original rollout
139- ** {k : v for k , v in vars (rollout ).items () if k != 'outputs' }
140- )
141- ret .append (clipped_rollout )
142-
143- return ret
144-
145-
14676def clip_outputs_after_think_token (rollouts , tokenizer , think_tokens , num_tokens ):
14777 """
14878 Clip token_ids and logprobs to keep only num_tokens after the </think> token sequence ends,
0 commit comments