Skip to content

Commit 3b8b885

Browse files
authored
Handle tar file extraction errors (#58)
Closes #33 Closes #40
1 parent 09ff1c5 commit 3b8b885

File tree

5 files changed

+116
-41
lines changed

5 files changed

+116
-41
lines changed

sagemaker_shim/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@
3636
@asynccontextmanager
3737
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
3838
async with get_s3_resources() as s3_resources:
39-
async with AuxiliaryData(s3_resources=s3_resources):
39+
auxiliary_data = AuxiliaryData(s3_resources=s3_resources)
40+
await auxiliary_data.setup()
41+
42+
try:
4043
yield
44+
finally:
45+
await auxiliary_data.teardown()
4146

4247

4348
app = FastAPI(lifespan=lifespan)

sagemaker_shim/cli.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic import ValidationError
1515

1616
from sagemaker_shim.app import app
17+
from sagemaker_shim.exceptions import UserSafeError
1718
from sagemaker_shim.logging import LOGGING_CONFIG
1819
from sagemaker_shim.models import (
1920
AuxiliaryData,
@@ -86,7 +87,16 @@ async def invoke(tasks: str, file: str) -> None:
8687
tasks=tasks, file=file, s3_resources=s3_resources
8788
)
8889

89-
async with AuxiliaryData(s3_resources=s3_resources):
90+
auxiliary_data = AuxiliaryData(s3_resources=s3_resources)
91+
92+
try:
93+
try:
94+
await auxiliary_data.setup()
95+
except* UserSafeError as exception_group:
96+
for exception in exception_group.exceptions:
97+
logger.error(msg=str(exception), extra={"internal": False})
98+
raise SystemExit(1) from exception_group
99+
90100
for task in parsed_tasks.root:
91101
# Only run one task at a time
92102
result = await task.invoke(s3_resources=s3_resources)
@@ -96,9 +106,11 @@ async def invoke(tasks: str, file: str) -> None:
96106
logger.error(
97107
f"Stopping due to failure of task {result.pk}"
98108
)
99-
raise SystemExit(0)
109+
raise SystemExit(result.return_code)
100110

101-
logger.info("Model invocation complete")
111+
logger.info("Model invocation complete")
112+
finally:
113+
await auxiliary_data.teardown()
102114

103115

104116
async def _parse_tasks(

sagemaker_shim/models.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from importlib.metadata import version
2323
from pathlib import Path
2424
from tempfile import SpooledTemporaryFile, TemporaryDirectory
25-
from types import TracebackType
2625
from typing import TYPE_CHECKING, Any, NamedTuple
2726
from zipfile import BadZipFile
2827

@@ -309,8 +308,14 @@ async def download_and_extract_tarball(
309308

310309
f.seek(0)
311310

312-
with ProcUserTarfile.open(fileobj=f, mode="r") as tar:
313-
tar.extractall(path=dest, filter="data")
311+
try:
312+
with ProcUserTarfile.open(fileobj=f, mode="r") as tar:
313+
tar.extractall(path=dest, filter="data")
314+
except (tarfile.TarError, FileNotFoundError) as error:
315+
logger.error(
316+
f"Tarfile could not be extracted: {error}", exc_info=error
317+
)
318+
raise UserSafeError("Tarfile could not be extracted") from error
314319

315320

316321
class AuxiliaryData:
@@ -378,23 +383,24 @@ def post_clean_directories(self) -> list[Path]:
378383
logger.debug(f"{post_clean_directories=}")
379384
return post_clean_directories
380385

381-
async def __aenter__(self) -> "AuxiliaryData":
386+
async def setup(self) -> None:
382387
logger.info("Setting up Auxiliary Data")
383388

384389
self.ensure_directories_are_writable()
385390

386-
async with asyncio.TaskGroup() as task_group:
387-
task_group.create_task(self.download_model())
388-
task_group.create_task(self.download_ground_truth())
391+
try:
392+
await self.download_model()
393+
except UserSafeError as error:
394+
raise UserSafeError(f"Could not setup model: {error}") from error
389395

390-
return self
396+
try:
397+
await self.download_ground_truth()
398+
except UserSafeError as error:
399+
raise UserSafeError(
400+
f"Could not setup ground truth: {error}"
401+
) from error
391402

392-
async def __aexit__(
393-
self,
394-
exc_type: type[BaseException] | None,
395-
exc_val: BaseException | None,
396-
exc_tb: TracebackType | None,
397-
) -> None:
403+
async def teardown(self) -> None:
398404
logger.info("Cleaning up Auxiliary Data")
399405
for p in self.post_clean_directories:
400406
logger.info(f"Cleaning {p=}")

tests/test_cli.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import json
3+
import os
34
import resource
45
from unittest.mock import patch
56
from uuid import uuid4
@@ -192,7 +193,7 @@ def test_bad_command_inference_from_task_list(minio, monkeypatch):
192193
f'{{"log": "Stopping due to failure of task {pk1}", "level": "ERROR", '
193194
'"source": "stderr", "internal": true, "task": null}' in result.output
194195
)
195-
assert result.exit_code == 0
196+
assert result.exit_code == 1
196197

197198

198199
def test_good_command_inference_from_s3_uri(minio, monkeypatch):
@@ -317,7 +318,7 @@ def test_bad_command_inference_from_s3_uri(minio, monkeypatch):
317318
f'{{"log": "Stopping due to failure of task {pk1}", "level": "ERROR", '
318319
'"source": "stderr", "internal": true, "task": null}' in result.output
319320
)
320-
assert result.exit_code == 0
321+
assert result.exit_code == 1
321322

322323

323324
def test_logging_setup(minio, monkeypatch):
@@ -440,3 +441,50 @@ def test_memory_limit_defined(minio, monkeypatch):
440441
'{"log": "Setting memory limit to 1337 MB", "level": "INFO", '
441442
'"source": "stdout", "internal": true, "task": null}'
442443
) in result.output
444+
445+
446+
def test_aux_data_failure(minio, monkeypatch, tmp_path):
447+
pk = str(uuid4())
448+
prefix = f"tasks/{pk}"
449+
model_key = f"{prefix}/sub/dodgy.tar"
450+
model_destination = tmp_path / "model"
451+
tasks = [
452+
{
453+
"pk": pk,
454+
"inputs": [],
455+
"output_bucket_name": minio.output_bucket_name,
456+
"output_prefix": prefix,
457+
"timeout": "PT10S",
458+
}
459+
]
460+
461+
monkeypatch.setenv(
462+
"GRAND_CHALLENGE_COMPONENT_CMD_B64J",
463+
encode_b64j(val=["echo", "hello"]),
464+
)
465+
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_SET_EXTRA_GROUPS", "False")
466+
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_USE_LINKED_INPUT", "False")
467+
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_WRITABLE_DIRECTORIES", "")
468+
monkeypatch.setenv(
469+
"GRAND_CHALLENGE_COMPONENT_MODEL",
470+
f"s3://{minio.input_bucket_name}/{model_key}",
471+
)
472+
monkeypatch.setenv(
473+
"GRAND_CHALLENGE_COMPONENT_MODEL_DEST", str(model_destination)
474+
)
475+
476+
sync_s3_operation(
477+
method=s3_upload_fileobj,
478+
Fileobj=io.BytesIO(os.urandom(8)),
479+
Bucket=minio.input_bucket_name,
480+
Key=model_key,
481+
)
482+
483+
runner = CliRunner()
484+
result = runner.invoke(cli, ["invoke", "-t", json.dumps(tasks)])
485+
486+
assert result.exit_code == 1
487+
assert result.stderr.splitlines()[-1] == (
488+
'{"log": "Could not setup model: Tarfile could not be extracted", '
489+
'"level": "ERROR", "source": "stderr", "internal": false, "task": null}'
490+
)

