Skip to content

Commit 67fad07

Browse files
committed
Eliminate Join buffer copies via WriteSplit/WriteJoin
1 parent 66c2720 commit 67fad07

File tree

5 files changed

+417
-1
lines changed

5 files changed

+417
-1
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.link.numba.dispatch.compile_ops
77
import pytensor.link.numba.dispatch.elemwise
88
import pytensor.link.numba.dispatch.extra_ops
9+
import pytensor.link.numba.dispatch.join_inplace
910
import pytensor.link.numba.dispatch.nlinalg
1011
import pytensor.link.numba.dispatch.random
1112
import pytensor.link.numba.dispatch.scan
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
3+
from pytensor.link.numba.cache import compile_numba_function_src
4+
from pytensor.link.numba.dispatch import basic as numba_basic
5+
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
6+
from pytensor.tensor.rewriting.join_inplace import WriteJoin, WriteSplit
7+
8+
9+
@register_funcify_default_op_cache_key(WriteSplit)
10+
def numba_funcify_WriteSplit(op, node, **kwargs):
11+
n_splits = op.n_splits
12+
axis = op.axis
13+
14+
slice_lines = []
15+
offset_expr = "0"
16+
for i in range(n_splits):
17+
slice_lines.append(f" sz_{i} = s{i}.item()")
18+
idx = ", ".join(
19+
f"{offset_expr}:{offset_expr} + sz_{i}" if d == axis else ":"
20+
for d in range(node.inputs[0].type.ndim)
21+
)
22+
slice_lines.append(f" out_{i} = buffer[{idx}]")
23+
offset_expr = f"{offset_expr} + sz_{i}"
24+
25+
return_vars = ", ".join(f"out_{i}" for i in range(n_splits))
26+
size_params = ", ".join(f"s{i}" for i in range(n_splits))
27+
28+
func_src = f"""
29+
def write_split(buffer, {size_params}):
30+
{chr(10).join(slice_lines)}
31+
return ({return_vars},)
32+
"""
33+
fn = compile_numba_function_src(func_src, "write_split", {"np": np})
34+
return numba_basic.numba_njit(fn)
35+
36+
37+
@register_funcify_default_op_cache_key(WriteJoin)
38+
def numba_funcify_WriteJoin(op, node, **kwargs):
39+
n_deps = len(node.inputs) - 1
40+
41+
dep_params = ", ".join(f"dep{i}" for i in range(n_deps))
42+
sig = f"buffer, {dep_params}" if dep_params else "buffer"
43+
44+
func_src = f"""
45+
def write_join({sig}):
46+
return buffer
47+
"""
48+
fn = compile_numba_function_src(func_src, "write_join")
49+
return numba_basic.numba_njit(fn)

pytensor/tensor/rewriting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.tensor.rewriting.elemwise
77
import pytensor.tensor.rewriting.extra_ops
88
import pytensor.tensor.rewriting.jax
9+
import pytensor.tensor.rewriting.join_inplace
910
import pytensor.tensor.rewriting.linalg
1011
import pytensor.tensor.rewriting.math
1112
import pytensor.tensor.rewriting.numba

0 commit comments

Comments
 (0)