Skip to content

System for algebraic reasoning about linear alegbra#2032

Open
jessegrabowski wants to merge 11 commits intopymc-devs:v3from
jessegrabowski:assumption-system
Open

System for algebraic reasoning about linear alegbra#2032
jessegrabowski wants to merge 11 commits intopymc-devs:v3from
jessegrabowski:assumption-system

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Apr 7, 2026

This PR is a proposal for a typing system for linear algebra primitives. The purpose is to enable graph-wide reasoning about the kinds of matrices, so that we can rewrite to efficient computational forms.

Current State

We currently have several linear algebra rewrites and plan to add more. These are tracked in #573. This is important because linear algebra is 1) ubiquotous, 2) expensive, and 3) inscrutiable. Pytensor's static graph representation and rewrite system is well positioned to provide users help writing the best possible programs involving heavy linear algebra, if only we can figure out what is going on.

Consider the motivating case of solve(A, b), when A = pt.diag(pt.arange(100_000)). This is an O(n^3) operation that will call out to specialized routines. But there is no need for this. Since A is diagonal, we can write this as an elementwise division b / pt.extract_diag(A).

How do we decide to do this? We:

  • Track the Solve Op
  • Check if either input "seems diagonal"
  • If so, do the rewrite

What seems diagonal? We assume an input is diagonal if it was created by pt.eye or pt.diag. Users cannot specify themselves whether input data is diagonal. If an Op get inbetween the known "diagonalish" Op and the Solve, we cannot detect diagonality. For example, we cannot rewrite solve(A * 3, b), because now the first input is Elemwise(Mul)(A, 3). Multiplication is diagonal-preserving (because it is zero-preserving), but since the known diagonal op is now buried inside the Elemwise, we're out of luck.

Proposal

My proposal is to reason about algebraic properties of matrices the same way we reason about shapes. For shapes, we attach a ShapeFeature to FunctionGraphs. Each Op has an infer_shape method that explains how the static shape propagates. Likewise, I propose an AssumptionFeature. Ops do not have infer_assumption methods. That would be too messy. Instead, we have a central ASSUMPTION_INFER_REGISTRY with keys (Op, AssumptionKey ) and values InferFactFn.

  • An AssumptionKey is just a marker class corresponding to an algebraic fact about a matrix, like DIAGONAL, LOWER_TRIANGULAR, ORTHOGONAL, POSITIVE, SEMIDEFINITE, and so on.
  • An InferFactFn has the following signature:
def infer_diagonal(op: Op, assumption_feature: AssumptionFeature, fgraph: FunctionGraph, node: Apply, input_facts: list[FactState]) -> list[FactState]:

Like other symbolic operations, the InferFactFunction takes an Op (plus global information about the graph it lives in, fgraph and assumption_feature), and information about its inputs (the list of FactStates) and returns a list of information about its outputs.

A FactState is a three-valued logic for assumption inference. The possible values are UNKNOWN, TRUE, or FALSE. A fourth state, CONFLICT, exists, but should never arise.

All facts about all Ops are assumed to be UNKNOWN unless we can prove otherwise. Proof comes from each Op's registered InferFactFunctions. The AssumptionFeature is responsible for gathering all the rules of fact propagation. An example of a simple fact is that all Eye Ops are DIAGONAL, provided it is 1) square and 2) offset of zero:

def true_if(cond: bool) -> list[FactState]:
    """``[TRUE]`` when *cond* holds, ``[UNKNOWN]`` otherwise."""
    return [FactState.TRUE] if cond else [FactState.UNKNOWN]

def eye_is_identity(node) -> bool:
    """True when an :class:`Eye` node produces the identity matrix (square, k == 0)."""
    n, m, k = node.inputs
    if not (isinstance(k, Constant) and k.data.item() == 0):
        return False
    if n is m:
        return True
    if isinstance(n, Constant) and isinstance(m, Constant):
        return n.data.item() == m.data.item()
    return False

@register_assumption(DIAGONAL, Eye)
def _eye(op, feature, fgraph, node, input_states):
    return true_if(eye_is_identity(node))

In a program, we can use the AssumptionFeature.get(x, AssumptionKey) to query about the state of a Variable. Here, we ask "is x diagonal?". Obviously it is:

