Skip to content

Commit 6acc3eb

Browse files
committed
Finish resume support across modal fit paths
1 parent 18cbd11 commit 6acc3eb

File tree

3 files changed

+237
-56
lines changed

3 files changed

+237
-56
lines changed

modal_app/remote_calibration_runner.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _trigger_repository_dispatch(event_type: str = "calibration-updated"):
166166
def _fit_weights_impl(
167167
branch: str,
168168
epochs: int,
169+
output_prefix: str = "",
169170
target_config: str = None,
170171
beta: float = None,
171172
lambda_l0: float = None,
@@ -183,6 +184,7 @@ def _fit_weights_impl(
183184
artifacts = artifacts_dir if artifacts_dir else f"{PIPELINE_MOUNT}/artifacts"
184185
db_path = f"{artifacts}/policy_data.db"
185186
dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5"
187+
checkpoint_path = f"{artifacts}/{output_prefix}calibration_checkpoint.pt"
186188
for label, p in [("database", db_path), ("dataset", dataset_path)]:
187189
if not os.path.exists(p):
188190
raise RuntimeError(
@@ -203,7 +205,11 @@ def _fit_weights_impl(
203205
db_path,
204206
"--dataset",
205207
dataset_path,
208+
"--checkpoint-output",
209+
checkpoint_path,
206210
]
211+
if os.path.exists(checkpoint_path):
212+
cmd.extend(["--resume-from", checkpoint_path])
207213
if target_config:
208214
cmd.extend(["--target-config", target_config])
209215
if not skip_county:
@@ -212,11 +218,15 @@ def _fit_weights_impl(
212218
cmd.extend(["--workers", str(workers)])
213219
_append_hyperparams(cmd, beta, lambda_l0, lambda_l2, learning_rate, log_freq)
214220

215-
cal_rc, cal_lines = _run_streaming(
216-
cmd,
217-
env=os.environ.copy(),
218-
label="calibrate",
219-
)
221+
try:
222+
cal_rc, cal_lines = _run_streaming(
223+
cmd,
224+
env=os.environ.copy(),
225+
label="calibrate",
226+
)
227+
finally:
228+
if os.path.exists(checkpoint_path):
229+
pipeline_vol.commit()
220230
if cal_rc != 0:
221231
raise RuntimeError(f"Script failed with code {cal_rc}")
222232

@@ -277,15 +287,17 @@ def _fit_from_package_impl(
277287

278288
print(f"Running command: {' '.join(cmd)}", flush=True)
279289

280-
cal_rc, cal_lines = _run_streaming(
281-
cmd,
282-
env=os.environ.copy(),
283-
label="calibrate",
284-
)
290+
try:
291+
cal_rc, cal_lines = _run_streaming(
292+
cmd,
293+
env=os.environ.copy(),
294+
label="calibrate",
295+
)
296+
finally:
297+
if os.path.exists(checkpoint_path):
298+
pipeline_vol.commit()
285299
if cal_rc != 0:
286300
raise RuntimeError(f"Script failed with code {cal_rc}")
287-
288-
pipeline_vol.commit()
289301
return _collect_outputs(cal_lines)
290302

291303

@@ -511,6 +523,7 @@ def check_volume_package(artifacts_dir: str = "") -> dict:
511523
def fit_weights_t4(
512524
branch: str = "main",
513525
epochs: int = 200,
526+
output_prefix: str = "",
514527
target_config: str = None,
515528
beta: float = None,
516529
lambda_l0: float = None,
@@ -522,14 +535,15 @@ def fit_weights_t4(
522535
artifacts_dir: str = "",
523536
) -> dict:
524537
return _fit_weights_impl(
525-
branch,
526-
epochs,
527-
target_config,
528-
beta,
529-
lambda_l0,
530-
lambda_l2,
531-
learning_rate,
532-
log_freq,
538+
branch=branch,
539+
epochs=epochs,
540+
output_prefix=output_prefix,
541+
target_config=target_config,
542+
beta=beta,
543+
lambda_l0=lambda_l0,
544+
lambda_l2=lambda_l2,
545+
learning_rate=learning_rate,
546+
log_freq=log_freq,
533547
skip_county=skip_county,
534548
workers=workers,
535549
artifacts_dir=artifacts_dir,
@@ -548,6 +562,7 @@ def fit_weights_t4(
548562
def fit_weights_a10(
549563
branch: str = "main",
550564
epochs: int = 200,
565+
output_prefix: str = "",
551566
target_config: str = None,
552567
beta: float = None,
553568
lambda_l0: float = None,
@@ -559,14 +574,15 @@ def fit_weights_a10(
559574
artifacts_dir: str = "",
560575
) -> dict:
561576
return _fit_weights_impl(
562-
branch,
563-
epochs,
564-
target_config,
565-
beta,
566-
lambda_l0,
567-
lambda_l2,
568-
learning_rate,
569-
log_freq,
577+
branch=branch,
578+
epochs=epochs,
579+
output_prefix=output_prefix,
580+
target_config=target_config,
581+
beta=beta,
582+
lambda_l0=lambda_l0,
583+
lambda_l2=lambda_l2,
584+
learning_rate=learning_rate,
585+
log_freq=log_freq,
570586
skip_county=skip_county,
571587
workers=workers,
572588
artifacts_dir=artifacts_dir,
@@ -585,6 +601,7 @@ def fit_weights_a10(
585601
def fit_weights_a100_40(
586602
branch: str = "main",
587603
epochs: int = 200,
604+
output_prefix: str = "",
588605
target_config: str = None,
589606
beta: float = None,
590607
lambda_l0: float = None,
@@ -596,14 +613,15 @@ def fit_weights_a100_40(
596613
artifacts_dir: str = "",
597614
) -> dict:
598615
return _fit_weights_impl(
599-
branch,
600-
epochs,
601-
target_config,
602-
beta,
603-
lambda_l0,
604-
lambda_l2,
605-
learning_rate,
606-
log_freq,
616+
branch=branch,
617+
epochs=epochs,
618+
output_prefix=output_prefix,
619+
target_config=target_config,
620+
beta=beta,
621+
lambda_l0=lambda_l0,
622+
lambda_l2=lambda_l2,
623+
learning_rate=learning_rate,
624+
log_freq=log_freq,
607625
skip_county=skip_county,
608626
workers=workers,
609627
artifacts_dir=artifacts_dir,
@@ -622,6 +640,7 @@ def fit_weights_a100_40(
622640
def fit_weights_a100_80(
623641
branch: str = "main",
624642
epochs: int = 200,
643+
output_prefix: str = "",
625644
target_config: str = None,
626645
beta: float = None,
627646
lambda_l0: float = None,
@@ -633,14 +652,15 @@ def fit_weights_a100_80(
633652
artifacts_dir: str = "",
634653
) -> dict:
635654
return _fit_weights_impl(
636-
branch,
637-
epochs,
638-
target_config,
639-
beta,
640-
lambda_l0,
641-
lambda_l2,
642-
learning_rate,
643-
log_freq,
655+
branch=branch,
656+
epochs=epochs,
657+
output_prefix=output_prefix,
658+
target_config=target_config,
659+
beta=beta,
660+
lambda_l0=lambda_l0,
661+
lambda_l2=lambda_l2,
662+
learning_rate=learning_rate,
663+
log_freq=log_freq,
644664
skip_county=skip_county,
645665
workers=workers,
646666
artifacts_dir=artifacts_dir,
@@ -659,6 +679,7 @@ def fit_weights_a100_80(
659679
def fit_weights_h100(
660680
branch: str = "main",
661681
epochs: int = 200,
682+
output_prefix: str = "",
662683
target_config: str = None,
663684
beta: float = None,
664685
lambda_l0: float = None,
@@ -670,14 +691,15 @@ def fit_weights_h100(
670691
artifacts_dir: str = "",
671692
) -> dict:
672693
return _fit_weights_impl(
673-
branch,
674-
epochs,
675-
target_config,
676-
beta,
677-
lambda_l0,
678-
lambda_l2,
679-
learning_rate,
680-
log_freq,
694+
branch=branch,
695+
epochs=epochs,
696+
output_prefix=output_prefix,
697+
target_config=target_config,
698+
beta=beta,
699+
lambda_l0=lambda_l0,
700+
lambda_l2=lambda_l2,
701+
learning_rate=learning_rate,
702+
log_freq=log_freq,
681703
skip_county=skip_county,
682704
workers=workers,
683705
artifacts_dir=artifacts_dir,

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,9 +1149,7 @@ def run_calibration(
11491149
"block_geoid": package.get("block_geoid"),
11501150
"base_n_records": package_base_n_records,
11511151
"n_clones": (
1152-
int(package_n_clones)
1153-
if package_n_clones is not None
1154-
else n_clones
1152+
int(package_n_clones) if package_n_clones is not None else n_clones
11551153
),
11561154
}
11571155
return (

0 commit comments

Comments
 (0)