Skip to content

Commit fe4ae27

Browse files
committed
Implement new user-facing looping signature
1 parent 4273eb8 commit fe4ae27

File tree

5 files changed

+689
-7
lines changed

5 files changed

+689
-7
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def add_output_storage_post_proc_stmt(
291291
storage_name = outer_in_to_storage_name[outer_in_name]
292292

293293
is_tensor_type = isinstance(outer_in_var.type, TensorType)
294-
if is_tensor_type:
294+
is_untraced = outer_in_name in outer_in_untraced_sit_sot_names
295+
if is_tensor_type and not is_untraced:
295296
storage_size_name = f"{outer_in_name}_len"
296297
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
297298
input_taps = inner_in_names_to_input_taps[outer_in_name]

pytensor/looping.py

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
from collections.abc import Sequence
2+
from dataclasses import dataclass
3+
from functools import reduce
4+
from typing import Any
5+
6+
import numpy as np
7+
8+
from pytensor.graph.basic import Constant, Variable
9+
from pytensor.graph.replace import graph_replace
10+
from pytensor.graph.traversal import graph_inputs, truncated_graph_inputs
11+
from pytensor.scan.op import Scan, ScanInfo
12+
from pytensor.scan.utils import expand_empty
13+
from pytensor.tensor import TensorVariable, as_tensor, minimum
14+
15+
16+
@dataclass(frozen=True)
17+
class ShiftedArg:
18+
x: Any
19+
by: tuple[int, ...]
20+
21+
22+
@dataclass(frozen=True)
23+
class InnerShiftedArg:
24+
x: Any
25+
by: tuple[int, ...]
26+
readonly: bool
27+
update: Variable | None = None
28+
29+
def push(self, x: Variable) -> "InnerShiftedArg":
30+
if self.readonly:
31+
raise ValueError(
32+
"Cannot push to a read-only ShiftedArg (xs shifts cannot be updated)"
33+
)
34+
if self.update is not None:
35+
raise ValueError("ShiftedArg can only have a value pushed once")
36+
return type(self)(x=self.x, by=self.by, update=x, readonly=self.readonly)
37+
38+
def __getitem__(self, idx):
39+
if -len(self.by) <= idx < len(self.by):
40+
return self.x[idx]
41+
else:
42+
raise IndexError()
43+
44+
def __len__(self):
45+
return len(self.by)
46+
47+
48+
def shift(x: Any, by: int | Sequence[int] = -1):
49+
by = (by,) if isinstance(by, int) else tuple(by)
50+
if by != tuple(sorted(by)):
51+
raise ValueError(f"by entries must be sorted, got {by}")
52+
if min(by) < 0 and max(by) >= 0:
53+
raise ValueError(
54+
f"by cannot contain both negative and non-negative entries, got {by}"
55+
)
56+
# TODO: If shape is known statically, validate the input is as big as the min/max taps
57+
return ShiftedArg(x, by=by)
58+
59+
60+
def flatten_tree(x, subsume_none: bool = False) -> tuple[tuple, Any]:
61+
def recurse(e, spec):
62+
match e:
63+
case tuple() | list():
64+
e_spec = []
65+
for e_i in e:
66+
yield from recurse(e_i, e_spec)
67+
if isinstance(e, tuple):
68+
e_spec = tuple(e_spec)
69+
spec.append(e_spec)
70+
case None if subsume_none:
71+
spec.append(None)
72+
case x:
73+
spec.append("x")
74+
yield x
75+
76+
spec: list[Any] = []
77+
flat_inputs = tuple(recurse(x, spec=spec))
78+
return flat_inputs, spec[0]
79+
80+
81+
def unflatten_tree(x, spec):
82+
def recurse(x_iter, spec):
83+
match spec:
84+
case "x":
85+
return next(x_iter)
86+
case None:
87+
return None
88+
case tuple():
89+
return tuple(recurse(x_iter, e_spec) for e_spec in spec)
90+
case list():
91+
return [recurse(x_iter, e_spec) for e_spec in spec]
92+
case _:
93+
raise ValueError(f"Unrecognized spec: {spec}")
94+
95+
iter_x = iter(x)
96+
res = recurse(iter_x, spec=spec)
97+
# Check we consumed the whole iterable
98+
try:
99+
next(iter_x)
100+
except StopIteration:
101+
return res
102+
else:
103+
raise ValueError(f"x {x} has more entries than expected from the spec: {spec}")
104+
105+
106+
def loop(
107+
f,
108+
init,
109+
xs=None,
110+
length=None,
111+
truncate_gradient=False,
112+
**scan_kwargs,
113+
):
114+
# Flatten and process user init and xs
115+
init_flat, init_tree = flatten_tree(init, subsume_none=True)
116+
init_flat = [
117+
i if isinstance(i, (Variable, ShiftedArg)) else as_tensor(i) for i in init_flat
118+
]
119+
120+
# Convert to inner inputs, (also learn about how they map to Scan outputs_info semantics)
121+
mit_sot_idxs = []
122+
sit_sot_idxs = []
123+
implicit_sit_sot_idxs = []
124+
untraced_sit_sot_idxs = []
125+
init_flat_inner_with_shifts = []
126+
for i, init_i in enumerate(init_flat):
127+
if isinstance(init_i, ShiftedArg):
128+
if max(init_i.by) >= 0:
129+
raise ValueError(f"Init shifts must be negative, got by={i.by}")
130+
elem_type = init_i.x.type.clone(shape=init_i.x.type.shape[1:])
131+
init_inner = InnerShiftedArg(
132+
x=[elem_type() for _ in init_i.by], by=init_i.by, readonly=False
133+
)
134+
if min(init_i.by) < -1:
135+
mit_sot_idxs.append(i)
136+
else:
137+
sit_sot_idxs.append(i)
138+
init_flat_inner_with_shifts.append(init_inner)
139+
elif isinstance(init_i, TensorVariable):
140+
implicit_sit_sot_idxs.append(i)
141+
init_flat_inner_with_shifts.append(init_i.type())
142+
else:
143+
untraced_sit_sot_idxs.append(i)
144+
init_flat_inner_with_shifts.append(init_i.type())
145+
146+
# Do the same for sequences
147+
xs_flat, x_tree = flatten_tree(xs, subsume_none=True)
148+
xs_flat = [
149+
x if isinstance(x, (Variable, ShiftedArg)) else as_tensor(x) for x in xs_flat
150+
]
151+
xs_flat_inner_with_shifts = []
152+
for x in xs_flat:
153+
if isinstance(x, ShiftedArg):
154+
if min(x.by) < 0:
155+
raise ValueError(f"Sequence shifts must be non-negative, got by={x.by}")
156+
elem_type = x.x.type.clone(shape=x.x.type.shape[1:])
157+
xs_flat_inner_with_shifts.append(
158+
InnerShiftedArg(x=[elem_type() for _ in x.by], by=x.by, readonly=True)
159+
)
160+
elif isinstance(x, TensorVariable):
161+
xs_flat_inner_with_shifts.append(x.type.clone(shape=x.type.shape[1:])())
162+
else:
163+
raise ValueError(f"xs must be TensorVariable got {x} of type {type(x)}")
164+
165+
# Obtain inner outputs
166+
res = f(
167+
unflatten_tree(init_flat_inner_with_shifts, init_tree),
168+
unflatten_tree(xs_flat_inner_with_shifts, x_tree),
169+
)
170+
ys_inner, break_cond_inner = None, None
171+
match res:
172+
case (update_inner, ys_inner):
173+
pass
174+
case (update_inner, ys_inner, break_cond_inner):
175+
pass
176+
case _:
177+
raise ValueError("Scan f must return a tuple with 2 or 3 outputs")
178+
179+
# Validate outputs
180+
update_flat_inner_with_shifts, update_tree = flatten_tree(
181+
update_inner, subsume_none=True
182+
)
183+
if init_tree != update_tree:
184+
raise ValueError(
185+
"The update expression (first output of f) does not match the init expression (first input of f), ",
186+
f"expected: {init_tree}, got: {update_tree}",
187+
)
188+
update_flat_inner = []
189+
for u in update_flat_inner_with_shifts:
190+
if isinstance(u, InnerShiftedArg):
191+
if u.update is None:
192+
raise ValueError(f"No update pushed for shifted argument {u}")
193+
update_flat_inner.append(u.update)
194+
else:
195+
update_flat_inner.append(u)
196+
197+
ys_flat_inner, y_tree = flatten_tree(ys_inner, subsume_none=True)
198+
for y in ys_flat_inner:
199+
if not isinstance(y, TensorVariable):
200+
raise TypeError(
201+
f"ys outputs must be TensorVariables, got {y} of type {type(y)}. "
202+
"Non-traceable types like RNG states should be carried in init, not returned as ys."
203+
)
204+
205+
if break_cond_inner is not None:
206+
# TODO: validate
207+
raise NotImplementedError
208+
209+
# Get inputs aligned for Scan and unpack ShiftedArgs
210+
scan_inner_inputs, _ = flatten_tree(
211+
(
212+
[
213+
s.x if isinstance(s, InnerShiftedArg) else s
214+
for s in xs_flat_inner_with_shifts
215+
],
216+
[init_flat_inner_with_shifts[idx].x for idx in mit_sot_idxs],
217+
[init_flat_inner_with_shifts[idx].x for idx in sit_sot_idxs],
218+
[init_flat_inner_with_shifts[idx] for idx in implicit_sit_sot_idxs],
219+
[init_flat_inner_with_shifts[idx] for idx in untraced_sit_sot_idxs],
220+
)
221+
)
222+
# Get outputs aligned for Scan and unpack ShiftedArgs.update
223+
scan_inner_outputs, _ = flatten_tree(
224+
(
225+
[update_flat_inner[idx] for idx in mit_sot_idxs],
226+
[update_flat_inner[idx] for idx in sit_sot_idxs],
227+
[update_flat_inner[idx] for idx in implicit_sit_sot_idxs],
228+
ys_flat_inner,
229+
[update_flat_inner[idx] for idx in untraced_sit_sot_idxs],
230+
break_cond_inner,
231+
),
232+
subsume_none=True,
233+
)
234+
235+
# TODO: if any of the ys, is the same as the update values, we could return a slice of the trace
236+
# that discards the initial values and reduced the number of nit_sots
237+
# (useful even if there's already a Scan rewrite for this)
238+
239+
# Use graph analysis to get the smallest closure of loop-invariant constants
240+
# Expand ShiftedArgs into their individual tap variables for graph analysis
241+
def _find_scan_constants(inputs, outputs) -> list[Variable]:
242+
def _depends_only_on_constants(var: Variable) -> bool:
243+
if isinstance(var, Constant):
244+
return True
245+
if var.owner is None:
246+
return False
247+
return all(isinstance(v, Constant) for v in graph_inputs([var]))
248+
249+
inputs_set = set(inputs)
250+
return [
251+
arg
252+
for arg in truncated_graph_inputs(outputs, inputs)
253+
if (arg not in inputs_set and not _depends_only_on_constants(arg))
254+
]
255+
256+
constants = _find_scan_constants(scan_inner_inputs, scan_inner_outputs)
257+
inner_constants = [c.type() for c in constants]
258+
if inner_constants:
259+
# These constants belong to the outer graph, we need to remake inner outputs using dummies
260+
scan_inner_inputs = (*scan_inner_inputs, *inner_constants)
261+
scan_inner_outputs = graph_replace(
262+
scan_inner_outputs,
263+
replace=tuple(zip(constants, inner_constants)),
264+
strict=True,
265+
)
266+
267+
# Now build Scan Op
268+
info = ScanInfo(
269+
n_seqs=sum(
270+
len(x.by) if isinstance(x, InnerShiftedArg) else 1
271+
for x in xs_flat_inner_with_shifts
272+
),
273+
mit_mot_in_slices=(),
274+
mit_mot_out_slices=(),
275+
mit_sot_in_slices=tuple(init_flat[idx].by for idx in mit_sot_idxs),
276+
sit_sot_in_slices=((-1,),) * (len(sit_sot_idxs) + len(implicit_sit_sot_idxs)),
277+
n_untraced_sit_sot=len(untraced_sit_sot_idxs),
278+
n_nit_sot=len(ys_flat_inner),
279+
n_non_seqs=len(inner_constants),
280+
as_while=break_cond_inner is not None,
281+
)
282+
283+
scan_op = Scan(
284+
list(scan_inner_inputs),
285+
list(scan_inner_outputs),
286+
info,
287+
truncate_gradient=truncate_gradient,
288+
strict=True,
289+
**scan_kwargs,
290+
)
291+
292+
# Create outer sequences (learning about their length as we go)
293+
outer_sequences = []
294+
sequences_lengths = []
295+
for x in xs_flat:
296+
if isinstance(x, ShiftedArg):
297+
maxtap = max(x.by)
298+
sequences_lengths.append(x.x.shape[0] - maxtap)
299+
for start in x.by:
300+
end = None if start == maxtap else -(maxtap - start)
301+
outer_sequences.append(x.x[start:end])
302+
else:
303+
sequences_lengths.append(x.shape[0])
304+
outer_sequences.append(x)
305+
306+
if length is not None:
307+
n_steps = as_tensor(length)
308+
elif sequences_lengths:
309+
n_steps = reduce(minimum, sequences_lengths)
310+
else:
311+
raise ValueError("length must be provided when there are no xs")
312+
313+
# Build outer input traces with as many entries as n_steps + lags
314+
mit_sot_outer_inputs = [
315+
expand_empty(init_flat[idx].x, n_steps) for idx in mit_sot_idxs
316+
]
317+
sit_sot_outer_inputs = [
318+
expand_empty(init_flat[idx].x, n_steps) for idx in sit_sot_idxs
319+
]
320+
implicit_sit_sot_outer_inputs = [
321+
expand_empty(init_flat[idx], n_steps, new_dim=True)
322+
for idx in implicit_sit_sot_idxs
323+
]
324+
untraced_sit_sot_outer_inputs = [init_flat[idx] for idx in untraced_sit_sot_idxs]
325+
326+
scan_outer_inputs, _ = flatten_tree(
327+
(
328+
n_steps,
329+
outer_sequences,
330+
mit_sot_outer_inputs,
331+
sit_sot_outer_inputs,
332+
implicit_sit_sot_outer_inputs,
333+
untraced_sit_sot_outer_inputs,
334+
((n_steps,) * info.n_nit_sot),
335+
constants,
336+
)
337+
)
338+
scan_outputs = scan_op(*scan_outer_inputs, return_list=True)
339+
340+
# Extract final values from traced/untraced_outputs
341+
final_values, _ = flatten_tree(
342+
(
343+
[mit_sot[-1] for mit_sot in scan_op.outer_mitsot_outs(scan_outputs)],
344+
[sit_sot[-1] for sit_sot in scan_op.outer_sitsot_outs(scan_outputs)],
345+
scan_op.outer_untraced_sit_sot_outs(scan_outputs),
346+
)
347+
)
348+
# These need to be reordered to the user order
349+
flat_idxs, _ = flatten_tree(
350+
(mit_sot_idxs, sit_sot_idxs, implicit_sit_sot_idxs, untraced_sit_sot_idxs)
351+
)
352+
final = unflatten_tree(
353+
[final_values[rev_idx] for rev_idx in np.argsort(flat_idxs)], init_tree
354+
)
355+
ys = unflatten_tree(scan_op.outer_nitsot_outs(scan_outputs), y_tree)
356+
357+
return final, ys

0 commit comments

Comments
 (0)