Skip to content

Commit 5eda868

Browse files
committed
feat: parallelize per-library spatial graph construction
Add `n_jobs` parameter to `spatial_neighbors()` to compute per-library graphs in parallel via joblib. Defaults to 1 (sequential, no behavior change). Set to -1 to use all CPUs. When `library_key` is set, each library's graph is already computed independently, so this is a trivially parallel workload. For datasets with many libraries (e.g., multi-sample spatial transcriptomics), this gives a near-linear speedup.
1 parent 9690a55 commit 5eda868

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

src/squidpy/gr/_build.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from anndata import AnnData
1515
from anndata.utils import make_index_unique
1616
from fast_array_utils import stats as fau_stats
17+
from joblib import Parallel, delayed
1718
from numba import njit, prange
1819
from scipy.sparse import (
1920
SparseEfficiencyWarning,
@@ -77,6 +78,7 @@ def spatial_neighbors(
7778
set_diag: bool = False,
7879
key_added: str = "spatial",
7980
copy: bool = False,
81+
n_jobs: int = 1,
8082
) -> SpatialNeighborsResult | None:
8183
"""
8284
Create a graph from spatial coordinates.
@@ -131,6 +133,9 @@ def spatial_neighbors(
131133
key_added
132134
Key which controls where the results are saved if ``copy = False``.
133135
%(copy)s
136+
n_jobs
137+
Number of parallel jobs for computing per-library graphs. Only used when ``library_key`` is not ``None``.
138+
``1`` (default) disables parallelism. ``-1`` uses all available CPUs.
134139
135140
Returns
136141
-------
@@ -230,14 +235,25 @@ def spatial_neighbors(
230235
)
231236

232237
if library_key is not None:
233-
mats: list[tuple[spmatrix, spmatrix]] = []
238+
239+
def _compute_one(lib: Any) -> tuple[np.ndarray, spmatrix, spmatrix]:
240+
idx = np.where(adata.obs[library_key] == lib)[0]
241+
adj, dst = _build_fun(adata[adata.obs[library_key] == lib])
242+
return idx, adj, dst
243+
244+
results = Parallel(n_jobs=n_jobs)(delayed(_compute_one)(lib) for lib in libs)
245+
234246
ixs: list[int] = []
235-
for lib in libs:
236-
ixs.extend(np.where(adata.obs[library_key] == lib)[0])
237-
mats.append(_build_fun(adata[adata.obs[library_key] == lib]))
238-
ixs = cast(list[int], np.argsort(ixs).tolist())
239-
Adj = block_diag([m[0] for m in mats], format="csr")[ixs, :][:, ixs]
240-
Dst = block_diag([m[1] for m in mats], format="csr")[ixs, :][:, ixs]
247+
mats_adj: list[spmatrix] = []
248+
mats_dst: list[spmatrix] = []
249+
for idx, adj, dst in results:
250+
ixs.extend(idx)
251+
mats_adj.append(adj)
252+
mats_dst.append(dst)
253+
254+
ixs_order = cast(list[int], np.argsort(ixs).tolist())
255+
Adj = block_diag(mats_adj, format="csr")[ixs_order, :][:, ixs_order]
256+
Dst = block_diag(mats_dst, format="csr")[ixs_order, :][:, ixs_order]
241257
else:
242258
Adj, Dst = _build_fun(adata)
243259

0 commit comments

Comments
 (0)