|
| 1 | +from pytensor.graph.basic import Constant |
| 2 | +from pytensor.tensor.assumptions.core import FactState |
| 3 | + |
| 4 | + |
| 5 | +def true_if(cond: bool) -> list[FactState]: |
| 6 | + """``[TRUE]`` when *cond* holds, ``[UNKNOWN]`` otherwise.""" |
| 7 | + return [FactState.TRUE] if cond else [FactState.UNKNOWN] |
| 8 | + |
| 9 | + |
| 10 | +def propagate_first(op, feature, fgraph, node, input_states) -> list[FactState]: |
| 11 | + """Output inherits the assumption iff the first input has it.""" |
| 12 | + return true_if(input_states[0]) |
| 13 | + |
| 14 | + |
| 15 | +def all_inputs_have_key(op, feature, fgraph, node, input_states) -> list[FactState]: |
| 16 | + """Output inherits the assumption iff *every* input has it.""" |
| 17 | + return true_if(all(input_states)) |
| 18 | + |
| 19 | + |
| 20 | +def eye_is_identity(node) -> bool: |
| 21 | + """True when an :class:`Eye` node produces the identity matrix (square, k == 0).""" |
| 22 | + n, m, k = node.inputs |
| 23 | + if not (isinstance(k, Constant) and k.data.item() == 0): |
| 24 | + return False |
| 25 | + if n is m: |
| 26 | + return True |
| 27 | + if isinstance(n, Constant) and isinstance(m, Constant): |
| 28 | + return n.data.item() == m.data.item() |
| 29 | + return False |
| 30 | + |
| 31 | + |
| 32 | +def _same_variable(a, b) -> bool: |
| 33 | + """True when *a* and *b* represent the same graph variable, including ``ScalarFromTensor`` wrappers.""" |
| 34 | + if a is b: |
| 35 | + return True |
| 36 | + if ( |
| 37 | + a.owner is not None |
| 38 | + and b.owner is not None |
| 39 | + and type(a.owner.op) is type(b.owner.op) |
| 40 | + and len(a.owner.inputs) == 1 |
| 41 | + and len(b.owner.inputs) == 1 |
| 42 | + and a.owner.inputs[0] is b.owner.inputs[0] |
| 43 | + ): |
| 44 | + return True |
| 45 | + return False |
| 46 | + |
| 47 | + |
| 48 | +def indexes_diagonal(node) -> bool: |
| 49 | + """True when an ``*IncSubtensor*`` node modifies only diagonal entries.""" |
| 50 | + from pytensor.tensor.subtensor import AdvancedIncSubtensor, IncSubtensor |
| 51 | + |
| 52 | + op = node.op |
| 53 | + if isinstance(op, AdvancedIncSubtensor): |
| 54 | + # inputs: (x, y, *index_arrays) |
| 55 | + index_arrays = node.inputs[2:] |
| 56 | + if len(index_arrays) >= 2: |
| 57 | + return _same_variable(index_arrays[-2], index_arrays[-1]) |
| 58 | + return False |
| 59 | + |
| 60 | + if isinstance(op, IncSubtensor): |
| 61 | + # idx_list entries: int = scalar index (consumes a dynamic input), |
| 62 | + # slice = slice (no dynamic input for static parts) |
| 63 | + # Dynamic inputs are in node.inputs[2:], one per non-slice entry. |
| 64 | + idx_list = op.idx_list |
| 65 | + if len(idx_list) < 2: |
| 66 | + return False |
| 67 | + # Last two entries must both be scalar indices (not slices) |
| 68 | + if isinstance(idx_list[-1], slice) or isinstance(idx_list[-2], slice): |
| 69 | + return False |
| 70 | + # Map each non-slice idx_list entry to its dynamic input |
| 71 | + dynamic_inputs = list(node.inputs[2:]) |
| 72 | + non_slice_positions = [ |
| 73 | + i for i, entry in enumerate(idx_list) if not isinstance(entry, slice) |
| 74 | + ] |
| 75 | + if len(non_slice_positions) < 2: |
| 76 | + return False |
| 77 | + # The last two non-slice positions correspond to the last two dynamic inputs |
| 78 | + pos_a = non_slice_positions[-2] |
| 79 | + pos_b = non_slice_positions[-1] |
| 80 | + # idx in dynamic_inputs list = count of non-slice entries before this one |
| 81 | + dyn_idx_a = sum(1 for e in idx_list[:pos_a] if not isinstance(e, slice)) |
| 82 | + dyn_idx_b = sum(1 for e in idx_list[:pos_b] if not isinstance(e, slice)) |
| 83 | + if dyn_idx_a < len(dynamic_inputs) and dyn_idx_b < len(dynamic_inputs): |
| 84 | + return _same_variable(dynamic_inputs[dyn_idx_a], dynamic_inputs[dyn_idx_b]) |
| 85 | + return False |
| 86 | + |
| 87 | + return False |
0 commit comments