Skip to content

General purpose sq.tl.align() function with STAlign core logic#1162

Open
timtreis wants to merge 7 commits intoscverse:mainfrom
timtreis:align-skeleton
Open

General purpose sq.tl.align() function with STAlign core logic#1162
timtreis wants to merge 7 commits intoscverse:mainfrom
timtreis:align-skeleton

Conversation

@timtreis
Copy link
Copy Markdown
Member

@timtreis timtreis commented Apr 16, 2026

Implements the following functions:

sq.experimental.tl.align_by_landmarks(
    sdata,
    cs_name_ref="section_a",
    cs_name_query="section_b",
    landmarks_ref=((x1, y1), (x2, y2), (x3, y3)),
    landmarks_query=((x1, y1), (x2, y2), (x3, y3)),
    model="similarity",  # or "affine" for 6-DOF
)

This registers a spatialdata Affine transformation on every element in cs_b, mapping it into cs_a. After the call, sdata is updated in-place, all elements in section_b now know how to transform into section_a's coordinate system.

result = sq.experimental.tl.align_obs(
    adata_ref,
    adata_query,
    flavour="stalign", # will later accept moscot
    inplace=False, # otherwise, creates new aligned adata in sdata
    output_mode="obs",
)

output_mode controls what you get back:

  • "obs": a new AnnData with baked-in aligned coordinates
  • "affine": registers a transform on the SpatialData element (only works if the backend produces an affine; LDDMM doesn't)
  • "return": raw AlignResult with the displacement field and full solver state for power users
sq.experimental.tl.align_images(
    ...
)

Currently dead but will be implemente via STalign. Already in here to better align function internals

Introduces a backend-agnostic alignment API (align_obs, align_images,
align_by_landmarks) with STalign and landmark backends, lazy JAX imports,
and e2e tests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 16, 2026

Codecov Report

❌ Patch coverage is 31.13855% with 502 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.54%. Comparing base (5b3c95d) to head (1d4c135).

Files with missing lines Patch % Lines
.../experimental/tl/_align/_backends/_stalign_core.py 0.00% 156 Missing ⚠️
...perimental/tl/_align/_backends/_stalign_helpers.py 0.00% 115 Missing ⚠️
...experimental/tl/_align/_backends/_stalign_tools.py 0.00% 92 Missing ⚠️
src/squidpy/experimental/tl/_align/_io.py 55.26% 43 Missing and 8 partials ⚠️
...uidpy/experimental/tl/_align/_backends/_stalign.py 36.66% 18 Missing and 1 partial ⚠️
src/squidpy/experimental/tl/_align/_api.py 69.23% 12 Missing and 4 partials ⚠️
src/squidpy/experimental/tl/_align/_types.py 70.37% 14 Missing and 2 partials ⚠️
...quidpy/experimental/tl/_align/_backends/_moscot.py 0.00% 11 Missing ⚠️
.../squidpy/experimental/tl/_align/_backends/_base.py 0.00% 10 Missing ⚠️
...idpy/experimental/tl/_align/_backends/_landmark.py 76.47% 5 Missing and 3 partials ⚠️
... and 2 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1162      +/-   ##
==========================================
- Coverage   73.56%   69.54%   -4.02%     
==========================================
  Files          44       56      +12     
  Lines        6929     7658     +729     
  Branches     1174     1272      +98     
==========================================
+ Hits         5097     5326     +229     
- Misses       1347     1825     +478     
- Partials      485      507      +22     
Files with missing lines Coverage Δ
src/squidpy/experimental/tl/_align/_jax.py 88.88% <88.88%> (ø)
src/squidpy/experimental/tl/_align/_validation.py 86.53% <86.53%> (ø)
...idpy/experimental/tl/_align/_backends/_landmark.py 76.47% <76.47%> (ø)
.../squidpy/experimental/tl/_align/_backends/_base.py 0.00% <0.00%> (ø)
...quidpy/experimental/tl/_align/_backends/_moscot.py 0.00% <0.00%> (ø)
src/squidpy/experimental/tl/_align/_api.py 69.23% <69.23%> (ø)
src/squidpy/experimental/tl/_align/_types.py 70.37% <70.37%> (ø)
...uidpy/experimental/tl/_align/_backends/_stalign.py 36.66% <36.66%> (ø)
src/squidpy/experimental/tl/_align/_io.py 55.26% <55.26%> (ø)
...experimental/tl/_align/_backends/_stalign_tools.py 0.00% <0.00%> (ø)
... and 2 more

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- Replace private `sdata._gen_elements()` with public `gen_elements()`
- Replace dict-style `sdata[key]` lookup with explicit element-type search
- Add subprocess timeout (30s) to lazy-import hygiene test
- Document shallow X sharing in `materialise_obs` docstring
- Document JAX array retention in stalign metadata comment
- Document camelCase convention in STalignRegistrationConfig docstring
- Broaden landmark type hints to accept Sequence and np.ndarray
- Remove stale TODO comment from _stalign_helpers.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
timtreis and others added 2 commits April 16, 2026 12:52
- Resolve JAX_DTYPE lazily via jax_dtype() to respect runtime x64 config
- Replace 14-line config field unpack with dataclasses.asdict(registration)
- Remove unreachable ValueError in _writeback (already validated upstream)
- Remove task-tracking comment from moscot stub
- Clean up PR-line-number reference in stalign comment

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds a hatch-test.py3.13-stable-jax environment that installs the [jax]
optional extra so the STalign solver and e2e alignment tests run in CI.
Excluded from macOS to avoid doubling runner cost.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@timtreis timtreis changed the title Migrate logic into general skeleton General purpose sq.tl.align() function with STAlign core logic Apr 16, 2026
timtreis and others added 3 commits April 16, 2026 14:13
- align_obs: default output_mode="obs" (was "affine", which crashed
  with the default stalign backend). Auto-generate key_added from
  query name when not provided for SpatialData inputs.
- align_by_landmarks: make cs_name_ref/query and landmarks_ref/query
  keyword-only required args (were Optional with None defaults that
  immediately errored). Remove unused scale_ref/scale_query params.
  Wire get_extent validation for landmark bounds checking against cs
  extent. Fix docstring (y, x) -> (x, y).
- align_images: make img_ref/query_name keyword-only required. Remove
  from public __all__ (no backend implements it yet).
- align_obs docstring: note that inplace only affects SpatialData.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JAX selects the appropriate device based on its install (CPU/GPU) and
runtime context managers. The explicit device arg added unnecessary
complexity with no benefit over JAX's built-in device management.

Removes device from: align_obs, align_images, AlignBackend protocol,
StAlignBackend, MoscotBackend, and require_jax. Simplifies require_jax
to a pure import guard.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@timtreis timtreis requested a review from selmanozleyen April 19, 2026 22:33
Copy link
Copy Markdown
Member

@selmanozleyen selmanozleyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • need to rename backend -> method everywhere
  • it is not moscot since we don't need moscot, both stalign and gridot use jax but Ideally we'd want possible numba implementations?
  • why is AlignPair only for AnnData | DataTree? Landmarks should be another possibility here
  • again why in STAlign class there is align_obs and align_images but not align_landmarks? (it's fine to have partial support since GridOT class won't have maybe align_landmarks as well)
  • why model? I'd name it ransac or whatever and expect the inputs based on that. My method should clarify what input I need. GridOT assumes the point clouds are in the same space for example and different algorithms can minimize different models of the input given.

I'd rather have

For inputs: AlignObsPair, AlignImagesPair, AlignLandmarkPair as possible inputs which are separated nicely.

For algorithms: StAlignStrategy, GridOtStrategy, LandmarkAffineStrategy, LandmarkSimilarityStrategy, *NumbaStAlignStrategy, *OpticalFlowStrategy... in which we can encapulate the heavy dependencies to classes.

We are also not nicely conceptualizing the relation of alignment to input to what the user actually gave. When the input is spatialdata the alignment pairs can be all, when it's anndata it's either the obs or uns for landmarks, or when it's numpy arrays its landmarks right? Why not go on that pattern?

If I have spatialdata I can have a landmark based algorithm run and apply that to my images and my anndata depending on what I will give.

Cc @FrancescaDr @sarajimenez

What do you guys also think? What makes sense for you usually?

@selmanozleyen
Copy link
Copy Markdown
Member

Would this kind of refactor make sense:

Public surface: 2 functions, 3 output modes

The split is driven by input source, not by element kind:

input source function
AnnData / SpatialData container align(...)
in-memory landmark arrays / tuples align_by_landmarks(...)

The obs vs image distinction is one kwarg of align(), not a
separate function -- both consume containers and produce a transform.
Output is a single enum across both functions:

output_mode returns
"object" raw AlignResult (no container is touched)
"copy" modified copy of the writeback target
"inplace" mutates the target, returns None
def align(
    data_ref: AnnData | SpatialData,
    data_query: AnnData | SpatialData,
    *,
    method: str = "stalign",
    on: Literal["obs", "image"] = "obs",
    ref_key: str | None = None,           # SpatialData only: which table / image
    query_key: str | None = None,
    output_mode: Literal["object", "copy", "inplace"] = "inplace",
    key_added: str | None = None,
    **method_kwargs,
) -> AlignResult | AnnData | SpatialData | None: ...

def align_by_landmarks(
    landmarks_ref: np.ndarray | Sequence[tuple[float, float]],
    landmarks_query: np.ndarray | Sequence[tuple[float, float]],
    *,
    method: Literal["similarity", "affine"] = "similarity",
    data: AnnData | SpatialData | None = None,   # required for copy / inplace
    cs_name_ref: str | None = None,
    cs_name_query: str | None = None,
    output_mode: Literal["object", "copy", "inplace"] = "object",
    key_added: str | None = None,
    method_kwargs
) -> AlignResult | AnnData | SpatialData | None: ...

align_by_landmarks defaults to output_mode="object" because the only
input is arrays in memory -- there is no implicit container target.
Choosing "copy" or "inplace" requires data=.

Naming notes:

  • ref_key / query_key (not element_ref / element_query): "element"
    is SpatialData jargon; _key is the standard scverse suffix
    (cluster_key, library_key, spatial_key, connectivity_key).
  • data= (not target=) on align_by_landmarks: from the user's
    perspective they are passing the data on which to apply the
    alignment
    , not a "writeback dump". The data slot also matches
    align's data_ref / data_query vocabulary.

Strategy hierarchy: two families, pure fit

Strategy.fit* is a pure function -- raw fit args in, AlignResult
out, no container in scope. Containers are routed separately as a
writeback AlignTarget and only enter the picture after the fit. The
landmark family in particular never sees an AnnData or SpatialData
during the fit.

AlignStrategy (ABC)
    name: ClassVar[str]
    requires: ClassVar[tuple[str, ...]]
├── AlignDataStrategy (ABC)
│       on: ClassVar[frozenset[str]]   # subset of {"obs", "image"}
│       def fit_obs(self, ref: AnnData,        query: AnnData)        -> AlignResult: ...
│       def fit_image(self, ref: xr.DataArray, query: xr.DataArray)   -> AlignResult: ...
│   ├── StAlignStrategy             # on={"obs","image"}; requires=("jax",)
│   ├── NumbaStalignStrategy        # on={"obs","image"}; requires=("numba",)
│   └── GridOtStrategy              # on={"obs"};         requires=("jax",)
└── AlignLandmarkStrategy (ABC)
        def fit(self, landmarks_ref: np.ndarray, landmarks_query: np.ndarray) -> AlignResult: ...
    ├── LandmarkSimilarityStrategy
    └── LandmarkAffineStrategy

Writeback target (separate concept, kept)

@dataclass(frozen=True)
class AlignTarget:
    container: AnnData | SpatialData
    key: str | None = None           # element key (SpatialData only; ignored for AnnData)

Two fields, both intent-neutral. Coordinate-system names live on the
AlignResult (the affine carries source_cs / target_cs), not on the
writeback target -- they describe the transform, not where to put
it
.

AlignTarget survives across the resolver -> strategy -> writeback
boundary (the fit args do not), so it earns being a typed object.
Resolvers in _io.py produce (ref_data, query_data, target | None):

public function fit args target source
align(...) extracted from data_query data_query itself -> AlignTarget
align_by_landmarks(...) landmarks_ref, landmarks_query data= kwarg -> AlignTarget (or None)

Two flat (method) -> Strategy registries (one per family), no
(method, mode) tuple keys -- the family already pins it.

Result hierarchy: polymorphic apply_to(target)

AlignResult (ABC)
    metadata: dict
    def apply_to(self, target: AlignTarget, *, key_added=None) -> AnnData | SpatialData
├── AffineResult         # 3x3; SpatialData -> set_transformation, AnnData -> obsm
├── DisplacementResult   # (N, 2); per-obs deltas; obsm-only on either container
└── DiffeomorphismResult # velocity field + integrator (future, lightweight)

Invariant: results and inputs carry pointers / parametric forms, not
bulk data.
DiffeomorphismResult keeps the velocity field; the dense
(H, W, 2) warp is materialised only when actually needed.

Each result subclass dispatches on target.container (and on
target.key) inside apply_to, absorbing the three writeback paths
currently in _io.py::apply_affine_to_cs. Coordinate-system info comes
from the result's own transform (source_cs, target_cs), not from the
target. _writeback collapses to:

def _writeback(target: AlignTarget | None, result: AlignResult, *, output_mode, key_added):
    if output_mode == "object":
        return result
    if target is None:
        raise ValueError(f"output_mode={output_mode!r} requires a target container.")
    if output_mode == "copy":
        target = _copy_for_writeback(target)
    return result.apply_to(target, key_added=key_added)

No isinstance ladder on the result type, no
output_mode in {"obs", "affine"} branch -- the result class owns the
"can I be applied to this target?" question and raises a single, typed
NotImplementedError when a combination doesn't make sense (e.g.
DisplacementResult.apply_to(<SpatialData target without a key>)).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants