Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ def _extract_step_from_wandb_artifact(artifact: "wandb.Artifact") -> int | None:
return None


def _wandb_checkpoint_collection_path(
*,
from_model: str,
from_project: str,
model_entity: str | None,
default_entity: str | None,
from_entity: str | None = None,
) -> str:
"""Build the W&B artifact collection path for a source checkpoint.

Resolves the entity from the explicit ``from_entity`` first, then the
destination model's entity, then the W&B default entity, so a checkpoint can
be forked from an entity other than the destination's.
"""
resolved_entity = from_entity or model_entity or default_entity
if resolved_entity is None:
raise ValueError("A W&B entity is required to locate the source checkpoint")
return f"{resolved_entity}/{from_project}/{from_model}"


_UPSTREAM_TRAIN_METRIC_KEYS = {
"reward": "reward",
"reward_std_dev": "reward_std_dev",
Expand Down Expand Up @@ -879,6 +899,7 @@ async def _experimental_fork_checkpoint(
model: "Model",
from_model: str,
from_project: str | None = None,
from_entity: str | None = None,
from_s3_bucket: str | None = None,
not_after_step: int | None = None,
verbose: bool = False,
Expand All @@ -897,6 +918,10 @@ async def _experimental_fork_checkpoint(
model: The destination model to fork to.
from_model: The name of the source model to fork from.
from_project: The project of the source model. Defaults to model.project.
from_entity: The W&B entity of the source model. Defaults to
model.entity, then the W&B API's default entity. Set this to fork
from a checkpoint that lives in a different entity than the
destination model.
from_s3_bucket: Optional S3 bucket to pull the checkpoint from.
not_after_step: If provided, uses the latest checkpoint <= this step.
verbose: Whether to print verbose output.
Expand Down Expand Up @@ -963,12 +988,17 @@ async def _experimental_fork_checkpoint(
else:
# Pull from W&B artifacts
api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute]
from_entity = model.entity or api.default_entity

# Iterate all artifact versions to find the best step.
# We avoid relying on the W&B `:latest` alias because it
# may not correspond to the highest training step.
collection_path = f"{from_entity}/{from_project}/{from_model}"
collection_path = _wandb_checkpoint_collection_path(
from_model=from_model,
from_project=from_project,
from_entity=from_entity,
model_entity=model.entity,
default_entity=api.default_entity,
)
versions = api.artifacts("lora", collection_path)

best_step: int | None = None
Expand Down
96 changes: 96 additions & 0 deletions tests/unit/test_serverless_fork_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Tests for cross-entity checkpoint forking (issue #649).

``_experimental_fork_checkpoint`` previously located the source checkpoint under
the *destination* model's entity, so forking from a checkpoint in another W&B
entity was impossible. These cover the new ``from_entity`` parameter and the
entity-resolution helper it flows through.
"""

import sys
from types import SimpleNamespace

import pytest

from art.serverless.backend import (
ServerlessBackend,
_wandb_checkpoint_collection_path,
)


def test_collection_path_prefers_explicit_from_entity():
path = _wandb_checkpoint_collection_path(
from_model="src-model",
from_project="src-project",
from_entity="src-entity",
model_entity="dst-entity",
default_entity="default-entity",
)
assert path == "src-entity/src-project/src-model"


def test_collection_path_falls_back_to_model_entity():
path = _wandb_checkpoint_collection_path(
from_model="src-model",
from_project="src-project",
from_entity=None,
model_entity="dst-entity",
default_entity="default-entity",
)
assert path == "dst-entity/src-project/src-model"


def test_collection_path_falls_back_to_default_entity():
path = _wandb_checkpoint_collection_path(
from_model="src-model",
from_project="src-project",
from_entity=None,
model_entity=None,
default_entity="default-entity",
)
assert path == "default-entity/src-project/src-model"


def test_collection_path_requires_an_entity():
with pytest.raises(ValueError, match="W&B entity"):
_wandb_checkpoint_collection_path(
from_model="src-model",
from_project="src-project",
from_entity=None,
model_entity=None,
default_entity=None,
)


@pytest.mark.asyncio
async def test_fork_checkpoint_queries_explicit_source_entity(monkeypatch):
"""An explicit from_entity must be used when querying W&B artifacts, even
when the destination model lives in a different entity."""
artifact_calls = []

class FakeApi:
default_entity = "default-entity"

def __init__(self, api_key):
assert api_key == "test-api-key"

def artifacts(self, artifact_type, collection_path):
artifact_calls.append((artifact_type, collection_path))
return [] # no versions -> method raises "No checkpoints found"

monkeypatch.setitem(sys.modules, "wandb", SimpleNamespace(Api=FakeApi))

backend = ServerlessBackend.__new__(ServerlessBackend)
backend._client = SimpleNamespace(api_key="test-api-key")
model = SimpleNamespace(
entity="dst-entity", project="dst-project", name="dst-model"
)

with pytest.raises(ValueError, match="No checkpoints found"):
await backend._experimental_fork_checkpoint(
model,
from_model="src-model",
from_project="src-project",
from_entity="src-entity",
)

assert artifact_calls == [("lora", "src-entity/src-project/src-model")]