|
| 1 | +from collections.abc import Sequence |
| 2 | +from dataclasses import dataclass |
| 3 | +from functools import reduce |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +from pytensor.graph.basic import Constant, Variable |
| 9 | +from pytensor.graph.replace import graph_replace |
| 10 | +from pytensor.graph.traversal import graph_inputs, truncated_graph_inputs |
| 11 | +from pytensor.scan.op import Scan, ScanInfo |
| 12 | +from pytensor.scan.utils import expand_empty |
| 13 | +from pytensor.tensor import TensorVariable, as_tensor, minimum |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class ShiftedArg: |
| 18 | + x: Any |
| 19 | + by: tuple[int, ...] |
| 20 | + |
| 21 | + |
| 22 | +@dataclass(frozen=True) |
| 23 | +class InnerShiftedArg: |
| 24 | + x: Any |
| 25 | + by: tuple[int, ...] |
| 26 | + readonly: bool |
| 27 | + update: Variable | None = None |
| 28 | + |
| 29 | + def push(self, x: Variable) -> "InnerShiftedArg": |
| 30 | + if self.readonly: |
| 31 | + raise ValueError( |
| 32 | + "Cannot push to a read-only ShiftedArg (xs shifts cannot be updated)" |
| 33 | + ) |
| 34 | + if self.update is not None: |
| 35 | + raise ValueError("ShiftedArg can only have a value pushed once") |
| 36 | + return type(self)(x=self.x, by=self.by, update=x, readonly=self.readonly) |
| 37 | + |
| 38 | + def __getitem__(self, idx): |
| 39 | + if -len(self.by) <= idx < len(self.by): |
| 40 | + return self.x[idx] |
| 41 | + else: |
| 42 | + raise IndexError() |
| 43 | + |
| 44 | + def __len__(self): |
| 45 | + return len(self.by) |
| 46 | + |
| 47 | + |
| 48 | +def shift(x: Any, by: int | Sequence[int] = -1): |
| 49 | + by = (by,) if isinstance(by, int) else tuple(by) |
| 50 | + if by != tuple(sorted(by)): |
| 51 | + raise ValueError(f"by entries must be sorted, got {by}") |
| 52 | + if min(by) < 0 and max(by) >= 0: |
| 53 | + raise ValueError( |
| 54 | + f"by cannot contain both negative and non-negative entries, got {by}" |
| 55 | + ) |
| 56 | + # TODO: If shape is known statically, validate the input is as big as the min/max taps |
| 57 | + return ShiftedArg(x, by=by) |
| 58 | + |
| 59 | + |
| 60 | +def flatten_tree(x, subsume_none: bool = False) -> tuple[tuple, Any]: |
| 61 | + def recurse(e, spec): |
| 62 | + match e: |
| 63 | + case tuple() | list(): |
| 64 | + e_spec = [] |
| 65 | + for e_i in e: |
| 66 | + yield from recurse(e_i, e_spec) |
| 67 | + if isinstance(e, tuple): |
| 68 | + e_spec = tuple(e_spec) |
| 69 | + spec.append(e_spec) |
| 70 | + case None if subsume_none: |
| 71 | + spec.append(None) |
| 72 | + case x: |
| 73 | + spec.append("x") |
| 74 | + yield x |
| 75 | + |
| 76 | + spec: list[Any] = [] |
| 77 | + flat_inputs = tuple(recurse(x, spec=spec)) |
| 78 | + return flat_inputs, spec[0] |
| 79 | + |
| 80 | + |
| 81 | +def unflatten_tree(x, spec): |
| 82 | + def recurse(x_iter, spec): |
| 83 | + match spec: |
| 84 | + case "x": |
| 85 | + return next(x_iter) |
| 86 | + case None: |
| 87 | + return None |
| 88 | + case tuple(): |
| 89 | + return tuple(recurse(x_iter, e_spec) for e_spec in spec) |
| 90 | + case list(): |
| 91 | + return [recurse(x_iter, e_spec) for e_spec in spec] |
| 92 | + case _: |
| 93 | + raise ValueError(f"Unrecognized spec: {spec}") |
| 94 | + |
| 95 | + iter_x = iter(x) |
| 96 | + res = recurse(iter_x, spec=spec) |
| 97 | + # Check we consumed the whole iterable |
| 98 | + try: |
| 99 | + next(iter_x) |
| 100 | + except StopIteration: |
| 101 | + return res |
| 102 | + else: |
| 103 | + raise ValueError(f"x {x} has more entries than expected from the spec: {spec}") |
| 104 | + |
| 105 | + |
| 106 | +def loop( |
| 107 | + f, |
| 108 | + init, |
| 109 | + xs=None, |
| 110 | + length=None, |
| 111 | + truncate_gradient=False, |
| 112 | + **scan_kwargs, |
| 113 | +): |
| 114 | + # Flatten and process user init and xs |
| 115 | + init_flat, init_tree = flatten_tree(init, subsume_none=True) |
| 116 | + init_flat = [ |
| 117 | + i if isinstance(i, (Variable, ShiftedArg)) else as_tensor(i) for i in init_flat |
| 118 | + ] |
| 119 | + |
| 120 | + # Convert to inner inputs, (also learn about how they map to Scan outputs_info semantics) |
| 121 | + mit_sot_idxs = [] |
| 122 | + sit_sot_idxs = [] |
| 123 | + implicit_sit_sot_idxs = [] |
| 124 | + untraced_sit_sot_idxs = [] |
| 125 | + init_flat_inner_with_shifts = [] |
| 126 | + for i, init_i in enumerate(init_flat): |
| 127 | + if isinstance(init_i, ShiftedArg): |
| 128 | + if max(init_i.by) >= 0: |
| 129 | + raise ValueError(f"Init shifts must be negative, got by={i.by}") |
| 130 | + elem_type = init_i.x.type.clone(shape=init_i.x.type.shape[1:]) |
| 131 | + init_inner = InnerShiftedArg( |
| 132 | + x=[elem_type() for _ in init_i.by], by=init_i.by, readonly=False |
| 133 | + ) |
| 134 | + if min(init_i.by) < -1: |
| 135 | + mit_sot_idxs.append(i) |
| 136 | + else: |
| 137 | + sit_sot_idxs.append(i) |
| 138 | + init_flat_inner_with_shifts.append(init_inner) |
| 139 | + elif isinstance(init_i, TensorVariable): |
| 140 | + implicit_sit_sot_idxs.append(i) |
| 141 | + init_flat_inner_with_shifts.append(init_i.type()) |
| 142 | + else: |
| 143 | + untraced_sit_sot_idxs.append(i) |
| 144 | + init_flat_inner_with_shifts.append(init_i.type()) |
| 145 | + |
| 146 | + # Do the same for sequences |
| 147 | + xs_flat, x_tree = flatten_tree(xs, subsume_none=True) |
| 148 | + xs_flat = [ |
| 149 | + x if isinstance(x, (Variable, ShiftedArg)) else as_tensor(x) for x in xs_flat |
| 150 | + ] |
| 151 | + xs_flat_inner_with_shifts = [] |
| 152 | + for x in xs_flat: |
| 153 | + if isinstance(x, ShiftedArg): |
| 154 | + if min(x.by) < 0: |
| 155 | + raise ValueError(f"Sequence shifts must be non-negative, got by={x.by}") |
| 156 | + elem_type = x.x.type.clone(shape=x.x.type.shape[1:]) |
| 157 | + xs_flat_inner_with_shifts.append( |
| 158 | + InnerShiftedArg(x=[elem_type() for _ in x.by], by=x.by, readonly=True) |
| 159 | + ) |
| 160 | + elif isinstance(x, TensorVariable): |
| 161 | + xs_flat_inner_with_shifts.append(x.type.clone(shape=x.type.shape[1:])()) |
| 162 | + else: |
| 163 | + raise ValueError(f"xs must be TensorVariable got {x} of type {type(x)}") |
| 164 | + |
| 165 | + # Obtain inner outputs |
| 166 | + res = f( |
| 167 | + unflatten_tree(init_flat_inner_with_shifts, init_tree), |
| 168 | + unflatten_tree(xs_flat_inner_with_shifts, x_tree), |
| 169 | + ) |
| 170 | + ys_inner, break_cond_inner = None, None |
| 171 | + match res: |
| 172 | + case (update_inner, ys_inner): |
| 173 | + pass |
| 174 | + case (update_inner, ys_inner, break_cond_inner): |
| 175 | + pass |
| 176 | + case _: |
| 177 | + raise ValueError("Scan f must return a tuple with 2 or 3 outputs") |
| 178 | + |
| 179 | + # Validate outputs |
| 180 | + update_flat_inner_with_shifts, update_tree = flatten_tree( |
| 181 | + update_inner, subsume_none=True |
| 182 | + ) |
| 183 | + if init_tree != update_tree: |
| 184 | + raise ValueError( |
| 185 | + "The update expression (first output of f) does not match the init expression (first input of f), ", |
| 186 | + f"expected: {init_tree}, got: {update_tree}", |
| 187 | + ) |
| 188 | + update_flat_inner = [] |
| 189 | + for u in update_flat_inner_with_shifts: |
| 190 | + if isinstance(u, InnerShiftedArg): |
| 191 | + if u.update is None: |
| 192 | + raise ValueError(f"No update pushed for shifted argument {u}") |
| 193 | + update_flat_inner.append(u.update) |
| 194 | + else: |
| 195 | + update_flat_inner.append(u) |
| 196 | + |
| 197 | + ys_flat_inner, y_tree = flatten_tree(ys_inner, subsume_none=True) |
| 198 | + for y in ys_flat_inner: |
| 199 | + if not isinstance(y, TensorVariable): |
| 200 | + raise TypeError( |
| 201 | + f"ys outputs must be TensorVariables, got {y} of type {type(y)}. " |
| 202 | + "Non-traceable types like RNG states should be carried in init, not returned as ys." |
| 203 | + ) |
| 204 | + |
| 205 | + if break_cond_inner is not None: |
| 206 | + # TODO: validate |
| 207 | + raise NotImplementedError |
| 208 | + |
| 209 | + # Get inputs aligned for Scan and unpack ShiftedArgs |
| 210 | + scan_inner_inputs, _ = flatten_tree( |
| 211 | + ( |
| 212 | + [ |
| 213 | + s.x if isinstance(s, InnerShiftedArg) else s |
| 214 | + for s in xs_flat_inner_with_shifts |
| 215 | + ], |
| 216 | + [init_flat_inner_with_shifts[idx].x for idx in mit_sot_idxs], |
| 217 | + [init_flat_inner_with_shifts[idx].x for idx in sit_sot_idxs], |
| 218 | + [init_flat_inner_with_shifts[idx] for idx in implicit_sit_sot_idxs], |
| 219 | + [init_flat_inner_with_shifts[idx] for idx in untraced_sit_sot_idxs], |
| 220 | + ) |
| 221 | + ) |
| 222 | + # Get outputs aligned for Scan and unpack ShiftedArgs.update |
| 223 | + scan_inner_outputs, _ = flatten_tree( |
| 224 | + ( |
| 225 | + [update_flat_inner[idx] for idx in mit_sot_idxs], |
| 226 | + [update_flat_inner[idx] for idx in sit_sot_idxs], |
| 227 | + [update_flat_inner[idx] for idx in implicit_sit_sot_idxs], |
| 228 | + ys_flat_inner, |
| 229 | + [update_flat_inner[idx] for idx in untraced_sit_sot_idxs], |
| 230 | + break_cond_inner, |
| 231 | + ), |
| 232 | + subsume_none=True, |
| 233 | + ) |
| 234 | + |
| 235 | + # TODO: if any of the ys, is the same as the update values, we could return a slice of the trace |
| 236 | + # that discards the initial values and reduced the number of nit_sots |
| 237 | + # (useful even if there's already a Scan rewrite for this) |
| 238 | + |
| 239 | + # Use graph analysis to get the smallest closure of loop-invariant constants |
| 240 | + # Expand ShiftedArgs into their individual tap variables for graph analysis |
| 241 | + def _find_scan_constants(inputs, outputs) -> list[Variable]: |
| 242 | + def _depends_only_on_constants(var: Variable) -> bool: |
| 243 | + if isinstance(var, Constant): |
| 244 | + return True |
| 245 | + if var.owner is None: |
| 246 | + return False |
| 247 | + return all(isinstance(v, Constant) for v in graph_inputs([var])) |
| 248 | + |
| 249 | + inputs_set = set(inputs) |
| 250 | + return [ |
| 251 | + arg |
| 252 | + for arg in truncated_graph_inputs(outputs, inputs) |
| 253 | + if (arg not in inputs_set and not _depends_only_on_constants(arg)) |
| 254 | + ] |
| 255 | + |
| 256 | + constants = _find_scan_constants(scan_inner_inputs, scan_inner_outputs) |
| 257 | + inner_constants = [c.type() for c in constants] |
| 258 | + if inner_constants: |
| 259 | + # These constants belong to the outer graph, we need to remake inner outputs using dummies |
| 260 | + scan_inner_inputs = (*scan_inner_inputs, *inner_constants) |
| 261 | + scan_inner_outputs = graph_replace( |
| 262 | + scan_inner_outputs, |
| 263 | + replace=tuple(zip(constants, inner_constants)), |
| 264 | + strict=True, |
| 265 | + ) |
| 266 | + |
| 267 | + # Now build Scan Op |
| 268 | + info = ScanInfo( |
| 269 | + n_seqs=sum( |
| 270 | + len(x.by) if isinstance(x, InnerShiftedArg) else 1 |
| 271 | + for x in xs_flat_inner_with_shifts |
| 272 | + ), |
| 273 | + mit_mot_in_slices=(), |
| 274 | + mit_mot_out_slices=(), |
| 275 | + mit_sot_in_slices=tuple(init_flat[idx].by for idx in mit_sot_idxs), |
| 276 | + sit_sot_in_slices=((-1,),) * (len(sit_sot_idxs) + len(implicit_sit_sot_idxs)), |
| 277 | + n_untraced_sit_sot=len(untraced_sit_sot_idxs), |
| 278 | + n_nit_sot=len(ys_flat_inner), |
| 279 | + n_non_seqs=len(inner_constants), |
| 280 | + as_while=break_cond_inner is not None, |
| 281 | + ) |
| 282 | + |
| 283 | + scan_op = Scan( |
| 284 | + list(scan_inner_inputs), |
| 285 | + list(scan_inner_outputs), |
| 286 | + info, |
| 287 | + truncate_gradient=truncate_gradient, |
| 288 | + strict=True, |
| 289 | + **scan_kwargs, |
| 290 | + ) |
| 291 | + |
| 292 | + # Create outer sequences (learning about their length as we go) |
| 293 | + outer_sequences = [] |
| 294 | + sequences_lengths = [] |
| 295 | + for x in xs_flat: |
| 296 | + if isinstance(x, ShiftedArg): |
| 297 | + maxtap = max(x.by) |
| 298 | + sequences_lengths.append(x.x.shape[0] - maxtap) |
| 299 | + for start in x.by: |
| 300 | + end = None if start == maxtap else -(maxtap - start) |
| 301 | + outer_sequences.append(x.x[start:end]) |
| 302 | + else: |
| 303 | + sequences_lengths.append(x.shape[0]) |
| 304 | + outer_sequences.append(x) |
| 305 | + |
| 306 | + if length is not None: |
| 307 | + n_steps = as_tensor(length) |
| 308 | + elif sequences_lengths: |
| 309 | + n_steps = reduce(minimum, sequences_lengths) |
| 310 | + else: |
| 311 | + raise ValueError("length must be provided when there are no xs") |
| 312 | + |
| 313 | + # Build outer input traces with as many entries as n_steps + lags |
| 314 | + mit_sot_outer_inputs = [ |
| 315 | + expand_empty(init_flat[idx].x, n_steps) for idx in mit_sot_idxs |
| 316 | + ] |
| 317 | + sit_sot_outer_inputs = [ |
| 318 | + expand_empty(init_flat[idx].x, n_steps) for idx in sit_sot_idxs |
| 319 | + ] |
| 320 | + implicit_sit_sot_outer_inputs = [ |
| 321 | + expand_empty(init_flat[idx], n_steps, new_dim=True) |
| 322 | + for idx in implicit_sit_sot_idxs |
| 323 | + ] |
| 324 | + untraced_sit_sot_outer_inputs = [init_flat[idx] for idx in untraced_sit_sot_idxs] |
| 325 | + |
| 326 | + scan_outer_inputs, _ = flatten_tree( |
| 327 | + ( |
| 328 | + n_steps, |
| 329 | + outer_sequences, |
| 330 | + mit_sot_outer_inputs, |
| 331 | + sit_sot_outer_inputs, |
| 332 | + implicit_sit_sot_outer_inputs, |
| 333 | + untraced_sit_sot_outer_inputs, |
| 334 | + ((n_steps,) * info.n_nit_sot), |
| 335 | + constants, |
| 336 | + ) |
| 337 | + ) |
| 338 | + scan_outputs = scan_op(*scan_outer_inputs, return_list=True) |
| 339 | + |
| 340 | + # Extract final values from traced/untraced_outputs |
| 341 | + final_values, _ = flatten_tree( |
| 342 | + ( |
| 343 | + [mit_sot[-1] for mit_sot in scan_op.outer_mitsot_outs(scan_outputs)], |
| 344 | + [sit_sot[-1] for sit_sot in scan_op.outer_sitsot_outs(scan_outputs)], |
| 345 | + scan_op.outer_untraced_sit_sot_outs(scan_outputs), |
| 346 | + ) |
| 347 | + ) |
| 348 | + # These need to be reordered to the user order |
| 349 | + flat_idxs, _ = flatten_tree( |
| 350 | + (mit_sot_idxs, sit_sot_idxs, implicit_sit_sot_idxs, untraced_sit_sot_idxs) |
| 351 | + ) |
| 352 | + final = unflatten_tree( |
| 353 | + [final_values[rev_idx] for rev_idx in np.argsort(flat_idxs)], init_tree |
| 354 | + ) |
| 355 | + ys = unflatten_tree(scan_op.outer_nitsot_outs(scan_outputs), y_tree) |
| 356 | + |
| 357 | + return final, ys |
0 commit comments