|
14 | 14 | from anndata import AnnData |
15 | 15 | from anndata.utils import make_index_unique |
16 | 16 | from fast_array_utils import stats as fau_stats |
| 17 | +from joblib import Parallel, delayed |
17 | 18 | from numba import njit, prange |
18 | 19 | from scipy.sparse import ( |
19 | 20 | SparseEfficiencyWarning, |
@@ -77,6 +78,7 @@ def spatial_neighbors( |
77 | 78 | set_diag: bool = False, |
78 | 79 | key_added: str = "spatial", |
79 | 80 | copy: bool = False, |
| 81 | + n_jobs: int = 1, |
80 | 82 | ) -> SpatialNeighborsResult | None: |
81 | 83 | """ |
82 | 84 | Create a graph from spatial coordinates. |
@@ -131,6 +133,9 @@ def spatial_neighbors( |
131 | 133 | key_added |
132 | 134 | Key which controls where the results are saved if ``copy = False``. |
133 | 135 | %(copy)s |
| 136 | + n_jobs |
| 137 | + Number of parallel jobs for computing per-library graphs. Only used when ``library_key`` is not ``None``. |
| 138 | + ``1`` (default) disables parallelism. ``-1`` uses all available CPUs. |
134 | 139 |
|
135 | 140 | Returns |
136 | 141 | ------- |
@@ -230,14 +235,25 @@ def spatial_neighbors( |
230 | 235 | ) |
231 | 236 |
|
232 | 237 | if library_key is not None: |
233 | | - mats: list[tuple[spmatrix, spmatrix]] = [] |
| 238 | + |
| 239 | + def _compute_one(lib: Any) -> tuple[np.ndarray, spmatrix, spmatrix]: |
| 240 | + idx = np.where(adata.obs[library_key] == lib)[0] |
| 241 | + adj, dst = _build_fun(adata[adata.obs[library_key] == lib]) |
| 242 | + return idx, adj, dst |
| 243 | + |
| 244 | + results = Parallel(n_jobs=n_jobs)(delayed(_compute_one)(lib) for lib in libs) |
| 245 | + |
234 | 246 | ixs: list[int] = [] |
235 | | - for lib in libs: |
236 | | - ixs.extend(np.where(adata.obs[library_key] == lib)[0]) |
237 | | - mats.append(_build_fun(adata[adata.obs[library_key] == lib])) |
238 | | - ixs = cast(list[int], np.argsort(ixs).tolist()) |
239 | | - Adj = block_diag([m[0] for m in mats], format="csr")[ixs, :][:, ixs] |
240 | | - Dst = block_diag([m[1] for m in mats], format="csr")[ixs, :][:, ixs] |
| 247 | + mats_adj: list[spmatrix] = [] |
| 248 | + mats_dst: list[spmatrix] = [] |
| 249 | + for idx, adj, dst in results: |
| 250 | + ixs.extend(idx) |
| 251 | + mats_adj.append(adj) |
| 252 | + mats_dst.append(dst) |
| 253 | + |
| 254 | + ixs_order = cast(list[int], np.argsort(ixs).tolist()) |
| 255 | + Adj = block_diag(mats_adj, format="csr")[ixs_order, :][:, ixs_order] |
| 256 | + Dst = block_diag(mats_dst, format="csr")[ixs_order, :][:, ixs_order] |
241 | 257 | else: |
242 | 258 | Adj, Dst = _build_fun(adata) |
243 | 259 |
|
|
0 commit comments