Skip to content

Commit cece59a

Browse files
Utilities
1 parent 2f77897 commit cece59a

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from pytensor.graph.basic import Constant
2+
from pytensor.tensor.assumptions.core import FactState
3+
4+
5+
def true_if(cond: bool) -> list[FactState]:
6+
"""``[TRUE]`` when *cond* holds, ``[UNKNOWN]`` otherwise."""
7+
return [FactState.TRUE] if cond else [FactState.UNKNOWN]
8+
9+
10+
def propagate_first(op, feature, fgraph, node, input_states) -> list[FactState]:
11+
"""Output inherits the assumption iff the first input has it."""
12+
return true_if(input_states[0])
13+
14+
15+
def all_inputs_have_key(op, feature, fgraph, node, input_states) -> list[FactState]:
16+
"""Output inherits the assumption iff *every* input has it."""
17+
return true_if(all(input_states))
18+
19+
20+
def eye_is_identity(node) -> bool:
21+
"""True when an :class:`Eye` node produces the identity matrix (square, k == 0)."""
22+
n, m, k = node.inputs
23+
if not (isinstance(k, Constant) and k.data.item() == 0):
24+
return False
25+
if n is m:
26+
return True
27+
if isinstance(n, Constant) and isinstance(m, Constant):
28+
return n.data.item() == m.data.item()
29+
return False
30+
31+
32+
def _same_variable(a, b) -> bool:
33+
"""True when *a* and *b* represent the same graph variable, including ``ScalarFromTensor`` wrappers."""
34+
if a is b:
35+
return True
36+
if (
37+
a.owner is not None
38+
and b.owner is not None
39+
and type(a.owner.op) is type(b.owner.op)
40+
and len(a.owner.inputs) == 1
41+
and len(b.owner.inputs) == 1
42+
and a.owner.inputs[0] is b.owner.inputs[0]
43+
):
44+
return True
45+
return False
46+
47+
48+
def indexes_diagonal(node) -> bool:
49+
"""True when an ``*IncSubtensor*`` node modifies only diagonal entries."""
50+
from pytensor.tensor.subtensor import AdvancedIncSubtensor, IncSubtensor
51+
52+
op = node.op
53+
if isinstance(op, AdvancedIncSubtensor):
54+
# inputs: (x, y, *index_arrays)
55+
index_arrays = node.inputs[2:]
56+
if len(index_arrays) >= 2:
57+
return _same_variable(index_arrays[-2], index_arrays[-1])
58+
return False
59+
60+
if isinstance(op, IncSubtensor):
61+
# idx_list entries: int = scalar index (consumes a dynamic input),
62+
# slice = slice (no dynamic input for static parts)
63+
# Dynamic inputs are in node.inputs[2:], one per non-slice entry.
64+
idx_list = op.idx_list
65+
if len(idx_list) < 2:
66+
return False
67+
# Last two entries must both be scalar indices (not slices)
68+
if isinstance(idx_list[-1], slice) or isinstance(idx_list[-2], slice):
69+
return False
70+
# Map each non-slice idx_list entry to its dynamic input
71+
dynamic_inputs = list(node.inputs[2:])
72+
non_slice_positions = [
73+
i for i, entry in enumerate(idx_list) if not isinstance(entry, slice)
74+
]
75+
if len(non_slice_positions) < 2:
76+
return False
77+
# The last two non-slice positions correspond to the last two dynamic inputs
78+
pos_a = non_slice_positions[-2]
79+
pos_b = non_slice_positions[-1]
80+
# idx in dynamic_inputs list = count of non-slice entries before this one
81+
dyn_idx_a = sum(1 for e in idx_list[:pos_a] if not isinstance(e, slice))
82+
dyn_idx_b = sum(1 for e in idx_list[:pos_b] if not isinstance(e, slice))
83+
if dyn_idx_a < len(dynamic_inputs) and dyn_idx_b < len(dynamic_inputs):
84+
return _same_variable(dynamic_inputs[dyn_idx_a], dynamic_inputs[dyn_idx_b])
85+
return False
86+
87+
return False

0 commit comments

Comments
 (0)