@@ -166,6 +166,7 @@ def _trigger_repository_dispatch(event_type: str = "calibration-updated"):
166166def _fit_weights_impl (
167167 branch : str ,
168168 epochs : int ,
169+ output_prefix : str = "" ,
169170 target_config : str = None ,
170171 beta : float = None ,
171172 lambda_l0 : float = None ,
@@ -183,6 +184,7 @@ def _fit_weights_impl(
183184 artifacts = artifacts_dir if artifacts_dir else f"{ PIPELINE_MOUNT } /artifacts"
184185 db_path = f"{ artifacts } /policy_data.db"
185186 dataset_path = f"{ artifacts } /source_imputed_stratified_extended_cps.h5"
187+ checkpoint_path = f"{ artifacts } /{ output_prefix } calibration_checkpoint.pt"
186188 for label , p in [("database" , db_path ), ("dataset" , dataset_path )]:
187189 if not os .path .exists (p ):
188190 raise RuntimeError (
@@ -203,7 +205,11 @@ def _fit_weights_impl(
203205 db_path ,
204206 "--dataset" ,
205207 dataset_path ,
208+ "--checkpoint-output" ,
209+ checkpoint_path ,
206210 ]
211+ if os .path .exists (checkpoint_path ):
212+ cmd .extend (["--resume-from" , checkpoint_path ])
207213 if target_config :
208214 cmd .extend (["--target-config" , target_config ])
209215 if not skip_county :
@@ -212,11 +218,15 @@ def _fit_weights_impl(
212218 cmd .extend (["--workers" , str (workers )])
213219 _append_hyperparams (cmd , beta , lambda_l0 , lambda_l2 , learning_rate , log_freq )
214220
215- cal_rc , cal_lines = _run_streaming (
216- cmd ,
217- env = os .environ .copy (),
218- label = "calibrate" ,
219- )
221+ try :
222+ cal_rc , cal_lines = _run_streaming (
223+ cmd ,
224+ env = os .environ .copy (),
225+ label = "calibrate" ,
226+ )
227+ finally :
228+ if os .path .exists (checkpoint_path ):
229+ pipeline_vol .commit ()
220230 if cal_rc != 0 :
221231 raise RuntimeError (f"Script failed with code { cal_rc } " )
222232
@@ -277,15 +287,17 @@ def _fit_from_package_impl(
277287
278288 print (f"Running command: { ' ' .join (cmd )} " , flush = True )
279289
280- cal_rc , cal_lines = _run_streaming (
281- cmd ,
282- env = os .environ .copy (),
283- label = "calibrate" ,
284- )
290+ try :
291+ cal_rc , cal_lines = _run_streaming (
292+ cmd ,
293+ env = os .environ .copy (),
294+ label = "calibrate" ,
295+ )
296+ finally :
297+ if os .path .exists (checkpoint_path ):
298+ pipeline_vol .commit ()
285299 if cal_rc != 0 :
286300 raise RuntimeError (f"Script failed with code { cal_rc } " )
287-
288- pipeline_vol .commit ()
289301 return _collect_outputs (cal_lines )
290302
291303
@@ -511,6 +523,7 @@ def check_volume_package(artifacts_dir: str = "") -> dict:
511523def fit_weights_t4 (
512524 branch : str = "main" ,
513525 epochs : int = 200 ,
526+ output_prefix : str = "" ,
514527 target_config : str = None ,
515528 beta : float = None ,
516529 lambda_l0 : float = None ,
@@ -522,14 +535,15 @@ def fit_weights_t4(
522535 artifacts_dir : str = "" ,
523536) -> dict :
524537 return _fit_weights_impl (
525- branch ,
526- epochs ,
527- target_config ,
528- beta ,
529- lambda_l0 ,
530- lambda_l2 ,
531- learning_rate ,
532- log_freq ,
538+ branch = branch ,
539+ epochs = epochs ,
540+ output_prefix = output_prefix ,
541+ target_config = target_config ,
542+ beta = beta ,
543+ lambda_l0 = lambda_l0 ,
544+ lambda_l2 = lambda_l2 ,
545+ learning_rate = learning_rate ,
546+ log_freq = log_freq ,
533547 skip_county = skip_county ,
534548 workers = workers ,
535549 artifacts_dir = artifacts_dir ,
@@ -548,6 +562,7 @@ def fit_weights_t4(
548562def fit_weights_a10 (
549563 branch : str = "main" ,
550564 epochs : int = 200 ,
565+ output_prefix : str = "" ,
551566 target_config : str = None ,
552567 beta : float = None ,
553568 lambda_l0 : float = None ,
@@ -559,14 +574,15 @@ def fit_weights_a10(
559574 artifacts_dir : str = "" ,
560575) -> dict :
561576 return _fit_weights_impl (
562- branch ,
563- epochs ,
564- target_config ,
565- beta ,
566- lambda_l0 ,
567- lambda_l2 ,
568- learning_rate ,
569- log_freq ,
577+ branch = branch ,
578+ epochs = epochs ,
579+ output_prefix = output_prefix ,
580+ target_config = target_config ,
581+ beta = beta ,
582+ lambda_l0 = lambda_l0 ,
583+ lambda_l2 = lambda_l2 ,
584+ learning_rate = learning_rate ,
585+ log_freq = log_freq ,
570586 skip_county = skip_county ,
571587 workers = workers ,
572588 artifacts_dir = artifacts_dir ,
@@ -585,6 +601,7 @@ def fit_weights_a10(
585601def fit_weights_a100_40 (
586602 branch : str = "main" ,
587603 epochs : int = 200 ,
604+ output_prefix : str = "" ,
588605 target_config : str = None ,
589606 beta : float = None ,
590607 lambda_l0 : float = None ,
@@ -596,14 +613,15 @@ def fit_weights_a100_40(
596613 artifacts_dir : str = "" ,
597614) -> dict :
598615 return _fit_weights_impl (
599- branch ,
600- epochs ,
601- target_config ,
602- beta ,
603- lambda_l0 ,
604- lambda_l2 ,
605- learning_rate ,
606- log_freq ,
616+ branch = branch ,
617+ epochs = epochs ,
618+ output_prefix = output_prefix ,
619+ target_config = target_config ,
620+ beta = beta ,
621+ lambda_l0 = lambda_l0 ,
622+ lambda_l2 = lambda_l2 ,
623+ learning_rate = learning_rate ,
624+ log_freq = log_freq ,
607625 skip_county = skip_county ,
608626 workers = workers ,
609627 artifacts_dir = artifacts_dir ,
@@ -622,6 +640,7 @@ def fit_weights_a100_40(
622640def fit_weights_a100_80 (
623641 branch : str = "main" ,
624642 epochs : int = 200 ,
643+ output_prefix : str = "" ,
625644 target_config : str = None ,
626645 beta : float = None ,
627646 lambda_l0 : float = None ,
@@ -633,14 +652,15 @@ def fit_weights_a100_80(
633652 artifacts_dir : str = "" ,
634653) -> dict :
635654 return _fit_weights_impl (
636- branch ,
637- epochs ,
638- target_config ,
639- beta ,
640- lambda_l0 ,
641- lambda_l2 ,
642- learning_rate ,
643- log_freq ,
655+ branch = branch ,
656+ epochs = epochs ,
657+ output_prefix = output_prefix ,
658+ target_config = target_config ,
659+ beta = beta ,
660+ lambda_l0 = lambda_l0 ,
661+ lambda_l2 = lambda_l2 ,
662+ learning_rate = learning_rate ,
663+ log_freq = log_freq ,
644664 skip_county = skip_county ,
645665 workers = workers ,
646666 artifacts_dir = artifacts_dir ,
@@ -659,6 +679,7 @@ def fit_weights_a100_80(
659679def fit_weights_h100 (
660680 branch : str = "main" ,
661681 epochs : int = 200 ,
682+ output_prefix : str = "" ,
662683 target_config : str = None ,
663684 beta : float = None ,
664685 lambda_l0 : float = None ,
@@ -670,14 +691,15 @@ def fit_weights_h100(
670691 artifacts_dir : str = "" ,
671692) -> dict :
672693 return _fit_weights_impl (
673- branch ,
674- epochs ,
675- target_config ,
676- beta ,
677- lambda_l0 ,
678- lambda_l2 ,
679- learning_rate ,
680- log_freq ,
694+ branch = branch ,
695+ epochs = epochs ,
696+ output_prefix = output_prefix ,
697+ target_config = target_config ,
698+ beta = beta ,
699+ lambda_l0 = lambda_l0 ,
700+ lambda_l2 = lambda_l2 ,
701+ learning_rate = learning_rate ,
702+ log_freq = log_freq ,
681703 skip_county = skip_county ,
682704 workers = workers ,
683705 artifacts_dir = artifacts_dir ,
0 commit comments