General purpose sq.tl.align() function with STAlign core logic#1162
General purpose sq.tl.align() function with STAlign core logic#1162timtreis wants to merge 7 commits intoscverse:mainfrom
sq.tl.align() function with STAlign core logic#1162Conversation
Introduces a backend-agnostic alignment API (align_obs, align_images, align_by_landmarks) with STalign and landmark backends, lazy JAX imports, and e2e tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
16c576b to
42d74ab
Compare
- Replace private `sdata._gen_elements()` with public `gen_elements()` - Replace dict-style `sdata[key]` lookup with explicit element-type search - Add subprocess timeout (30s) to lazy-import hygiene test - Document shallow X sharing in `materialise_obs` docstring - Document JAX array retention in stalign metadata comment - Document camelCase convention in STalignRegistrationConfig docstring - Broaden landmark type hints to accept Sequence and np.ndarray - Remove stale TODO comment from _stalign_helpers.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3258f6e to
837c258
Compare
- Resolve JAX_DTYPE lazily via jax_dtype() to respect runtime x64 config - Replace 14-line config field unpack with dataclasses.asdict(registration) - Remove unreachable ValueError in _writeback (already validated upstream) - Remove task-tracking comment from moscot stub - Clean up PR-line-number reference in stalign comment Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds a hatch-test.py3.13-stable-jax environment that installs the [jax] optional extra so the STalign solver and e2e alignment tests run in CI. Excluded from macOS to avoid doubling runner cost. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
sq.tl.align() function with STAlign core logic
- align_obs: default output_mode="obs" (was "affine", which crashed with the default stalign backend). Auto-generate key_added from query name when not provided for SpatialData inputs. - align_by_landmarks: make cs_name_ref/query and landmarks_ref/query keyword-only required args (were Optional with None defaults that immediately errored). Remove unused scale_ref/scale_query params. Wire get_extent validation for landmark bounds checking against cs extent. Fix docstring (y, x) -> (x, y). - align_images: make img_ref/query_name keyword-only required. Remove from public __all__ (no backend implements it yet). - align_obs docstring: note that inplace only affects SpatialData. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
JAX selects the appropriate device based on its install (CPU/GPU) and runtime context managers. The explicit device arg added unnecessary complexity with no benefit over JAX's built-in device management. Removes device from: align_obs, align_images, AlignBackend protocol, StAlignBackend, MoscotBackend, and require_jax. Simplifies require_jax to a pure import guard. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
- need to rename backend -> method everywhere
- it is not moscot since we don't need moscot, both stalign and gridot use jax but Ideally we'd want possible numba implementations?
- why is AlignPair only for
AnnData | DataTree? Landmarks should be another possibility here - again why in STAlign class there is align_obs and align_images but not align_landmarks? (it's fine to have partial support since GridOT class won't have maybe align_landmarks as well)
- why
model? I'd name it ransac or whatever and expect the inputs based on that. My method should clarify what input I need. GridOT assumes the point clouds are in the same space for example and different algorithms can minimize different models of the input given.
I'd rather have
For inputs: AlignObsPair, AlignImagesPair, AlignLandmarkPair as possible inputs which are separated nicely.
For algorithms: StAlignStrategy, GridOtStrategy, LandmarkAffineStrategy, LandmarkSimilarityStrategy, *NumbaStAlignStrategy, *OpticalFlowStrategy... in which we can encapulate the heavy dependencies to classes.
We are also not nicely conceptualizing the relation of alignment to input to what the user actually gave. When the input is spatialdata the alignment pairs can be all, when it's anndata it's either the obs or uns for landmarks, or when it's numpy arrays its landmarks right? Why not go on that pattern?
If I have spatialdata I can have a landmark based algorithm run and apply that to my images and my anndata depending on what I will give.
What do you guys also think? What makes sense for you usually?
|
Would this kind of refactor make sense: Public surface: 2 functions, 3 output modesThe split is driven by input source, not by element kind:
The
def align(
data_ref: AnnData | SpatialData,
data_query: AnnData | SpatialData,
*,
method: str = "stalign",
on: Literal["obs", "image"] = "obs",
ref_key: str | None = None, # SpatialData only: which table / image
query_key: str | None = None,
output_mode: Literal["object", "copy", "inplace"] = "inplace",
key_added: str | None = None,
**method_kwargs,
) -> AlignResult | AnnData | SpatialData | None: ...
def align_by_landmarks(
landmarks_ref: np.ndarray | Sequence[tuple[float, float]],
landmarks_query: np.ndarray | Sequence[tuple[float, float]],
*,
method: Literal["similarity", "affine"] = "similarity",
data: AnnData | SpatialData | None = None, # required for copy / inplace
cs_name_ref: str | None = None,
cs_name_query: str | None = None,
output_mode: Literal["object", "copy", "inplace"] = "object",
key_added: str | None = None,
method_kwargs
) -> AlignResult | AnnData | SpatialData | None: ...
Naming notes:
Strategy hierarchy: two families, pure
|
| public function | fit args | target source |
|---|---|---|
align(...) |
extracted from data_query |
data_query itself -> AlignTarget |
align_by_landmarks(...) |
landmarks_ref, landmarks_query |
data= kwarg -> AlignTarget (or None) |
Two flat (method) -> Strategy registries (one per family), no
(method, mode) tuple keys -- the family already pins it.
Result hierarchy: polymorphic apply_to(target)
AlignResult (ABC)
metadata: dict
def apply_to(self, target: AlignTarget, *, key_added=None) -> AnnData | SpatialData
├── AffineResult # 3x3; SpatialData -> set_transformation, AnnData -> obsm
├── DisplacementResult # (N, 2); per-obs deltas; obsm-only on either container
└── DiffeomorphismResult # velocity field + integrator (future, lightweight)
Invariant: results and inputs carry pointers / parametric forms, not
bulk data. DiffeomorphismResult keeps the velocity field; the dense
(H, W, 2) warp is materialised only when actually needed.
Each result subclass dispatches on target.container (and on
target.key) inside apply_to, absorbing the three writeback paths
currently in _io.py::apply_affine_to_cs. Coordinate-system info comes
from the result's own transform (source_cs, target_cs), not from the
target. _writeback collapses to:
def _writeback(target: AlignTarget | None, result: AlignResult, *, output_mode, key_added):
if output_mode == "object":
return result
if target is None:
raise ValueError(f"output_mode={output_mode!r} requires a target container.")
if output_mode == "copy":
target = _copy_for_writeback(target)
return result.apply_to(target, key_added=key_added)No isinstance ladder on the result type, no
output_mode in {"obs", "affine"} branch -- the result class owns the
"can I be applied to this target?" question and raises a single, typed
NotImplementedError when a combination doesn't make sense (e.g.
DisplacementResult.apply_to(<SpatialData target without a key>)).
Implements the following functions:
This registers a spatialdata Affine transformation on every element in cs_b, mapping it into cs_a. After the call, sdata is updated in-place, all elements in section_b now know how to transform into section_a's coordinate system.
output_mode controls what you get back:
Currently dead but will be implemente via STalign. Already in here to better align function internals