提交 c1ecbe0e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid large allocation for taps of length 1 in ScanSaveMem

上级 f6958407
......@@ -53,6 +53,7 @@ from pytensor.scan.utils import (
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
atleast_Nd,
get_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
......@@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements
def _is_default_scan_buffer(x: TensorVariable) -> bool:
node = x.owner
def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
node = final_buffer.owner
if node is None:
return False
......@@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
):
return False
x, y, *_ = node.inputs
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
init_buffer, init_value, *_ = node.inputs
if not (
init_buffer.owner is not None and isinstance(init_buffer.owner.op, AllocEmpty)
):
return False
# The value may have been broadcast to fill in the initial taps.
......@@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x):
return False
return True
if taps > 1:
return not broadcasted_by(init_value, init_buffer)
else:
# In this case we know we have alloc_empty(1 + nsteps, ...)[:1].set(init_value)
# The first dimension cannot possibly broadcast in the subtensor assignment,
# so we exclude it from `broadcasted_by`. To exclude it we squeeze it out,
# after adding any other implicit expand_dims. We select into the first entry of
# the buffer, to check for potential broadcasting in other dimensions.
init_value_ = atleast_Nd(init_value, n=init_buffer.ndim)
return not broadcasted_by(init_value_.squeeze(0), init_buffer[0])
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
......@@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[i]
nw_input = nw_inputs[offset + idx]
# Recreate default buffers with new size
if _is_default_scan_buffer(nw_input):
extra_size = 1 if required_orphan else val - init_l[i]
if _is_default_scan_buffer(nw_input, taps):
extra_size = 1 if required_orphan else val - taps
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice
else:
stop = init_l[i] if required_orphan else val
stop = taps if required_orphan else val
nw_input = nw_input[:stop]
nw_inputs[offset + idx] = nw_input
......@@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[op_info.n_mit_mot + idx]
in_idx = offset + idx
nw_input = nw_inputs[in_idx]
if _is_default_scan_buffer(nw_input):
if _is_default_scan_buffer(nw_input, taps):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
else:
# Number of steps in the initial state
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_input = nw_input[: (taps + nw_steps)]
nw_inputs[in_idx] = nw_input
elif (
......
......@@ -96,9 +96,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
"""
bx = x.type.broadcastable
by = y.type.broadcastable
if len(bx) < len(by):
bx_len = len(bx)
by_len = len(by)
if bx_len < by_len:
return True
bx = bx[-len(by) :]
bx = bx[bx_len - by_len :]
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True))
......
......@@ -9,13 +9,14 @@ from pytensor.compile.io import In
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.basic import Constant, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid, tanh
......@@ -1207,7 +1208,7 @@ class TestScanInplaceOptimizer:
class TestSaveMem:
mode = get_default_mode().including("scan_save_mem")
mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout")
def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed())
......@@ -1371,7 +1372,7 @@ class TestSaveMem:
)
def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return (
u_t + 1.0,
u_t + 2.0,
......@@ -1388,7 +1389,7 @@ class TestSaveMem:
x30 = vector("x30")
x40 = scalar("x40")
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn,
step,
u,
[
None,
......@@ -1404,7 +1405,7 @@ class TestSaveMem:
go_backwards=False,
)
f2 = function(
f = function(
[u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates,
......@@ -1417,13 +1418,51 @@ class TestSaveMem:
v_u = rng.uniform(-5.0, 5.0, size=(20,))
# compute the output in numpy
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0)
utt.assert_allclose(tx1, v_u[-7] + 1.0)
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0)
utt.assert_allclose(tx3, v_u[-6:] + 3.0)
utt.assert_allclose(tx4, v_u[-1] + 4.0)
utt.assert_allclose(tx5, v_u[-1] + 5.0)
tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0)
rtol = 1e-7 if config.floatX == "float64" else 1e-6
np.testing.assert_allclose(tx1, v_u[-7] + 1.0, rtol=rtol)
np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0, rtol=rtol)
np.testing.assert_allclose(tx3, v_u[-6:] + 3.0, rtol=rtol)
np.testing.assert_allclose(tx4, v_u[-1] + 4.0, rtol=rtol)
np.testing.assert_allclose(tx5, v_u[-1] + 5.0, rtol=rtol)
# Confirm reduction in buffer sizes
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
# x6 and x7 are dropped because they are not used
[n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs
[x4_underlying_alloc] = [
var
for var in ancestors([x4_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
[x5_underlying_alloc] = [
var
for var in ancestors([x5_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
buffer_lengths = pytensor.function(
[u, x10, x20, x30, x40],
[
x1_len,
x2_len,
x3_len,
x4_underlying_alloc.shape[0],
x5_underlying_alloc.shape[0],
],
accept_inplace=True,
on_unused_input="ignore",
allow_input_downcast=True,
)(v_u, [0, 0], 0, [0, 0], 0)
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
assert [int(i) for i in buffer_lengths] == [
7, # entry -7 of a map variable is kept, we need at least that many
3, # entries [-3, -2] of a map variable are kept, we need at least 3
6, # last six entries of a map variable are kept
2 + 1, # last entry of a double tap variable is kept
1 + 1, # last entry of a single tap variable is kept
]
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = pt.ones(())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论