Skip to content

Commit 2410ecc

Browse files
core design
1 parent 58e3ff1 commit 2410ecc

File tree

3 files changed

+470
-0
lines changed

3 files changed

+470
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pytensor.tensor.assumptions.core import (
2+
AssumptionFeature,
3+
AssumptionKey,
4+
FactState,
5+
lookup_assumption_rule,
6+
register_assumption,
7+
register_implies,
8+
)
9+
from pytensor.tensor.assumptions.diagonal import DIAGONAL
10+
11+
12+
ALL_KEYS = (DIAGONAL,)
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
from collections import deque
2+
from collections.abc import Callable, Iterable
3+
from dataclasses import dataclass
4+
from enum import IntFlag, auto
5+
from typing import Any
6+
7+
from pytensor.graph import Apply, FunctionGraph, Op
8+
from pytensor.graph.features import AlreadyThere, Feature
9+
10+
11+
class FactState(IntFlag):
12+
"""Three-valued logic for assumption inference.
13+
14+
The three fact states are TRUE, FALSE and UNKNOWN.
15+
16+
UNKNOWN is the default condition, in which we cannot confirm or deny the fact. TRUE and FALSE are definitive
17+
states. If we have evidence that a fact is both TRUE and FALSE, we get CONFLICT. In general, a CONFLICT state
18+
should not be possible.
19+
"""
20+
21+
UNKNOWN = 0
22+
TRUE = auto()
23+
FALSE = auto()
24+
CONFLICT = TRUE | FALSE
25+
26+
def __bool__(self) -> bool:
27+
return self is FactState.TRUE
28+
29+
@classmethod
30+
def join(cls, left: "FactState", right: "FactState") -> "FactState":
31+
"""Combine two pieces of evidence about the *same* (variable, key)."""
32+
return cls(left | right)
33+
34+
35+
@dataclass(frozen=True)
36+
class AssumptionKey:
37+
"""Identifies a named structural property (e.g. "diagonal" or "triangular")."""
38+
39+
name: str
40+
41+
def __repr__(self) -> str:
42+
return self.name
43+
44+
45+
# An inference function takes an Op, the current AssumptionFeature, the FunctionGraph, the Apply node being analyzed,
46+
# the states of the input variables for the current key, and returns a list of FactState (one per output).
47+
InferFactFn = Callable[
48+
[Op, "AssumptionFeature", FunctionGraph, Apply, list[FactState]],
49+
list[FactState],
50+
]
51+
52+
# The global inference registry maps (AssumptionKey, Op type) pairs to inference functions. The most specific
53+
# applicable rule is used for each node.
54+
ASSUMPTION_INFER_REGISTRY: dict[tuple[AssumptionKey, type], InferFactFn] = {}
55+
56+
# Registry mapping assumptions to other assumptions they imply. For example, a "diagonal" matrix is also "symmetric"
57+
# and "triangular". This is consulted after all other inference rules to derive additional facts.
58+
IMPLIES: dict[AssumptionKey, list[AssumptionKey]] = {}
59+
60+
61+
def register_implies(stronger: AssumptionKey, *weaker: AssumptionKey) -> None:
62+
"""Declare that *stronger* being TRUE implies each *weaker* key is also TRUE."""
63+
IMPLIES.setdefault(stronger, []).extend(weaker)
64+
65+
66+
def register_assumption(
67+
key: AssumptionKey, *op_types: type
68+
) -> Callable[[InferFactFn], InferFactFn]:
69+
"""Decorator that registers an inference rule for ``(key, op_type)`` pairs.
70+
71+
The decorated function is called as ``fn(op, feature, fgraph, node, input_states)``
72+
and must return a list of :class:`FactState` with one entry per node output.
73+
"""
74+
75+
def decorator(fn: InferFactFn) -> InferFactFn:
76+
for op_type in op_types:
77+
ASSUMPTION_INFER_REGISTRY[(key, op_type)] = fn
78+
return fn
79+
80+
return decorator
81+
82+
83+
def lookup_assumption_rule(key: AssumptionKey, op: Any) -> InferFactFn | None:
84+
"""Find the most specific registered rule for *(key, type(op))*, walking the MRO."""
85+
for cls in type(op).__mro__:
86+
fn = ASSUMPTION_INFER_REGISTRY.get((key, cls))
87+
if fn is not None:
88+
return fn
89+
return None
90+
91+
92+
def _default_infer_assumption(node: Any) -> list[FactState]:
93+
"""Absent evidence, all facts are assumed to be UNKNOWN for all outputs of all Ops."""
94+
return [FactState.UNKNOWN] * len(node.outputs)
95+
96+
97+
def _validate_output_states(
98+
node: Any, output_states: list[FactState]
99+
) -> list[FactState]:
100+
if len(output_states) != len(node.outputs):
101+
raise ValueError(
102+
f"infer_assumption returned {len(output_states)} states for "
103+
f"{len(node.outputs)} outputs on node {node!r}"
104+
)
105+
return [FactState(s) for s in output_states]
106+
107+
108+
def infer_assumption_for_node(
109+
op: Op,
110+
key: AssumptionKey,
111+
feature: "AssumptionFeature",
112+
fgraph: FunctionGraph,
113+
node: Apply,
114+
input_states: list[FactState],
115+
) -> list[FactState]:
116+
"""Determine the *key* fact for every output of *node*.
117+
118+
Resolution order:
119+
1. ``op.infer_assumption(key, feature, fgraph, node, input_states)``
120+
2. Registered rule via :func:`register_assumption`
121+
3. Conservative ``UNKNOWN`` for every output.
122+
"""
123+
meth = getattr(op, "infer_assumption", None)
124+
if meth is not None:
125+
output_states = meth(key, feature, fgraph, node, input_states)
126+
if output_states is not NotImplemented:
127+
return _validate_output_states(node, output_states)
128+
129+
fn = lookup_assumption_rule(key, op)
130+
if fn is not None:
131+
output_states = fn(op, feature, fgraph, node, input_states)
132+
return _validate_output_states(node, output_states)
133+
134+
return _default_infer_assumption(node)
135+
136+
137+
class AssumptionFeature(Feature):
138+
"""``FunctionGraph`` feature that tracks symbolic assumptions about variables.
139+
140+
Assumptions (e.g. "this matrix is diagonal") are represented as ``(variable, AssumptionKey) -> FactState``
141+
mappings. Facts are inferred lazily via per-Op rules registered with :func:`register_assumption` or via
142+
an ``infer_assumption`` method on the Op itself.
143+
144+
Results are cached and automatically invalidated when the graph changes.
145+
"""
146+
147+
__slots__ = ("cache", "fgraph", "user_facts")
148+
149+
def on_attach(self, fgraph: Any) -> None:
150+
if hasattr(fgraph, "assumption_feature"):
151+
raise AlreadyThere("AssumptionFeature is already attached")
152+
self.fgraph = fgraph
153+
self.cache: dict[tuple[Any, AssumptionKey], FactState] = {}
154+
self.user_facts: dict[tuple[Any, AssumptionKey], FactState] = {}
155+
fgraph.assumption_feature = self
156+
157+
def on_detach(self, fgraph: Any) -> None:
158+
self.cache = {}
159+
self.user_facts = {}
160+
self.fgraph = None
161+
del fgraph.assumption_feature
162+
163+
def on_import(self, fgraph, node, reason) -> None:
164+
self.invalidate_from_vars(node.outputs)
165+
166+
def on_change_input(self, fgraph, node, i, old_var, new_var, reason=None) -> None:
167+
if node is not None:
168+
self.invalidate_from_vars(node.outputs)
169+
170+
def on_prune(self, fgraph, node, reason) -> None:
171+
self.invalidate_from_vars(node.outputs)
172+
173+
def clone(self) -> "AssumptionFeature":
174+
return AssumptionFeature()
175+
176+
def get(self, var: Any, key: AssumptionKey) -> FactState:
177+
"""Return the inferred :class:`FactState` for ``(var, key)``"""
178+
cache_key = (var, key)
179+
if cache_key not in self.cache:
180+
self.cache[cache_key] = self._compute(var, key)
181+
return self.cache[cache_key]
182+
183+
def check(self, var: Any, key: AssumptionKey) -> bool:
184+
"""Return ``True`` iff the assumption is definitively TRUE for ``var``."""
185+
return bool(self.get(var, key))
186+
187+
def set_user_fact(self, var: Any, key: AssumptionKey, state: FactState) -> None:
188+
"""Join *state* with any existing user evidence for ``(var, key)``."""
189+
state = FactState(state)
190+
cache_key = (var, key)
191+
old = self.user_facts.get(cache_key, FactState.UNKNOWN)
192+
new = FactState.join(old, state)
193+
if new != old:
194+
self.user_facts[cache_key] = new
195+
self.invalidate_from_vars([var])
196+
197+
def replace_user_fact(self, var: Any, key: AssumptionKey, state: FactState) -> None:
198+
"""Overwrite user evidence for ``(var, key)``."""
199+
self.user_facts[(var, key)] = FactState(state)
200+
self.invalidate_from_vars([var])
201+
202+
def clear_user_fact(self, var: Any, key: AssumptionKey) -> None:
203+
cache_key = (var, key)
204+
if cache_key in self.user_facts:
205+
del self.user_facts[cache_key]
206+
self.invalidate_from_vars([var])
207+
208+
def _compute(self, var: Any, key: AssumptionKey) -> FactState:
209+
"""Propagate the knowledge state through the function graph to determine the fact state for ``(var, key)``."""
210+
state = FactState.UNKNOWN
211+
state = FactState.join(state, self.static_fact(var, key))
212+
state = FactState.join(
213+
state, self.user_facts.get((var, key), FactState.UNKNOWN)
214+
)
215+
216+
owner = getattr(var, "owner", None)
217+
if owner is not None:
218+
prev_key = getattr(self, "_current_key", None)
219+
self._current_key = key
220+
try:
221+
input_states = [self.get(inp, key) for inp in owner.inputs]
222+
output_states = infer_assumption_for_node(
223+
owner.op, key, self, self.fgraph, owner, input_states
224+
)
225+
finally:
226+
self._current_key = prev_key
227+
228+
out_idx = owner.outputs.index(var)
229+
state = FactState.join(state, output_states[out_idx])
230+
231+
if not state:
232+
for stronger, weaker_list in IMPLIES.items():
233+
if key in weaker_list and self.get(var, stronger):
234+
state = FactState.join(state, FactState.TRUE)
235+
break
236+
237+
return state
238+
239+
def static_fact(self, var: Any, key: AssumptionKey) -> FactState:
240+
"""Hook for non-Op fact sources. Returns UNKNOWN by default."""
241+
return FactState.UNKNOWN
242+
243+
def invalidate_from_vars(self, start_vars: Iterable[Any]) -> None:
244+
"""Clear cached facts for *start_vars* and everything downstream."""
245+
queue = deque(start_vars)
246+
seen = {id(v) for v in start_vars}
247+
while queue:
248+
var = queue.popleft()
249+
self._clear_cached_var(var)
250+
for client_node, _ in self.fgraph.clients.get(var, ()):
251+
for out in client_node.outputs:
252+
if id(out) not in seen:
253+
seen.add(id(out))
254+
queue.append(out)
255+
256+
def _clear_cached_var(self, var: Any) -> None:
257+
stale = [k for k in self.cache if k[0] is var]
258+
for k in stale:
259+
del self.cache[k]

0 commit comments

Comments
 (0)