提交 03b62a33 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow inplacing of SITSOT and last MITSOT in numba Scan, when they are discarded immediately

上级 1fa22d8f
...@@ -55,7 +55,7 @@ def array0d_range(x): ...@@ -55,7 +55,7 @@ def array0d_range(x):
@numba_funcify.register(Scan) @numba_funcify.register(Scan)
def numba_funcify_Scan(op, node, **kwargs): def numba_funcify_Scan(op: Scan, node, **kwargs):
# Apply inner rewrites # Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that # TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan? # explicitly triggers the optimization of the inner graphs of Scan?
...@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
.optimizer .optimizer
) )
fgraph = op.fgraph fgraph = op.fgraph
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
# TODO: Allow inplace for MITMOT
destroyable_sitsot = [
inner_sitsot
for outer_sitsot, inner_sitsot in zip(
op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True
)
if outer_sitsot.type.shape[0] == 1
]
destroyable_mitsot = [
oldest_inner_mitmot
for outer_mitsot, oldest_inner_mitmot, taps in zip(
op.outer_mitsot(node.inputs),
op.oldest_inner_mitsot(fgraph.inputs),
op.info.mit_sot_in_slices,
strict=True,
)
if outer_mitsot.type.shape[0] == abs(min(taps))
]
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
add_supervisor_to_fgraph( add_supervisor_to_fgraph(
fgraph=fgraph, fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], input_specs=[
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
],
accept_inplace=True, accept_inplace=True,
) )
rewriter(fgraph) rewriter(fgraph)
......
...@@ -321,6 +321,16 @@ class ScanMethodsMixin: ...@@ -321,6 +321,16 @@ class ScanMethodsMixin:
self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
] ]
def oldest_inner_mitsot(self, list_inputs):
inner_mitsot_inputs = self.inner_mitsot(list_inputs)
oldest_inner_mitsot_inputs = []
offset = 0
for taps in self.info.mit_sot_in_slices:
oldest_tap = np.argmin(taps)
oldest_inner_mitsot_inputs += [inner_mitsot_inputs[offset + oldest_tap]]
offset += len(taps)
return oldest_inner_mitsot_inputs
def outer_mitsot(self, list_inputs): def outer_mitsot(self, list_inputs):
offset = 1 + self.info.n_seqs + self.info.n_mit_mot offset = 1 + self.info.n_seqs + self.info.n_mit_mot
return list_inputs[offset : offset + self.info.n_mit_sot] return list_inputs[offset : offset + self.info.n_mit_sot]
......
...@@ -451,6 +451,80 @@ def test_vector_taps_benchmark(benchmark): ...@@ -451,6 +451,80 @@ def test_vector_taps_benchmark(benchmark):
benchmark(numba_fn, *test.values()) benchmark(numba_fn, *test.values())
@pytest.mark.parametrize("n_steps_constant", (True, False))
def test_inplace_taps(n_steps_constant):
"""Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps."""
n_steps = 10 if n_steps_constant else scalar("n_steps", dtype=int)
a = scalar("a")
x0 = scalar("x0")
y0 = vector("y0", shape=(2,))
z0 = vector("z0", shape=(3,))
def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
z = ztm1 + 1 + ztm3 + a
x = xtm1 + 1
y = ytm1 + 1 + ytm2 + a
return z, x, z + x + y, y
[zs, xs, ws, ys], _ = scan(
fn=step,
outputs_info=[
dict(initial=z0, taps=[-3, -1]),
dict(initial=x0, taps=[-1]),
None,
dict(initial=y0, taps=[-1, -2]),
],
non_sequences=[a],
n_steps=n_steps,
)
numba_fn, _ = compare_numba_and_py(
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
[zs[-1], xs[-1], ws[-1], ys[-1]],
[10] * (not n_steps_constant) + [np.pi, np.e, [1, np.euler_gamma], [0, 1, 2]],
numba_mode="NUMBA",
eval_obj_mode=False,
)
[scan_op] = [
node.op
for node in numba_fn.maker.fgraph.toposort()
if isinstance(node.op, Scan)
]
# Scan reorders inputs internally, so we need to check its ordering
inner_inps = scan_op.fgraph.inputs
mit_sot_inps = scan_op.inner_mitsot(inner_inps)
oldest_mit_sot_inps = [
# Implicitly assume that the first mit-sot input is the one with 3 taps
# This is not a required behavior and the test can change if we need to change Scan.
mit_sot_inps[:2][scan_op.info.mit_sot_in_slices[0].index(-3)],
mit_sot_inps[2:][scan_op.info.mit_sot_in_slices[1].index(-2)],
]
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
inner_outs = scan_op.fgraph.outputs
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs)
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs)
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs)
if n_steps_constant:
assert mit_sot_outs[0].owner.op.destroy_map == {
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])]
}
assert mit_sot_outs[1].owner.op.destroy_map == {
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
}
assert sit_sot_out.owner.op.destroy_map == {
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
}
else:
# This is not a feature, but a current limitation
# https://github.com/pymc-devs/pytensor/issues/1283
assert mit_sot_outs[0].owner.op.destroy_map == {}
assert mit_sot_outs[1].owner.op.destroy_map == {}
assert sit_sot_out.owner.op.destroy_map == {}
assert nit_sot_out.owner.op.destroy_map == {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init") "buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论