import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.assumptions import AssumptionFeature, DIAGONAL
x = pt.eye(5)
fg = FunctionGraph([], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(x, DIAGONAL) # <FactState.TRUE: 1>

Where this becomes powerful is by accumulating InferFactFunctions. Many Ops preserve diagonality, like Cholesky or Inverse. We can use information about the inputs to these Ops to propogate fact information through the graph:

@register_assumption(DIAGONAL, Cholesky)
def _cholesky(op, feature, fgraph, node, input_states):
    return true_if(input_states[0])

@register_assumption(DIAGONAL, MatrixInverse)
def _inv(op, feature, fgraph, node, input_states):
    return true_if(input_states[0])

Now we don't lose information about x in deeper graphs, and are free to do more rewrites:

x = pt.eye(5)
y = pt.linalg.cholesky(x)
z = pt.linalg.inv(y)

fg = FunctionGraph([z], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(y, DIAGONAL) # <FactState.TRUE: 1>
af.get(z, DIAGONAL) # <FactState.TRUE: 1>

Of course can also reason conditionally. An IncSubtensor might be diagonal-preserving if we can prove that we're setting a value on the diagonal of the matrix. Otherwise we fall back to UNKNOWN:

from pytensor.tensor.subtensor import IncSubtensor

x = pt.eye(5)
i, j = pt.iscalars('i', 'j')
y = x[i, j].inc(3) # Cannot prove i != j at runtime, so UNKNOWN
z = x[i, i].inc(3) # i == i provable, so this is diagonal-preserving
fg = FunctionGraph([y, z], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(y, DIAGONAL) #<FactState.UNKNOWN: 0>
af.get(z, DIAGONAL) #<FactState.TRUE: 1>

Facts can also imply other facts. DIAGONAL matrices are also symmetrical. These general relationships can be registered and encoded as well. Continuing the example above:

from pytensor.tensor.assumptions import SYMMETRIC
af.get(z, SYMMETRIC) # <FactState.TRUE: 1>

Finally, users can specify facts about matrices using pt.specify_assumptions, the same way they are able to specify shapes.

x = pt.specify_assumptions(pt.dmatrix('x'), diagonal=True)
fg = FunctionGraph([x], [x])
af = AssumptionFeature()
fg.attach_feature(af)

af.get(x, DIAGONAL) # <FactState.TRUE: 1>

Benefits for rewriting:

  • FactStates are trivial to check. We can check any fact about any Variable in 5 lines:
def _is_diagonal(var, fgraph):
    """Check if *var* is diagonal using the AssumptionFeature."""

    af = getattr(fgraph, "assumption_feature", None)
    if af is None:
        af = AssumptionFeature()
        fgraph.attach_feature(af)

    return af.check(var, DIAGONAL)
  • We can reason globally about the graph
    As noted above, information flows through the graph. As long as we have good coverage of fact rules for Ops, we can make statements about Variables at all levels of computation

  • InferFactFunctions are lightweight and easy to write
    It is trivial to add new InferFactFunctions. LLMs can bang them out. They require only local reasoning about the Op and its immediate inputs.

  • FactStates allow non-trivial combinations of rewrites
    One example I hit while working on this was a rewrite for DirectSolveLyapunov given diagonal inputs. Because there is a rule that kron preserves diagonality if both inputs are diagonal, the chain of rewrites from solve_discrete_lyapunov(diag(a), Q) -> Q.ravel() / (1 - outer(a, a).ravel()) is discovered by the rewrite system via:

  • rewrite_kron_diag_to_diag_outer: kron(diag(a), diag(b)) → diag(outer(a, b).ravel())

  • rewrite_solve_diag_to_division: solve(diag(x), b) → b / x

The key is that after the first rewrite produces a diagonal matrix (via alloc_diag), the assumption system recognizes it as diagonal (via AllocDiag being registered with DIAGONAL), and then the second rewrite kicks in.

Non-Goals

The purpose of this system is not to introduce an complete, closed algebra over all types. That is impossible. The goal is also not to complete the project of German romanticism. That is also impossible.

  • There should never be any "global state" (no Wolfram Assuming). Assumptions live on FunctionGraphs. There is never a need to to deal with global context, logical combinations, relational assumptions.
  • We are not trying to be smart or figure out as much as possible. On the contrary, we want to be very dumb. The Assumptions should be maximally conservative, and fall back to UNKNOWN whenever there is runtime ambiguity.
  • Assumption logic lives separately from all other machinery. No other part of Pytensor needs to know anything about them. Evaluation logic does not check assumptions. Perform methods don't dispatch on the FactState of the Node. We check during rewrites and that's it.
  • We do not aspire to provide Maple-style conditional dispatches. If we don't know, we just don't know. There is no LinearAlgebra[IsDefinite]. No symbolic conditionals as part of the core API.
  • We are not and cannot be a theorm prover. We do not present assumptions to the user as a tool for this. Assumptions are primarily an inner-api, and rewrite-focused.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 7, 2026

I like the top message (didn't look at the code). Two notes:

1

Missing some discussion on how to preserve information across rewrites. ShapeFeature combines shape information when you replace a->b and you knew something of a and something of b (or just one of them).

This is specially relevant for constant folding (e.g, eye(int(1e6))) where it's much cheaper/obvious before vs after

But then, when do you decide to ask/start checking assumption? Because this can also be useless work (e.g. a graph without any linalg stuff).

Ordering is hard

2

Why do you assume that you can do all the reasoning you need to from op and inputs? It's a small note and I don't think a requirement/restriction in your proposal.

But eg checking for tridiagonal matrix creation (which we do) requires checking 2/3 nested set_subtensor nodes.

@maresb
Copy link
Copy Markdown
Contributor

maresb commented Apr 7, 2026

Very disappointed that German romanticism is out of scope

But seriously, this looks amazing. I'm not really capable of reviewing this but I am very excited for this one.

try:
val = get_underlying_scalar_constant_value(node.inputs[0])
if val == 0:
return [FactState.TRUE]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want FactState=False? when it is not zero but still known constant?

from pytensor.tensor.basic import as_tensor_variable


class SpecifyAssumptions(Op):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subclass TypeCastOp, which does most of the things you need dispatch/runtime-wise

return output_grads


def specify_assumptions(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter name? pytensor.assume?

(Also this need not be in tensor module, seems more generic that it)

return False


def _same_variable(a, b) -> bool:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is doing to much, a is b should be enough after MergeOptimization runs. I feel we have too many methods eagerly trying to assert equality in their own way

if isinstance(op, AdvancedIncSubtensor):
# inputs: (x, y, *index_arrays)
index_arrays = node.inputs[2:]
if len(index_arrays) >= 2:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. There can be many indexing variables and they can have different meanings. Maybe first is slice start point and second of integer index (along a completely different dimension)


diag_a = diagonal(A, axis1=-2, axis2=-1)
diag_b = diagonal(B, axis1=-2, axis2=-1)
kron_diag = outer(diag_a, diag_b).ravel()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds wrong wrt batch dims? Also it's join_dims time baby

if inputs is None:
from pytensor.graph.traversal import graph_inputs

inputs = graph_inputs(outputs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed, if you build FunctionGraph(outputs=x) and leave inputs as None, it finds them for itself

return true_if(eye_is_identity(node))


@register_assumption(ORTHOGONAL, MatrixInverse)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use BlockwiseOf?

Copy link
Copy Markdown
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the goal of the PR. I couldn't finish reviewing but I'll just leave the comments I've written down so far. I'll try to come back to this later and write something more decent.

f"infer_assumption returned {len(output_states)} states for "
f"{len(node.outputs)} outputs on node {node!r}"
)
return [FactState(s) for s in output_states]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return [FactState(s) for s in output_states]
return output_states

Isn't it guaranteed that the output_states will be a list of FactState objects?

) -> list[FactState]:
"""Determine the *key* fact for every output of *node*.

Resolution order:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the opposite resolution order than the one used in blockwise. Which should be preferred?

state = FactState(state)
cache_key = (var, key)
old = self.user_facts.get(cache_key, FactState.UNKNOWN)
new = FactState.join(old, state)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If old is CONFLICT, wont join also return a conflict? It wont' overwrite the old fact. Is this intended?

"""Return ``True`` iff the assumption is definitively TRUE for ``var``."""
return bool(self.get(var, key))

def set_user_fact(self, var: Any, key: AssumptionKey, state: FactState) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the name confusing. I thought that it would have set the state without joining it with the existing state. I noticed the docstring after I looked at the following method, replace_user_fact

@register_assumption(DIAGONAL, DimShuffle)
def _dimshuffle(op, feature, fgraph, node, input_states):
if not input_states[0]:
return [FactState.UNKNOWN]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return a input_states[0] directly? I'm having a hard time understanding the need for the FactState.FALSE. It looks like everything is either true or unknown.

@register_assumption(ORTHOGONAL, Eye)
def _eye(op, feature, fgraph, node, input_states):
return true_if(eye_is_identity(node))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add something about the transpose? Or a dimshuffle of the last two dimensions?


@register_assumption(DIAGONAL, Dot)
def _dot(op, feature, fgraph, node, input_states):
return true_if(all(input_states))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just came back to this after having looked at the orthogonal assumptions module. How would you handle the case where Q @ Q.T = eye? In other words, if the two inputs are orthogonal and one is the transpose of the other. Their product would produce an identity matrix. Would you be able to get the assumption from a different set (orthogonality) while working through diagonality?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. The system as it exists don't have good tooling for handling cross-facts like that. Need to pause and ponder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants