提交 63f8d6e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize while scans when only last state is needed

上级 01e92baa
...@@ -1182,7 +1182,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1182,7 +1182,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# these are states that do not feed anything back in the recurrent # these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan # computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input # node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we # in this case is just an int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same # need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int. # type of tensor as the output, it is always a scalar int.
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)] new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]
......
...@@ -28,10 +28,18 @@ from pytensor.graph.features import ReplaceValidate ...@@ -28,10 +28,18 @@ from pytensor.graph.features import ReplaceValidate
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.graph.rewriting.basic import (
GraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert
from pytensor.scalar import ScalarConstant
from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import ( from pytensor.scan.utils import (
ScanArgs, ScanArgs,
...@@ -1103,6 +1111,71 @@ def sanitize(x): ...@@ -1103,6 +1111,71 @@ def sanitize(x):
return at.as_tensor_variable(x) return at.as_tensor_variable(x)
@node_rewriter([Scan])
def while_scan_merge_subtensor_last_element(fgraph, scan_node):
"""
Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
recurring outputs, asserting that at least one step occurs.
Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
as the while scan could abort earlier anytime after that. This means it is
not possible to replace while_scan_out[abs(min(tap)):][-i]
by while_scan_out[-i], for -i != -1.
"""
op = scan_node.op
if not op.info.as_while:
return None
# Optimization is not implemented form mit-mot
recurrent_outputs = op.outer_mitsot_outs(scan_node.outputs) + op.outer_sitsot_outs(
scan_node.outputs
)
recurrent_outputs_taps_slices = (
op.info.mit_sot_in_slices + op.info.sit_sot_in_slices
)
n_steps = scan_node.inputs[0]
non_zero_steps_cond = n_steps > 0
assert_non_zero_steps_op = Assert("n_steps > 0")
subtensor_merge_replacements = {}
# Iterate over all nodes that are two computations below the while scan
for node2 in get_clients_at_depth(fgraph, scan_node, depth=2):
if not isinstance(node2.op, Subtensor):
continue
node1 = node2.inputs[0].owner
if not (node1 and isinstance(node1.op, Subtensor)):
continue
x = node1.inputs[0]
if x not in recurrent_outputs:
continue
slice1 = get_idx_list(node1.inputs, node1.op.idx_list)
slice2 = get_idx_list(node2.inputs, node2.op.idx_list)
min_tap = abs(min(recurrent_outputs_taps_slices[recurrent_outputs.index(x)]))
if (
len(slice1) == 1
and isinstance(slice1[0], slice)
and isinstance(slice1[0].start, aes.ScalarConstant)
and slice1[0].start.data == min_tap
and slice1[0].stop is None
and slice1[0].step is None
and len(slice2) == 1
and isinstance(slice2[0], aes.ScalarConstant)
and slice2[0].data == -1
):
out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond)
copy_stack_trace([node2.outputs[0], node2.inputs[0]], out)
subtensor_merge_replacements[node2.outputs[0]] = out
return subtensor_merge_replacements
@node_rewriter([Scan]) @node_rewriter([Scan])
def save_mem_new_scan(fgraph, node): def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption. r"""Graph optimizer that reduces scan memory consumption.
...@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node): ...@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
that SITSOT output. Only the most recently computed timestep ever needs to that SITSOT output. Only the most recently computed timestep ever needs to
be kept in memory. be kept in memory.
There are two ways in which the Scan buffer size is controlled:
1. Each recurring output is saved in an input empty tensor x with the initial
state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):]
positions determine how many intermediate results should be stored.
This rewrite shortens x[abs(min(taps)):] to the smallest possible size.
2. Each non-recurrent output (nit-sot) is associated with a scalar integer
input that determines how many steps should be saved in the perform method.
This rewrite reduces this number to the smallest possible.
The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.
""" """
if not isinstance(node.op, Scan): if not isinstance(node.op, Scan):
return False return False
...@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node): ...@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
# index(step) for any output scan actually needs to compute # index(step) for any output scan actually needs to compute
# In other words n_steps should be equal to this maximal ! # In other words n_steps should be equal to this maximal !
# Note: if we have a shared variable that gets updated at every step # Note: if we have a shared variable that gets updated at every step
# of the loop, reducing the number of steps will affect the the # of the loop, reducing the number of steps will affect the
# value of the shared variable after the loop so we need not to # value of the shared variable after the loop so we cannot
# change the number of steps in that case. To do this we set # change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs # global_nsteps to None which is seen as a flag that nothing needs
# to be done # to be done.
# Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone.
assert len(node.outputs) >= c_outs assert len(node.outputs) >= c_outs
if len(node.outputs) == c_outs: if len(node.outputs) == c_outs and not op.info.as_while:
global_nsteps = {"real": -1, "sym": []} global_nsteps = {"real": -1, "sym": []}
else: else:
global_nsteps = None global_nsteps = None
...@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
else: else:
# there is a **gotcha** here ! Namely, scan returns an # there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output # array that contains the initial state of the output
# as well. Which means that if have a initial state of # as well. Which means that if y has an initial state of
# length 3, and you look for 5 steps you get an output # length 3, and you look for 5 steps you get an output
# y of length 8. If you only use y[:5], this does not # y of length 8. If you only use y[:5], this does not
# mean that you only need to loop for 5 steps but # mean that you only need to loop for 5 steps but
...@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node): ...@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):
# 2.3. Analyze global_nsteps to figure out for how many steps scan # 2.3. Analyze global_nsteps to figure out for how many steps scan
# needs to iterate # needs to iterate
if global_nsteps is not None: if global_nsteps is None:
nw_steps = node.inputs[0] nw_steps = node.inputs[0]
else:
# there are some symbolic tensors that limit the number of # there are some symbolic tensors that limit the number of
# steps # steps
if len(global_nsteps["sym"]) == 0: if len(global_nsteps["sym"]) == 0:
...@@ -1303,6 +1390,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1303,6 +1390,7 @@ def save_mem_new_scan(fgraph, node):
real_steps = None real_steps = None
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
# FIXME: This is not correct. Scan with 0 steps seems to be supported
# Make sure the ScanSaveMem optimization never makes the new # Make sure the ScanSaveMem optimization never makes the new
# number of steps to be 0 (this could happen, for instance, if # number of steps to be 0 (this could happen, for instance, if
# the optimization detects that the outputs of the Scan go through # the optimization detects that the outputs of the Scan go through
...@@ -1310,9 +1398,6 @@ def save_mem_new_scan(fgraph, node): ...@@ -1310,9 +1398,6 @@ def save_mem_new_scan(fgraph, node):
# 0 iterations are not supported. Make sure the new number of steps # 0 iterations are not supported. Make sure the new number of steps
# is at least 1. # is at least 1.
nw_steps = select_max(nw_steps, 1) nw_steps = select_max(nw_steps, 1)
else:
nw_steps = node.inputs[0]
global_nsteps = None
# 2.4 Loop over the clients again now looking just to see how many # 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store # intermediate steps to store
...@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node): ...@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
store_steps[i] = 0 store_steps[i] = 0
break break
if i > op_info.n_mit_mot: # Special case for recurrent outputs where only the last result
length = node.inputs[0] + init_l[i] # is requested. This is needed for this rewrite to apply to
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
# the `else` branch would reintroduce a shape dependency on the
# original Scan which would lead this rewrite to abort in the end.
if (
i <= op.info.n_mit_mot
and isinstance(this_slice[0], ScalarConstant)
and this_slice[0].value == -1
):
start = nw_steps - 1
else: else:
if i <= op.info.n_mit_mot:
try: try:
length = shape_of[out][0] length = shape_of[out][0]
except KeyError: except KeyError:
length = out.shape[0] length = out.shape[0]
else:
length = node.inputs[0] + init_l[i]
cf_slice = get_canonical_form_slice(this_slice[0], length) cf_slice = get_canonical_form_slice(this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
start = at.extract_constant(cf_slice[0].start) start = at.extract_constant(cf_slice[0].start)
else: else:
start = at.extract_constant(cf_slice[0]) start = at.extract_constant(cf_slice[0])
if start == 0 or store_steps[i] == 0: if start == 0 or store_steps[i] == 0:
store_steps[i] = 0 store_steps[i] = 0
else: else:
...@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node): ...@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
nw_input = expand_empty(_nw_input, nw_steps) nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
else: else:
# FIXME: This is never used
nw_input = nw_inputs[in_idx][: (initl + nw_steps)] nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
elif ( elif (
...@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node): ...@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
) )
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
subtens = Subtensor(nw_slice) subtens = Subtensor(nw_slice)
...@@ -1603,6 +1703,13 @@ def save_mem_new_scan(fgraph, node): ...@@ -1603,6 +1703,13 @@ def save_mem_new_scan(fgraph, node):
), ),
) + tuple(old_slices[1:]) ) + tuple(old_slices[1:])
else:
# Special case when only last value is requested
if (
isinstance(old_slices[0], ScalarConstant)
and old_slices[0].value == -1
):
position = old_slices[0]
else: else:
position = ( position = (
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
...@@ -2403,6 +2510,12 @@ scan_seqopt1.register( ...@@ -2403,6 +2510,12 @@ scan_seqopt1.register(
position=5, position=5,
) )
scan_eqopt2.register(
"while_scan_merge_subtensor_last_element",
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
"fast_run",
"scan",
)
scan_eqopt2.register( scan_eqopt2.register(
"constant_folding_for_scan2", "constant_folding_for_scan2",
......
...@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node): ...@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
expresses all slices in a canonical form, and then merges them together. expresses all slices in a canonical form, and then merges them together.
""" """
from pytensor.scan.op import Scan
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
u = node.inputs[0] u = node.inputs[0]
...@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node): ...@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
# slices of the first applied subtensor # slices of the first applied subtensor
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = get_idx_list(node.inputs, node.op.idx_list) slices2 = get_idx_list(node.inputs, node.op.idx_list)
# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None
# Get the shapes of the vectors ! # Get the shapes of the vectors !
try: try:
# try not to introduce new shape into the graph # try not to introduce new shape into the graph
......
...@@ -1395,6 +1395,98 @@ class TestSaveMem: ...@@ -1395,6 +1395,98 @@ class TestSaveMem:
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3])) my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
def test_while_scan_taps(self):
n_steps = scalar("n_steps", dtype="int64")
x0 = vector("x0")
ys, _ = pytensor.scan(
# Fibonacci Sequence
lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)),
outputs_info=[{"initial": x0, "taps": [-2, -1]}],
n_steps=n_steps,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
f = pytensor.function(
[n_steps, x0], y, mode=get_default_mode().including("scan")
)
np.testing.assert_equal(f(n_steps=1000, x0=[1, 1]), 55)
np.testing.assert_equal(f(n_steps=1, x0=[1, 1]), 2)
with pytest.raises(AssertionError, match="n_steps > 0"):
f(n_steps=0, x0=[1, 1])
# ys_trace is an Alloc that controls the size of the inner buffer,
# it should have shape[0] == 3, with two entries for the taps and one
# entry for the intermediate output
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs
debug_fn = pytensor.function(
[n_steps, x0], ys_trace.shape[0], accept_inplace=True
)
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
def test_while_scan_map(self):
xs = vector("xs")
ys, _ = pytensor.scan(
lambda x: (x + 1, {}, until(x + 1 >= 10)),
outputs_info=[None],
sequences=[xs],
)
# Save memory is triggered by choosing only last value
y = ys[-1]
f = pytensor.function([xs], y, mode=get_default_mode().including("scan"))
np.testing.assert_equal(f(xs=np.arange(100, dtype=config.floatX)), 10)
np.testing.assert_equal(f(xs=[0]), 1)
with pytest.raises(IndexError):
f(xs=[])
# len_ys is a numerical input that controls the shape of the inner buffer
# It should be 1, as only the last output is needed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, len_ys = scan_node.inputs
debug_fn = pytensor.function([xs], len_ys, accept_inplace=True)
assert debug_fn(xs=np.zeros((100,), dtype=config.floatX)) == 1
def test_while_scan_taps_and_map(self):
x0 = scalar("x0")
seq = vector("seq")
n_steps = scalar("n_steps", dtype="int64")
# while loop
[ys, zs], _ = pytensor.scan(
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
sequences=[seq],
outputs_info=[x0, None],
n_steps=n_steps,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
z = zs[-1]
f = pytensor.function(
[x0, seq, n_steps], [y, z], mode=get_default_mode().including("scan")
)
test_seq = np.zeros(200, dtype=config.floatX)
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
with pytest.raises(AssertionError, match="n_steps > 0"):
f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# If a MissingInputError is raised, it means the rewrite failed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, ys_trace, len_zs = scan_node.inputs
debug_fn = pytensor.function(
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
)
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
assert stored_ys_steps == 2
assert stored_zs_steps == 1
def test_inner_replace_dot(): def test_inner_replace_dot():
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论