提交 a24f5345 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in ScanSaveMem with broadcasted initial value

上级 c822a8e6
......@@ -58,7 +58,11 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, dot, maximum, minimum
from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch
from pytensor.tensor.rewriting.basic import (
broadcasted_by,
constant_folding,
local_useless_switch,
)
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from pytensor.tensor.shape import shape
......@@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements
def _is_default_scan_buffer(x: TensorVariable) -> bool:
node = x.owner
if node is None:
return False
op = node.op
if not (
isinstance(op, IncSubtensor)
and op.set_instead_of_inc
and op.idx_list == [slice(None, ps.int64)]
):
return False
x, y, *_ = node.inputs
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
return False
# The value may have been broadcast to fill in the initial taps.
# If the user specified outputs as:
# x = scalar(); init = alloc(x, 2);
# outputs_info=[init, taps=(-2, -1)]
# Scan will generate an initial buffer that looks like
# alloc_empty(2 + nsteps)[:2].set(alloc(x, 2))
# PyTensor will then rewrite it as:
# alloc_empty(2 + nsteps)[:2].set(x)
# When the initial value (x) is being broadcast by the set_subtensor
# we can't recreate a newly sized buffer working with x alone
# We want to check that:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x):
return False
return True
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
r"""Graph optimizer that reduces scan memory consumption.
......@@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# 3.2 check orphane outputs to see if we can eliminate any
required, not_required = scan_can_remove_outs(node.op, orphane_outs)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required
# by the inner function .. )
# 3.3. compose replace pairs for those nodes that need not store everything in memory
# (or ar orphan but required by the inner function)
replaced_outs = []
offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]):
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op_info.n_mit_mot in required:
val = 1
else:
val = _val
if not (isinstance(val, int) and val <= 0 and i not in required):
required_orphan = idx + op_info.n_mit_mot in required
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
# In case the input is still an alloc node, we
# actually have two options:
# a) the input is a set_subtensor, in that case we
# can replace the initial tensor by a slice,
# b) it is not, and we simply take a slice of it.
# TODO: commit change below with Razvan
if (
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
# As it happens in set_subtensor(empty(2)[:], 0)
and not (
nw_inputs[offset + idx].ndim
> nw_inputs[offset + idx].owner.inputs[1].ndim
)
):
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = pt.as_tensor_variable(val)
initl = pt.as_tensor_variable(init_l[i])
tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl)
nw_input = expand_empty(_nw_input, tmp_idx)
nw_input = nw_inputs[offset + idx]
# Recreate default buffers with new size
if _is_default_scan_buffer(nw_input):
extra_size = 1 if required_orphan else val - init_l[i]
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
# Otherwise, just trim with a slice
else:
tmp = pt.as_tensor_variable(val)
initl = pt.as_tensor_variable(init_l[i])
tmp = maximum(tmp, initl)
nw_input = nw_inputs[offset + idx][:tmp]
stop = init_l[i] if required_orphan else val
nw_input = nw_input[:stop]
nw_inputs[offset + idx] = nw_input
replaced_outs.append(op_info.n_mit_mot + idx)
......@@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ op_info.n_shared_outs
)
if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val
nw_inputs[pos] = 1 if required_orphan else val
odx = op_info.n_mit_mot + idx
replaced_outs.append(odx)
old_outputs += [
......@@ -1600,8 +1619,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
],
)
]
# 3.4. Recompute inputs for everything else based on the new
# number of steps
# 3.4. Recompute inputs for everything else based on the new number of steps
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0:
......@@ -1609,28 +1627,14 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op_info.n_mit_mot + idx]
# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if (
nw_inputs[in_idx].owner
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[in_idx].owner.op.idx_list[0], slice
)
):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input
nw_input = nw_inputs[in_idx]
if _is_default_scan_buffer(nw_input):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
else:
# FIXME: This is never used
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
# Number of steps in the initial state
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_inputs[in_idx] = nw_input
elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
......
......@@ -1634,21 +1634,33 @@ class TestSaveMem:
assert stored_ys_steps == 2
assert stored_zs_steps == 1
def test_vector_zeros_init(self):
@pytest.mark.parametrize("val_ndim", (0, 1))
@pytest.mark.parametrize("keep_beginning", (False, True))
def test_broadcasted_init(self, keep_beginning, val_ndim):
# Regression test when the original value is a broadcasted alloc
# The scan save mem rewrite used to wrongly slice on the unbroadcasted value
val_shape = (1,) * val_ndim
val = pt.tensor("val", shape=val_shape)
val_test = np.zeros(val_shape, dtype=val.dtype)
init = pt.full((2,), val)
ys, _ = pytensor.scan(
fn=lambda ytm2, ytm1: ytm1 + ytm2,
outputs_info=[{"initial": pt.zeros(2), "taps": range(-2, 0)}],
fn=lambda *args: pt.add(*args),
outputs_info=[{"initial": init, "taps": (-2, -1)}],
n_steps=100,
)
fn = pytensor.function([], ys[-50:], mode=self.mode)
assert tuple(fn().shape) == (50,)
out = ys[:-50] if keep_beginning else ys[-50:]
fn = pytensor.function([val], out, mode=self.mode)
assert fn(val_test).shape == (50,)
# Check that rewrite worked
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True)
assert debug_fn() == 50
buffer_size_fn = pytensor.function(
[val], ys_trace.shape[0], accept_inplace=True
)
assert buffer_size_fn(val_test) == 52 if keep_beginning else 50
def test_inner_replace_dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论