提交 c1ecbe0e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid large allocation for taps of length 1 in ScanSaveMem

上级 f6958407
...@@ -53,6 +53,7 @@ from pytensor.scan.utils import ( ...@@ -53,6 +53,7 @@ from pytensor.scan.utils import (
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
atleast_Nd,
get_scalar_constant_value, get_scalar_constant_value,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): ...@@ -1186,8 +1187,8 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements return subtensor_merge_replacements
def _is_default_scan_buffer(x: TensorVariable) -> bool: def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
node = x.owner node = final_buffer.owner
if node is None: if node is None:
return False return False
...@@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool: ...@@ -1200,8 +1201,10 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
): ):
return False return False
x, y, *_ = node.inputs init_buffer, init_value, *_ = node.inputs
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)): if not (
init_buffer.owner is not None and isinstance(init_buffer.owner.op, AllocEmpty)
):
return False return False
# The value may have been broadcast to fill in the initial taps. # The value may have been broadcast to fill in the initial taps.
...@@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool: ...@@ -1218,10 +1221,16 @@ def _is_default_scan_buffer(x: TensorVariable) -> bool:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable # 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check: # But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable # 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x): if taps > 1:
return False return not broadcasted_by(init_value, init_buffer)
else:
return True # In this case we know we have alloc_empty(1 + nsteps, ...)[:1].set(init_value)
# The first dimension cannot possibly broadcast in the subtensor assignment,
# so we exclude it from `broadcasted_by`. To exclude it we squeeze it out,
# after adding any other implicit expand_dims. We select into the first entry of
# the buffer, to check for potential broadcasting in other dimensions.
init_value_ = atleast_Nd(init_value, n=init_buffer.ndim)
return not broadcasted_by(init_value_.squeeze(0), init_buffer[0])
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
...@@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1574,15 +1583,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# If the memory for this output has been pre-allocated # If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node) # before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot: if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[i]
nw_input = nw_inputs[offset + idx] nw_input = nw_inputs[offset + idx]
# Recreate default buffers with new size # Recreate default buffers with new size
if _is_default_scan_buffer(nw_input): if _is_default_scan_buffer(nw_input, taps):
extra_size = 1 if required_orphan else val - init_l[i] extra_size = 1 if required_orphan else val - taps
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice # Otherwise, just trim with a slice
else: else:
stop = init_l[i] if required_orphan else val stop = taps if required_orphan else val
nw_input = nw_input[:stop] nw_input = nw_input[:stop]
nw_inputs[offset + idx] = nw_input nw_inputs[offset + idx] = nw_input
...@@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1626,14 +1636,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# val == 0 means that we want to keep all intermediate # val == 0 means that we want to keep all intermediate
# results for that state, including the initial values. # results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot: if idx < op_info.n_mit_sot + op_info.n_sit_sot:
taps = init_l[op_info.n_mit_mot + idx]
in_idx = offset + idx in_idx = offset + idx
nw_input = nw_inputs[in_idx] nw_input = nw_inputs[in_idx]
if _is_default_scan_buffer(nw_input): if _is_default_scan_buffer(nw_input, taps):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps) nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
else: else:
# Number of steps in the initial state nw_input = nw_input[: (taps + nw_steps)]
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
elif ( elif (
......
...@@ -96,9 +96,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: ...@@ -96,9 +96,11 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
""" """
bx = x.type.broadcastable bx = x.type.broadcastable
by = y.type.broadcastable by = y.type.broadcastable
if len(bx) < len(by): bx_len = len(bx)
by_len = len(by)
if bx_len < by_len:
return True return True
bx = bx[-len(by) :] bx = bx[bx_len - by_len :]
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True)) return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by, strict=True))
......
...@@ -9,13 +9,14 @@ from pytensor.compile.io import In ...@@ -9,13 +9,14 @@ from pytensor.compile.io import In
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.basic import Constant, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until from pytensor.scan.utils import until
from pytensor.tensor import stack from pytensor.tensor import stack
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid, tanh from pytensor.tensor.math import Dot, dot, sigmoid, tanh
...@@ -1207,7 +1208,7 @@ class TestScanInplaceOptimizer: ...@@ -1207,7 +1208,7 @@ class TestScanInplaceOptimizer:
class TestSaveMem: class TestSaveMem:
mode = get_default_mode().including("scan_save_mem") mode = get_default_mode().including("scan_save_mem").excluding("scan_pushout")
def test_save_mem(self): def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -1371,7 +1372,7 @@ class TestSaveMem: ...@@ -1371,7 +1372,7 @@ class TestSaveMem:
) )
def test_save_mem_store_steps(self): def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): def step(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return ( return (
u_t + 1.0, u_t + 1.0,
u_t + 2.0, u_t + 2.0,
...@@ -1388,7 +1389,7 @@ class TestSaveMem: ...@@ -1388,7 +1389,7 @@ class TestSaveMem:
x30 = vector("x30") x30 = vector("x30")
x40 = scalar("x40") x40 = scalar("x40")
[x1, x2, x3, x4, x5, x6, x7], updates = scan( [x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn, step,
u, u,
[ [
None, None,
...@@ -1404,7 +1405,7 @@ class TestSaveMem: ...@@ -1404,7 +1405,7 @@ class TestSaveMem:
go_backwards=False, go_backwards=False,
) )
f2 = function( f = function(
[u, x10, x20, x30, x40], [u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]], [x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates, updates=updates,
...@@ -1417,13 +1418,51 @@ class TestSaveMem: ...@@ -1417,13 +1418,51 @@ class TestSaveMem:
v_u = rng.uniform(-5.0, 5.0, size=(20,)) v_u = rng.uniform(-5.0, 5.0, size=(20,))
# compute the output in numpy # compute the output in numpy
tx1, tx2, tx3, tx4, tx5 = f2(v_u, [0, 0], 0, [0, 0], 0) tx1, tx2, tx3, tx4, tx5 = f(v_u, [0, 0], 0, [0, 0], 0)
rtol = 1e-7 if config.floatX == "float64" else 1e-6
utt.assert_allclose(tx1, v_u[-7] + 1.0) np.testing.assert_allclose(tx1, v_u[-7] + 1.0, rtol=rtol)
utt.assert_allclose(tx2, v_u[-3:-1] + 2.0) np.testing.assert_allclose(tx2, v_u[-3:-1] + 2.0, rtol=rtol)
utt.assert_allclose(tx3, v_u[-6:] + 3.0) np.testing.assert_allclose(tx3, v_u[-6:] + 3.0, rtol=rtol)
utt.assert_allclose(tx4, v_u[-1] + 4.0) np.testing.assert_allclose(tx4, v_u[-1] + 4.0, rtol=rtol)
utt.assert_allclose(tx5, v_u[-1] + 5.0) np.testing.assert_allclose(tx5, v_u[-1] + 5.0, rtol=rtol)
# Confirm reduction in buffer sizes
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
# x6 and x7 are dropped because they are not used
[n_steps, seq, x4_buffer, x5_buffer, x1_len, x2_len, x3_len] = scan_node.inputs
[x4_underlying_alloc] = [
var
for var in ancestors([x4_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
[x5_underlying_alloc] = [
var
for var in ancestors([x5_buffer])
if var.owner and isinstance(var.owner.op, AllocEmpty)
]
buffer_lengths = pytensor.function(
[u, x10, x20, x30, x40],
[
x1_len,
x2_len,
x3_len,
x4_underlying_alloc.shape[0],
x5_underlying_alloc.shape[0],
],
accept_inplace=True,
on_unused_input="ignore",
allow_input_downcast=True,
)(v_u, [0, 0], 0, [0, 0], 0)
# ScanSaveMem keeps +1 entries to handle taps with preallocated outputs
assert [int(i) for i in buffer_lengths] == [
7, # entry -7 of a map variable is kept, we need at least that many
3, # entries [-3, -2] of a map variable are kept, we need at least 3
6, # last six entries of a map variable are kept
2 + 1, # last entry of a double tap variable is kept
1 + 1, # last entry of a single tap variable is kept
]
def test_savemem_does_not_duplicate_number_of_scan_nodes(self): def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = pt.ones(()) var = pt.ones(())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论