提交 63f8d6e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize while scans when only last state is needed

上级 01e92baa
......@@ -1182,7 +1182,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we
# in this case is just an int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]
......
差异被折叠。
......@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
expresses all slices in a canonical form, and then merges them together.
"""
from pytensor.scan.op import Scan
if isinstance(node.op, Subtensor):
u = node.inputs[0]
......@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
# slices of the first applied subtensor
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = get_idx_list(node.inputs, node.op.idx_list)
# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None
# Get the shapes of the vectors !
try:
# try not to introduce new shape into the graph
......
......@@ -1395,6 +1395,98 @@ class TestSaveMem:
rng = np.random.default_rng(utt.fetch_seed())
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
def test_while_scan_taps(self):
n_steps = scalar("n_steps", dtype="int64")
x0 = vector("x0")
ys, _ = pytensor.scan(
# Fibonacci Sequence
lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)),
outputs_info=[{"initial": x0, "taps": [-2, -1]}],
n_steps=n_steps,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
f = pytensor.function(
[n_steps, x0], y, mode=get_default_mode().including("scan")
)
np.testing.assert_equal(f(n_steps=1000, x0=[1, 1]), 55)
np.testing.assert_equal(f(n_steps=1, x0=[1, 1]), 2)
with pytest.raises(AssertionError, match="n_steps > 0"):
f(n_steps=0, x0=[1, 1])
# ys_trace is an Alloc that controls the size of the inner buffer,
# it should have shape[0] == 3, with two entries for the taps and one
# entry for the intermediate output
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs
debug_fn = pytensor.function(
[n_steps, x0], ys_trace.shape[0], accept_inplace=True
)
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
def test_while_scan_map(self):
xs = vector("xs")
ys, _ = pytensor.scan(
lambda x: (x + 1, {}, until(x + 1 >= 10)),
outputs_info=[None],
sequences=[xs],
)
# Save memory is triggered by choosing only last value
y = ys[-1]
f = pytensor.function([xs], y, mode=get_default_mode().including("scan"))
np.testing.assert_equal(f(xs=np.arange(100, dtype=config.floatX)), 10)
np.testing.assert_equal(f(xs=[0]), 1)
with pytest.raises(IndexError):
f(xs=[])
# len_ys is a numerical input that controls the shape of the inner buffer
# It should be 1, as only the last output is needed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, len_ys = scan_node.inputs
debug_fn = pytensor.function([xs], len_ys, accept_inplace=True)
assert debug_fn(xs=np.zeros((100,), dtype=config.floatX)) == 1
def test_while_scan_taps_and_map(self):
x0 = scalar("x0")
seq = vector("seq")
n_steps = scalar("n_steps", dtype="int64")
# while loop
[ys, zs], _ = pytensor.scan(
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
sequences=[seq],
outputs_info=[x0, None],
n_steps=n_steps,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
z = zs[-1]
f = pytensor.function(
[x0, seq, n_steps], [y, z], mode=get_default_mode().including("scan")
)
test_seq = np.zeros(200, dtype=config.floatX)
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
with pytest.raises(AssertionError, match="n_steps > 0"):
f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# If a MissingInputError is raised, it means the rewrite failed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, ys_trace, len_zs = scan_node.inputs
debug_fn = pytensor.function(
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
)
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
assert stored_ys_steps == 2
assert stored_zs_steps == 1
def test_inner_replace_dot():
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论