Skip to content

Commit 538897a

Browse files
Reorganize linear algebra (#2037)
* Move matrix decomposition Ops into a sub-module * Move solve Ops into a sub-module * Move misc solve Ops * Move inverse Ops to sub-module * Move matrix product Ops to sub-module * Move matrix construction Ops to sub-module * Move matrix summarization Ops to sub-module * Re-import all public functions via `linalg.py` * Issue depreciation warning when importing from nlinalg/slinalg * Update rewrite imports to use _linalg locations * Update link dispatch imports to use _linalg locations * Update numba dispatch imports to use _linalg locations * Organize linalg rewrites according to library * Organize linalg dispatches according to library * Organize numba dispatches according to library * Organize linalg tests according to library * Organize linalg rewrite tests according to library * Rename tests/tensor/test_linalg -> tests/tensor/linalg * Reorganize dispatch test according to library * Clean up stale linalg imports * Update docs * Rename tests/link/*/test_linalg -> tests/link/*/linalg * Remove stray empty test_linalg/__init__.py under numba linalg * Update numba link CI part to new linalg/ directory * Move solve rewrites under tensor/rewriting/linalg The scan-tracking solve rewrites previously lived in pytensor/tensor/_linalg/solve/rewriting.py and were registered via a side-effect import chain from tensor/_linalg/__init__.py. That chain was broken in the linalg reorg (the _linalg package init was emptied), and registration only kept working incidentally because tensor/linalg.py happens to import from _linalg.solve. Move them into tensor/rewriting/linalg/solve.py alongside the other solve rewrites. The imports also needed updating: the v3 file imported via tensor.slinalg, which is now the deprecation shim and would emit DeprecationWarnings on every load. Also drop the speculative __all__ and inv_to_solve re-export from tensor/rewriting/linalg/__init__.py; the one test that imported inv_to_solve from there now imports it from .inverse directly. * Stop importing from deprecated nlinalg/slinalg in unrelated tests These three test files happened to import individual Ops/functions from pytensor.tensor.{nlinalg,slinalg}. Since those modules are now deprecation shims, every test run was emitting DeprecationWarnings for imports that have nothing to do with testing the shim itself. Retarget the imports to the new pytensor.tensor._linalg.* locations, matching the convention used by the new tests/tensor/linalg/ tests. * Stop importing from deprecated nlinalg/slinalg in legacy test suites tests/tensor/test_nlinalg.py and tests/tensor/test_slinalg.py were left untouched by the linalg reorg and still imported individual symbols from pytensor.tensor.{nlinalg,slinalg}. With those modules now being deprecation shims, every collection of these files emitted dozens of DeprecationWarnings. These files are not testing the deprecation shim itself — they're the original Op test suites. Retarget the imports to the new pytensor.tensor._linalg.* locations. * Add targeted test for nlinalg/slinalg deprecation shim Asserts that accessing a moved name on the deprecated module emits a DeprecationWarning, that the shim forwards to the same object exposed by the new pytensor.tensor.linalg public API, and that unknown names still raise AttributeError. Also fixes a stale doc comment in _linalg/summary.py that pointed at the old nlinalg path. * Drop automodule from deprecated nlinalg/slinalg doc pages The deprecation pages still carried `.. automodule:: pytensor.tensor.{nlinalg,slinalg} :members:` directives. Each automodule walks the module's `__dir__`, which on the new shims returns every moved name and triggers `__getattr__` per name — emitting a DeprecationWarning per attribute on every doc build, and producing duplicate doc entries that already exist on the new :ref:`libdoc_linalg` page. Drop the automodule blocks; the pages are now pure deprecation notices that point at `pytensor.tensor.linalg`. * Group linalg API docs by category and pin them to __all__ The linalg.rst page used `.. automodule:: pytensor.tensor.linalg :members:`, which renders all 36 functions from `__all__` as one unstructured alphabetical list. Replace it with hand-grouped sections (Constructors / Decomposition / Inverse / Products / Solve / Summary) that mirror the internal `_linalg/` package layout, using explicit `autofunction::` directives. To stop the rst from drifting away from the public API, add `tests/tensor/linalg/test_doc_api.py`. It parses every `.. autofunction:: pytensor.tensor.linalg.<name>` directive out of the rst and asserts the resulting set is exactly equal to `pytensor.tensor.linalg.__all__`. Any future addition or removal on either side fails the test with a precise diff. Op classes remain undocumented in the rst — the public surface is intentionally functions only. * Tweak linalg directory structure --------- Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
1 parent 6f08a6e commit 538897a

File tree

154 files changed

+12013
-9062
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

154 files changed

+12013
-9062
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ jobs:
101101
default-mode: "CVM"
102102
python-version: "3.12"
103103
os: "ubuntu-latest"
104-
- part: ["numba link", "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"]
104+
- part: ["numba link", "tests/link/numba --ignore=tests/link/numba/linalg"]
105105
default-mode: "CVM"
106106
python-version: "3.12"
107107
os: "ubuntu-latest"
108-
- part: ["numba link slinalg", "tests/link/numba/test_slinalg.py"]
108+
- part: ["numba link linalg", "tests/link/numba/linalg"]
109109
default-mode: "CVM"
110110
python-version: "3.13"
111111
os: "ubuntu-latest"

doc/library/tensor/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ symbolic expressions using calls that look just like numpy calls, such as
2424
elemwise
2525
extra_ops
2626
io
27+
linalg
2728
slinalg
2829
nlinalg
2930
fft

doc/library/tensor/linalg.rst

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
.. _libdoc_linalg:
2+
3+
===================================================================
4+
:mod:`tensor.linalg` -- Linear Algebra Operations
5+
===================================================================
6+
7+
.. module:: tensor.linalg
8+
:platform: Unix, Windows
9+
:synopsis: Linear Algebra Operations
10+
11+
The :mod:`pytensor.tensor.linalg` module exposes the user-facing linear
12+
algebra API.
13+
14+
Constructors
15+
============
16+
17+
.. autofunction:: pytensor.tensor.linalg.block_diag
18+
19+
Decomposition
20+
=============
21+
22+
.. autofunction:: pytensor.tensor.linalg.cholesky
23+
.. autofunction:: pytensor.tensor.linalg.lu
24+
.. autofunction:: pytensor.tensor.linalg.lu_factor
25+
.. autofunction:: pytensor.tensor.linalg.pivot_to_permutation
26+
.. autofunction:: pytensor.tensor.linalg.qr
27+
.. autofunction:: pytensor.tensor.linalg.svd
28+
.. autofunction:: pytensor.tensor.linalg.eig
29+
.. autofunction:: pytensor.tensor.linalg.eigh
30+
.. autofunction:: pytensor.tensor.linalg.eigvalsh
31+
.. autofunction:: pytensor.tensor.linalg.schur
32+
.. autofunction:: pytensor.tensor.linalg.qz
33+
.. autofunction:: pytensor.tensor.linalg.ordqz
34+
35+
Inverse
36+
=======
37+
38+
.. autofunction:: pytensor.tensor.linalg.inv
39+
.. autofunction:: pytensor.tensor.linalg.pinv
40+
.. autofunction:: pytensor.tensor.linalg.tensorinv
41+
42+
Products
43+
========
44+
45+
.. autofunction:: pytensor.tensor.linalg.kron
46+
.. autofunction:: pytensor.tensor.linalg.matrix_dot
47+
.. autofunction:: pytensor.tensor.linalg.matrix_power
48+
.. autofunction:: pytensor.tensor.linalg.expm
49+
50+
Solve
51+
=====
52+
53+
.. autofunction:: pytensor.tensor.linalg.solve
54+
.. autofunction:: pytensor.tensor.linalg.solve_triangular
55+
.. autofunction:: pytensor.tensor.linalg.lu_solve
56+
.. autofunction:: pytensor.tensor.linalg.cho_solve
57+
.. autofunction:: pytensor.tensor.linalg.lstsq
58+
.. autofunction:: pytensor.tensor.linalg.tensorsolve
59+
.. autofunction:: pytensor.tensor.linalg.tridiagonal_lu_factor
60+
.. autofunction:: pytensor.tensor.linalg.tridiagonal_lu_solve
61+
.. autofunction:: pytensor.tensor.linalg.solve_continuous_lyapunov
62+
.. autofunction:: pytensor.tensor.linalg.solve_discrete_lyapunov
63+
.. autofunction:: pytensor.tensor.linalg.solve_discrete_are
64+
.. autofunction:: pytensor.tensor.linalg.solve_sylvester
65+
66+
Summary
67+
=======
68+
69+
.. autofunction:: pytensor.tensor.linalg.det
70+
.. autofunction:: pytensor.tensor.linalg.slogdet
71+
.. autofunction:: pytensor.tensor.linalg.norm
72+
.. autofunction:: pytensor.tensor.linalg.trace

doc/library/tensor/nlinalg.rst

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
1-
.. ../../../../pytensor/sandbox/nlinalg.py
2-
3-
.. _libdoc_linalg:
1+
.. _libdoc_nlinalg:
42

53
===================================================================
6-
:mod:`tensor.nlinalg` -- Linear Algebra Ops Using Numpy
4+
:mod:`tensor.nlinalg` -- Linear Algebra Ops Using Numpy (deprecated)
75
===================================================================
86

97
.. module:: tensor.nlinalg
108
:platform: Unix, Windows
11-
:synopsis: Linear Algebra Ops Using Numpy
9+
:synopsis: Linear Algebra Ops Using Numpy (deprecated)
1210
.. moduleauthor:: LISA
1311

14-
.. note::
15-
16-
This module is not imported by default. You need to import it to use it.
17-
18-
API
19-
===
20-
21-
.. automodule:: pytensor.tensor.nlinalg
22-
:members:
12+
.. deprecated:: 2.x
13+
The ``nlinalg`` module is deprecated. Use :mod:`pytensor.tensor.linalg` instead.
14+
All symbols previously exported from ``nlinalg`` are available from
15+
``pytensor.tensor.linalg`` — see :ref:`libdoc_linalg`. Imports from
16+
``nlinalg`` will be removed in PyTensor 3.0.

doc/library/tensor/slinalg.rst

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
1-
.. ../../../../pytensor/sandbox/slinalg.py
2-
31
.. _libdoc_slinalg:
42

53
===================================================================
6-
:mod:`tensor.slinalg` -- Linear Algebra Ops Using Scipy
4+
:mod:`tensor.slinalg` -- Linear Algebra Ops Using Scipy (deprecated)
75
===================================================================
86

97
.. module:: tensor.slinalg
108
:platform: Unix, Windows
11-
:synopsis: Linear Algebra Ops Using Scipy
9+
:synopsis: Linear Algebra Ops Using Scipy (deprecated)
1210
.. moduleauthor:: LISA
1311

14-
.. note::
15-
16-
This module is not imported by default. You need to import it to use it.
17-
18-
API
19-
===
20-
21-
.. automodule:: pytensor.tensor.slinalg
22-
:members:
23-
12+
.. deprecated:: 2.x
13+
The ``slinalg`` module is deprecated. Use :mod:`pytensor.tensor.linalg` instead.
14+
All symbols previously exported from ``slinalg`` are available from
15+
``pytensor.tensor.linalg`` — see :ref:`libdoc_linalg`. Imports from
16+
``slinalg`` will be removed in PyTensor 3.0.

pytensor/graph/rewriting/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,7 @@ def output_fn(fgraph, node, s):
14951495
from pytensor.graph.rewriting.basic import PatternNodeRewriter
14961496
from pytensor.graph.rewriting.unify import OpPattern, LiteralString
14971497
from pytensor.tensor.blockwise import Blockwise
1498-
from pytensor.tensor.slinalg import Solve
1498+
from pytensor.tensor.linalg.solvers.general import Solve
14991499
15001500
PatternNodeRewriter(
15011501
(

pytensor/graph/rewriting/unify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class OpPattern:
287287
from pytensor.graph.rewriting.unify import OpPattern
288288
from pytensor.tensor.elemwise import CAReduce
289289
from pytensor.tensor.blockwise import Blockwise
290-
from pytensor.tensor.slinalg import Solve
290+
from pytensor.tensor.linalg.solvers.general import Solve
291291
292292
@node_rewriter(tracks=[OpPattern(CAReduce, axis=None)])
293293
def local_car_reduce_all_rewriter(fgraph, node):
@@ -352,7 +352,7 @@ def output_fn(fgraph, node, s):
352352
import pytensor.tensor as pt
353353
from pytensor.graph.rewriting.unify import OpPattern
354354
from pytensor.tensor.blockwise import Blockwise
355-
from pytensor.tensor.slinalg import Solve
355+
from pytensor.tensor.linalg.solvers.general import Solve
356356
357357
A = var("A")
358358
b = var("b")

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
import pytensor.link.jax.dispatch.extra_ops
1010
import pytensor.link.jax.dispatch.pad
1111
import pytensor.link.jax.dispatch.math
12-
import pytensor.link.jax.dispatch.nlinalg
12+
import pytensor.link.jax.dispatch.linalg
1313
import pytensor.link.jax.dispatch.random
1414
import pytensor.link.jax.dispatch.scalar
1515
import pytensor.link.jax.dispatch.scan
1616
import pytensor.link.jax.dispatch.shape
1717
import pytensor.link.jax.dispatch.signal
18-
import pytensor.link.jax.dispatch.slinalg
1918
import pytensor.link.jax.dispatch.sort
2019
import pytensor.link.jax.dispatch.sparse
2120
import pytensor.link.jax.dispatch.subtensor
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pytensor.link.jax.dispatch.linalg import (
2+
constructors,
3+
decomposition,
4+
inverse,
5+
products,
6+
solvers,
7+
summary,
8+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import jax
2+
3+
from pytensor.link.jax.dispatch.basic import jax_funcify
4+
from pytensor.tensor.linalg.constructors import BlockDiagonal
5+
6+
7+
@jax_funcify.register(BlockDiagonal)
8+
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
9+
def block_diag(*inputs):
10+
return jax.scipy.linalg.block_diag(*inputs)
11+
12+
return block_diag

0 commit comments

Comments
 (0)