System for algebraic reasoning about linear alegbra#2032
System for algebraic reasoning about linear alegbra#2032jessegrabowski wants to merge 11 commits intopymc-devs:v3from
Conversation
cece59a to
9f6e7f5
Compare
|
I like the top message (didn't look at the code). Two notes: 1Missing some discussion on how to preserve information across rewrites. ShapeFeature combines shape information when you replace a->b and you knew something of a and something of b (or just one of them). This is specially relevant for constant folding (e.g, But then, when do you decide to ask/start checking assumption? Because this can also be useless work (e.g. a graph without any linalg stuff). Ordering is hard 2Why do you assume that you can do all the reasoning you need to from op and inputs? It's a small note and I don't think a requirement/restriction in your proposal. But eg checking for tridiagonal matrix creation (which we do) requires checking 2/3 nested set_subtensor nodes. |
|
Very disappointed that German romanticism is out of scope But seriously, this looks amazing. I'm not really capable of reviewing this but I am very excited for this one. |
| try: | ||
| val = get_underlying_scalar_constant_value(node.inputs[0]) | ||
| if val == 0: | ||
| return [FactState.TRUE] |
There was a problem hiding this comment.
Do we want FactState=False? when it is not zero but still known constant?
| from pytensor.tensor.basic import as_tensor_variable | ||
|
|
||
|
|
||
| class SpecifyAssumptions(Op): |
There was a problem hiding this comment.
Subclass TypeCastOp, which does most of the things you need dispatch/runtime-wise
| return output_grads | ||
|
|
||
|
|
||
| def specify_assumptions( |
There was a problem hiding this comment.
Shorter name? pytensor.assume?
(Also this need not be in tensor module, seems more generic that it)
| return False | ||
|
|
||
|
|
||
| def _same_variable(a, b) -> bool: |
There was a problem hiding this comment.
I feel this is doing to much, a is b should be enough after MergeOptimization runs. I feel we have too many methods eagerly trying to assert equality in their own way
| if isinstance(op, AdvancedIncSubtensor): | ||
| # inputs: (x, y, *index_arrays) | ||
| index_arrays = node.inputs[2:] | ||
| if len(index_arrays) >= 2: |
There was a problem hiding this comment.
This is wrong. There can be many indexing variables and they can have different meanings. Maybe first is slice start point and second of integer index (along a completely different dimension)
|
|
||
| diag_a = diagonal(A, axis1=-2, axis2=-1) | ||
| diag_b = diagonal(B, axis1=-2, axis2=-1) | ||
| kron_diag = outer(diag_a, diag_b).ravel() |
There was a problem hiding this comment.
sounds wrong wrt batch dims? Also it's join_dims time baby
| if inputs is None: | ||
| from pytensor.graph.traversal import graph_inputs | ||
|
|
||
| inputs = graph_inputs(outputs) |
There was a problem hiding this comment.
not needed, if you build FunctionGraph(outputs=x) and leave inputs as None, it finds them for itself
| return true_if(eye_is_identity(node)) | ||
|
|
||
|
|
||
| @register_assumption(ORTHOGONAL, MatrixInverse) |
lucianopaz
left a comment
There was a problem hiding this comment.
I really like the goal of the PR. I couldn't finish reviewing but I'll just leave the comments I've written down so far. I'll try to come back to this later and write something more decent.
| f"infer_assumption returned {len(output_states)} states for " | ||
| f"{len(node.outputs)} outputs on node {node!r}" | ||
| ) | ||
| return [FactState(s) for s in output_states] |
There was a problem hiding this comment.
| return [FactState(s) for s in output_states] | |
| return output_states |
Isn't it guaranteed that the output_states will be a list of FactState objects?
| ) -> list[FactState]: | ||
| """Determine the *key* fact for every output of *node*. | ||
|
|
||
| Resolution order: |
There was a problem hiding this comment.
This is the opposite resolution order than the one used in blockwise. Which should be preferred?
| state = FactState(state) | ||
| cache_key = (var, key) | ||
| old = self.user_facts.get(cache_key, FactState.UNKNOWN) | ||
| new = FactState.join(old, state) |
There was a problem hiding this comment.
If old is CONFLICT, wont join also return a conflict? It wont' overwrite the old fact. Is this intended?
| """Return ``True`` iff the assumption is definitively TRUE for ``var``.""" | ||
| return bool(self.get(var, key)) | ||
|
|
||
| def set_user_fact(self, var: Any, key: AssumptionKey, state: FactState) -> None: |
There was a problem hiding this comment.
I found the name confusing. I thought that it would have set the state without joining it with the existing state. I noticed the docstring after I looked at the following method, replace_user_fact
| @register_assumption(DIAGONAL, DimShuffle) | ||
| def _dimshuffle(op, feature, fgraph, node, input_states): | ||
| if not input_states[0]: | ||
| return [FactState.UNKNOWN] |
There was a problem hiding this comment.
Why not return a input_states[0] directly? I'm having a hard time understanding the need for the FactState.FALSE. It looks like everything is either true or unknown.
| @register_assumption(ORTHOGONAL, Eye) | ||
| def _eye(op, feature, fgraph, node, input_states): | ||
| return true_if(eye_is_identity(node)) | ||
|
|
There was a problem hiding this comment.
Maybe also add something about the transpose? Or a dimshuffle of the last two dimensions?
|
|
||
| @register_assumption(DIAGONAL, Dot) | ||
| def _dot(op, feature, fgraph, node, input_states): | ||
| return true_if(all(input_states)) |
There was a problem hiding this comment.
I just came back to this after having looked at the orthogonal assumptions module. How would you handle the case where Q @ Q.T = eye? In other words, if the two inputs are orthogonal and one is the transpose of the other. Their product would produce an identity matrix. Would you be able to get the assumption from a different set (orthogonality) while working through diagonality?
There was a problem hiding this comment.
Good question. The system as it exists don't have good tooling for handling cross-facts like that. Need to pause and ponder.
This PR is a proposal for a typing system for linear algebra primitives. The purpose is to enable graph-wide reasoning about the kinds of matrices, so that we can rewrite to efficient computational forms.
Current State
We currently have several linear algebra rewrites and plan to add more. These are tracked in #573. This is important because linear algebra is 1) ubiquotous, 2) expensive, and 3) inscrutiable. Pytensor's static graph representation and rewrite system is well positioned to provide users help writing the best possible programs involving heavy linear algebra, if only we can figure out what is going on.
Consider the motivating case of
solve(A, b), whenA = pt.diag(pt.arange(100_000)). This is an O(n^3) operation that will call out to specialized routines. But there is no need for this. Since A is diagonal, we can write this as an elementwise divisionb / pt.extract_diag(A).How do we decide to do this? We:
SolveOpWhat seems diagonal? We assume an input is diagonal if it was created by
pt.eyeorpt.diag. Users cannot specify themselves whether input data is diagonal. If anOpget inbetween the known "diagonalish"Opand theSolve, we cannot detect diagonality. For example, we cannot rewritesolve(A * 3, b), because now the first input isElemwise(Mul)(A, 3). Multiplication is diagonal-preserving (because it is zero-preserving), but since the known diagonal op is now buried inside the Elemwise, we're out of luck.Proposal
My proposal is to reason about algebraic properties of matrices the same way we reason about shapes. For shapes, we attach a
ShapeFeaturetoFunctionGraphs. EachOphas aninfer_shapemethod that explains how the static shape propagates. Likewise, I propose anAssumptionFeature.Opsdo not haveinfer_assumptionmethods. That would be too messy. Instead, we have a centralASSUMPTION_INFER_REGISTRYwith keys(Op, AssumptionKey )and valuesInferFactFn.AssumptionKeyis just a marker class corresponding to an algebraic fact about a matrix, likeDIAGONAL,LOWER_TRIANGULAR,ORTHOGONAL,POSITIVE,SEMIDEFINITE, and so on.InferFactFnhas the following signature:Like other symbolic operations, the
InferFactFunctiontakes an Op (plus global information about the graph it lives in,fgraphandassumption_feature), and information about its inputs (the list ofFactStates) and returns a list of information about its outputs.A
FactStateis a three-valued logic for assumption inference. The possible values areUNKNOWN,TRUE, orFALSE. A fourth state,CONFLICT, exists, but should never arise.All facts about all Ops are assumed to be
UNKNOWNunless we can prove otherwise. Proof comes from each Op's registeredInferFactFunctions. TheAssumptionFeatureis responsible for gathering all the rules of fact propagation. An example of a simple fact is that allEyeOpsareDIAGONAL, provided it is 1) square and 2) offset of zero:In a program, we can use the
AssumptionFeature.get(x, AssumptionKey)to query about the state of aVariable. Here, we ask "is x diagonal?". Obviously it is:Where this becomes powerful is by accumulating
InferFactFunctions. ManyOpspreserve diagonality, likeCholeskyorInverse. We can use information about the inputs to theseOpsto propogate fact information through the graph:Now we don't lose information about
xin deeper graphs, and are free to do more rewrites:Of course can also reason conditionally. An
IncSubtensormight be diagonal-preserving if we can prove that we're setting a value on the diagonal of the matrix. Otherwise we fall back toUNKNOWN:Facts can also imply other facts.
DIAGONALmatrices are also symmetrical. These general relationships can be registered and encoded as well. Continuing the example above:Finally, users can specify facts about matrices using
pt.specify_assumptions, the same way they are able to specify shapes.Benefits for rewriting:
FactStatesare trivial to check. We can check any fact about any Variable in 5 lines:We can reason globally about the graph
As noted above, information flows through the graph. As long as we have good coverage of fact rules for Ops, we can make statements about Variables at all levels of computation
InferFactFunctions are lightweight and easy to writeIt is trivial to add new
InferFactFunctions. LLMs can bang them out. They require only local reasoning about theOpand its immediate inputs.FactStatesallow non-trivial combinations of rewritesOne example I hit while working on this was a rewrite for
DirectSolveLyapunovgiven diagonal inputs. Because there is a rule thatkronpreserves diagonality if both inputs are diagonal, the chain of rewrites fromsolve_discrete_lyapunov(diag(a), Q) -> Q.ravel() / (1 - outer(a, a).ravel())is discovered by the rewrite system via:rewrite_kron_diag_to_diag_outer: kron(diag(a), diag(b)) → diag(outer(a, b).ravel())rewrite_solve_diag_to_division: solve(diag(x), b) → b / xThe key is that after the first rewrite produces a diagonal matrix (via alloc_diag), the assumption system recognizes it as diagonal (via AllocDiag being registered with DIAGONAL), and then the second rewrite kicks in.
Non-Goals
The purpose of this system is not to introduce an complete, closed algebra over all types. That is impossible. The goal is also not to complete the project of German romanticism. That is also impossible.
Assuming). Assumptions live on FunctionGraphs. There is never a need to to deal with global context, logical combinations, relational assumptions.FactStateof theNode. We check during rewrites and that's it.LinearAlgebra[IsDefinite]. No symbolic conditionals as part of the core API.