-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrun_full_experiment.py
More file actions
98 lines (81 loc) · 3.67 KB
/
run_full_experiment.py
File metadata and controls
98 lines (81 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import annotations
import argparse
import csv
import json
import os
import shutil
from pathlib import Path
from lockr.runners.benchmark import BenchmarkRunner
from lockr.schemas import BenchmarkConfig
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run the full LOCK-R LM Studio experiment.")
parser.add_argument("--config", type=Path, default=Path("configs/pilot_20.json"))
parser.add_argument("--output", type=Path, default=Path("outputs"))
parser.add_argument("--smoke", action="store_true", help="Run a 2-episode API smoke check.")
return parser
def load_config(path: Path) -> BenchmarkConfig:
return BenchmarkConfig.model_validate(json.loads(path.read_text(encoding="utf-8")))
def prepare_config(config: BenchmarkConfig, *, smoke: bool) -> BenchmarkConfig:
config.agent.kind = "openai_compatible_json"
config.agent.proposal_generation_mode = "qwen_nonthinking_eval"
config.agent.verifier_generation_mode = "qwen_nonthinking_eval"
config.agent.repair_generation_mode = "qwen_nonthinking_eval"
config.parallel_workers = 4
config.regimes = ["same_model_locked_agent"] if smoke else ["same_model_locked_agent", "blind_checker"]
if smoke:
config.suite_name = "lmstudio_smoke"
config.episodes = config.episodes[:1]
else:
config.suite_name = "lmstudio_full_experiment"
return config
def copy_figures(*, suite_name: str, output_dir: Path) -> None:
source_dir = output_dir / "figures" / suite_name
target_dir = output_dir / "figures"
target_dir.mkdir(parents=True, exist_ok=True)
for filename in [
"posterior_trajectories.png",
"r_by_regime.png",
"k_c_by_regime.png",
"anchor_strength_effect.png",
]:
source_path = source_dir / filename
if source_path.exists():
shutil.copy2(source_path, target_dir / filename)
def write_aggregate_metrics_csv(output_dir: Path) -> None:
source = output_dir / "regime_comparison.csv"
target = output_dir / "aggregate_metrics.csv"
if source.exists():
shutil.copy2(source, target)
return
episode_metrics = output_dir / "episode_metrics.csv"
if not episode_metrics.exists():
return
rows: dict[str, list[dict[str, str]]] = {}
with episode_metrics.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
for row in reader:
rows.setdefault(row["regime"], []).append(row)
with target.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=["regime", "mean_r_mean", "mean_k_c", "mean_cear"])
writer.writeheader()
for regime, entries in sorted(rows.items()):
writer.writerow(
{
"regime": regime,
"mean_r_mean": sum(float(entry["r_mean"]) for entry in entries) / len(entries),
"mean_k_c": sum(float(entry["k_c"]) for entry in entries) / len(entries),
"mean_cear": sum(float(entry["cear"]) for entry in entries) / len(entries),
}
)
def main() -> None:
args = build_parser().parse_args()
os.environ.setdefault("OPENAI_API_KEY", "lm-studio")
config = prepare_config(load_config(args.config), smoke=args.smoke)
args.output.mkdir(parents=True, exist_ok=True)
runner = BenchmarkRunner(config=config, output_dir=args.output)
summary = runner.run()
copy_figures(suite_name=config.suite_name, output_dir=args.output)
write_aggregate_metrics_csv(args.output)
print(json.dumps(summary.model_dump(mode="json"), indent=2))
if __name__ == "__main__":
main()