tests/test_models.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,14 @@ async def test_model_and_ground_truth_extraction(
395395
f"{ground_truth_pk}/ground_truth.tar.gz",
396396
)
397397

398-
async with AuxiliaryData(s3_resources=s3_resources):
399-
downloaded_files = {
400-
str(f.relative_to(tmp_path))
401-
for f in tmp_path.rglob("**/*")
402-
if f.is_file()
403-
}
398+
auxiliary_data = AuxiliaryData(s3_resources=s3_resources)
399+
await auxiliary_data.setup()
400+
downloaded_files = {
401+
str(f.relative_to(tmp_path))
402+
for f in tmp_path.rglob("**/*")
403+
if f.is_file()
404+
}
405+
await auxiliary_data.teardown()
404406

405407
assert downloaded_files == {
406408
"model/model-file1.txt",
@@ -425,17 +427,18 @@ async def test_ensure_directories_are_writable_unset(monkeypatch):
425427
monkeypatch.setenv("GRAND_CHALLENGE_COMPONENT_POST_CLEAN_DIRECTORIES", "")
426428

427429
async with get_s3_resources() as s3_resources:
428-
async with AuxiliaryData(s3_resources=s3_resources) as d:
429-
assert d.writable_directories == []
430-
assert d.post_clean_directories == []
431-
assert d.model_source is None
432-
assert d.model_dest == Path("/opt/ml/model")
433-
assert not d.model_dest.exists()
434-
assert d.ground_truth_source is None
435-
assert d.ground_truth_dest == Path(
436-
"/opt/ml/input/data/ground_truth"
437-
)
438-
assert not d.ground_truth_dest.exists()
430+
auxiliary_data = AuxiliaryData(s3_resources=s3_resources)
431+
432+
assert auxiliary_data.writable_directories == []
433+
assert auxiliary_data.post_clean_directories == []
434+
assert auxiliary_data.model_source is None
435+
assert auxiliary_data.model_dest == Path("/opt/ml/model")
436+
assert not auxiliary_data.model_dest.exists()
437+
assert auxiliary_data.ground_truth_source is None
438+
assert auxiliary_data.ground_truth_dest == Path(
439+
"/opt/ml/input/data/ground_truth"
440+
)
441+
assert not auxiliary_data.ground_truth_dest.exists()
439442

440443

441444
@pytest.mark.asyncio
@@ -456,8 +459,8 @@ async def test_ensure_directories_are_writable_set(
456459
)
457460

458461
async with get_s3_resources() as s3_resources:
459-
async with AuxiliaryData(s3_resources=s3_resources) as d:
460-
assert d.writable_directories == expected
462+
auxiliary_data = AuxiliaryData(s3_resources=s3_resources)
463+
assert auxiliary_data.writable_directories == expected
461464

462465

463466
@pytest.mark.asyncio
@@ -480,8 +483,9 @@ async def test_ensure_directories_are_writable(tmp_path, monkeypatch):
480483
)
481484

482485
async with get_s3_resources() as s3_resources:
483-
async with AuxiliaryData(s3_resources=s3_resources):
484-
pass
486+
auxiliary_data = AuxiliaryData(s3_resources=s3_resources)
487+
await auxiliary_data.setup()
488+
await auxiliary_data.teardown()
485489

486490
assert data.stat().st_mode == 0o40777
487491
assert model.stat().st_mode == 0o40777

0 commit comments

Comments
 (0)