Skip to content

Commit 36a88e0

Browse files
logprobs fixed
1 parent a9b98b9 commit 36a88e0

File tree

4 files changed

+102
-35
lines changed

4 files changed

+102
-35
lines changed

weboperator/models/azure_openai.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def wrapper(*args, **kwargs): # type: ignore
4343
return func(*args, **kwargs)
4444
# Retry on specified errors
4545
except Exception as e:
46-
print(f"Error {e}")
46+
print(f"Error {type(e)} -> {e} ")
4747
# Increment retries
4848
num_retries += 1
4949

@@ -66,24 +66,71 @@ def chat(self, messages: list[dict], **kwargs) -> str | list[str]:
6666
Chat completion using the chat/completions endpoint.
6767
Supports multi-modal inputs (text + images) for vision models.
6868
"""
69-
response = AzureClient.chat.completions.create(
70-
model=self.name,
71-
messages=messages,
72-
max_tokens=self.max_tokens,
73-
temperature=self.temperature,
74-
top_p=self.top_p,
75-
# reasoning_effort=self.reasoning_effort,
76-
n=kwargs.get("n", self.n),
77-
logprobs=True,
78-
top_logprobs=10,
79-
)
69+
try:
70+
if "gpt-5" in self.name:
71+
# For gpt-5 models, we might want to set different parameters
72+
response = AzureClient.chat.completions.create(
73+
model=self.name,
74+
messages=messages,
75+
# max_completion_tokens=self.max_tokens,
76+
temperature=self.temperature,
77+
top_p=self.top_p,
78+
n=kwargs.get("n", self.n),
79+
)
80+
elif "o4" in self.name:
81+
response = AzureClient.chat.completions.create(
82+
model=self.name,
83+
messages=messages,
84+
max_tokens=self.max_tokens,
85+
reasoning_effort="medium",
86+
stream=False,
87+
)
88+
else:
89+
if kwargs.get("logprobs", False):
90+
response = AzureClient.chat.completions.create(
91+
model=self.name,
92+
messages=messages,
93+
# max_tokens=self.max_tokens,
94+
temperature=self.temperature,
95+
top_p=self.top_p,
96+
# reasoning_effort=self.reasoning_effort,
97+
n=kwargs.get("n", self.n),
98+
logprobs=True,
99+
top_logprobs=10,
100+
)
101+
else:
102+
response = AzureClient.chat.completions.create(
103+
model=self.name,
104+
messages=messages,
105+
# max_tokens=self.max_tokens,
106+
temperature=self.temperature,
107+
top_p=self.top_p,
108+
# reasoning_effort=self.reasoning_effort,
109+
n=kwargs.get("n", self.n),
110+
)
111+
except openai.BadRequestError as e:
112+
print(f"BadRequestError: {e}")
113+
return "", []
114+
except Exception as e:
115+
raise e
80116

81117
if len(response.choices) == 0:
82118
raise ValueError("No choices returned from the model.")
119+
120+
predictions = [
121+
choice.message.content.strip()
122+
for choice in response.choices
123+
if choice.message.content.strip()
124+
]
125+
126+
if len(predictions) == 0:
127+
raise ValueError("No valid predictions returned from the model.")
83128

84129
top_logprobs = [
85130
choice.logprobs.content
86131
for choice in response.choices
87132
if hasattr(choice, "logprobs") and choice.logprobs
88133
]
89-
return response.choices[0].message.content.strip(), top_logprobs[0]
134+
if kwargs.get("logprobs", False):
135+
return predictions[0], top_logprobs[0]
136+
return predictions[0], []

weboperator/models/openhf.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,26 @@ def chat(self, messages: list[dict], **kwargs) -> str:
8282
except requests.exceptions.RequestException as e:
8383
raise ConnectionError(f"Could not connect to HUGGING_FACE_API_SERVER: {e}")
8484

85-
response = OpenHFClient.chat.completions.create(
86-
model=self.name,
87-
messages=messages,
88-
# max_tokens=self.max_tokens,
89-
temperature=self.temperature,
90-
top_p=self.top_p,
91-
n=kwargs.get("n", self.n),
92-
logprobs=True,
93-
top_logprobs=10,
94-
)
85+
if kwargs.get("logprobs", False):
86+
response = OpenHFClient.chat.completions.create(
87+
model=self.name,
88+
messages=messages,
89+
# max_tokens=self.max_tokens,
90+
temperature=self.temperature,
91+
top_p=self.top_p,
92+
n=kwargs.get("n", self.n),
93+
logprobs=True,
94+
top_logprobs=10,
95+
)
96+
else:
97+
response = OpenHFClient.chat.completions.create(
98+
model=self.name,
99+
messages=messages,
100+
# max_tokens=self.max_tokens,
101+
temperature=self.temperature,
102+
top_p=self.top_p,
103+
n=kwargs.get("n", self.n),
104+
)
95105

96106
# Raise OpenHFError if we get invalid response to trigger retry
97107
if not response or not hasattr(response, "choices") or not response.choices:

weboperator/models/openrouter.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,26 @@ def chat(self, messages: list[dict], **kwargs) -> str:
106106
Chat completion using the chat/completions endpoint.
107107
Supports multi-modal inputs (text + images) for vision models.
108108
"""
109-
response = self.client.chat.completions.create(
110-
model=self.name,
111-
messages=messages,
112-
# max_tokens=self.max_tokens,
113-
temperature=self.temperature,
114-
top_p=self.top_p,
115-
n=kwargs.get("n", self.n),
116-
logprobs=True,
117-
top_logprobs=10,
118-
)
109+
if kwargs.get("logprobs", False):
110+
response = self.client.chat.completions.create(
111+
model=self.name,
112+
messages=messages,
113+
# max_tokens=self.max_tokens,
114+
temperature=self.temperature,
115+
top_p=self.top_p,
116+
n=kwargs.get("n", self.n),
117+
logprobs=True,
118+
top_logprobs=10,
119+
)
120+
else:
121+
response = self.client.chat.completions.create(
122+
model=self.name,
123+
messages=messages,
124+
# max_tokens=self.max_tokens,
125+
temperature=self.temperature,
126+
top_p=self.top_p,
127+
n=kwargs.get("n", self.n),
128+
)
119129
# print(response.choices[0])
120130
usage = getattr(response, "usage", None)
121131
if usage:

weboperator/webprm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _get_brief_context(trajectory: List[Dict[str, Any]]) -> str:
201201
if len(trajectory) < 2:
202202
return "No previous actions."
203203

204-
# Show trajectory[-4:-2] for context (2 steps before current)
204+
# Show up to the last 10 prior steps (excluding current)
205205
context_steps = trajectory[-10:-1] if len(trajectory) >= 10 else trajectory[:-1]
206206
# print(f"Context steps for evaluation: {context_steps}")
207207
if not context_steps:
@@ -641,7 +641,7 @@ def evaluate(
641641
# print("USER PROMPT: ", message[-1]["content"])
642642

643643
for _ in range(3): # Try up to 3 times to get a valid answer
644-
response, scores = self.reward_model.chat(message)
644+
response, scores = self.reward_model.chat(message, logprobs=True)
645645
generated_text = response
646646
if "# Checklist Evaluation" in generated_text:
647647
# print(scores)

0 commit comments

Comments
 (0)