提交 a24f5345 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in ScanSaveMem with broadcasted initial value

上级 c822a8e6
...@@ -58,7 +58,11 @@ from pytensor.tensor.basic import ( ...@@ -58,7 +58,11 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, dot, maximum, minimum from pytensor.tensor.math import Dot, dot, maximum, minimum
from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch from pytensor.tensor.rewriting.basic import (
broadcasted_by,
constant_folding,
local_useless_switch,
)
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink
from pytensor.tensor.shape import shape from pytensor.tensor.shape import shape
...@@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): ...@@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node):
return subtensor_merge_replacements return subtensor_merge_replacements
def _is_default_scan_buffer(x: TensorVariable) -> bool:
node = x.owner
if node is None:
return False
op = node.op
if not (
isinstance(op, IncSubtensor)
and op.set_instead_of_inc
and op.idx_list == [slice(None, ps.int64)]
):
return False
x, y, *_ = node.inputs
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)):
return False
# The value may have been broadcast to fill in the initial taps.
# If the user specified outputs as:
# x = scalar(); init = alloc(x, 2);
# outputs_info=[init, taps=(-2, -1)]
# Scan will generate an initial buffer that looks like
# alloc_empty(2 + nsteps)[:2].set(alloc(x, 2))
# PyTensor will then rewrite it as:
# alloc_empty(2 + nsteps)[:2].set(x)
# When the initial value (x) is being broadcast by the set_subtensor
# we can't recreate a newly sized buffer working with x alone
# We want to check that:
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable
# But due to laziness we use the slightly more conservative check:
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable
if broadcasted_by(y, x):
return False
return True
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool):
r"""Graph optimizer that reduces scan memory consumption. r"""Graph optimizer that reduces scan memory consumption.
...@@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# 3.2 check orphane outputs to see if we can eliminate any # 3.2 check orphane outputs to see if we can eliminate any
required, not_required = scan_can_remove_outs(node.op, orphane_outs) required, not_required = scan_can_remove_outs(node.op, orphane_outs)
# 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required # 3.3. compose replace pairs for those nodes that need not store everything in memory
# by the inner function .. ) # (or ar orphan but required by the inner function)
replaced_outs = [] replaced_outs = []
offset = 1 + op_info.n_seqs + op_info.n_mit_mot offset = 1 + op_info.n_seqs + op_info.n_mit_mot
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]): for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
i = idx + op_info.n_mit_mot i = idx + op_info.n_mit_mot
if not (isinstance(_val, int) and _val <= 0 and i not in required): if not (isinstance(val, int) and val <= 0 and i not in required):
if idx + op_info.n_mit_mot in required: required_orphan = idx + op_info.n_mit_mot in required
val = 1
else:
val = _val
# If the memory for this output has been pre-allocated # If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node) # before going into the scan op (by an alloc node)
if idx < op_info.n_mit_sot + op_info.n_sit_sot: if idx < op_info.n_mit_sot + op_info.n_sit_sot:
# In case the input is still an alloc node, we nw_input = nw_inputs[offset + idx]
# actually have two options:
# a) the input is a set_subtensor, in that case we # Recreate default buffers with new size
# can replace the initial tensor by a slice, if _is_default_scan_buffer(nw_input):
# b) it is not, and we simply take a slice of it. extra_size = 1 if required_orphan else val - init_l[i]
# TODO: commit change below with Razvan nw_input = expand_empty(nw_input.owner.inputs[1], extra_size)
if ( # Otherwise, just trim with a slice
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
# As it happens in set_subtensor(empty(2)[:], 0)
and not (
nw_inputs[offset + idx].ndim
> nw_inputs[offset + idx].owner.inputs[1].ndim
)
):
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = pt.as_tensor_variable(val)
initl = pt.as_tensor_variable(init_l[i])
tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl)
nw_input = expand_empty(_nw_input, tmp_idx)
else: else:
tmp = pt.as_tensor_variable(val) stop = init_l[i] if required_orphan else val
initl = pt.as_tensor_variable(init_l[i]) nw_input = nw_input[:stop]
tmp = maximum(tmp, initl)
nw_input = nw_inputs[offset + idx][:tmp]
nw_inputs[offset + idx] = nw_input nw_inputs[offset + idx] = nw_input
replaced_outs.append(op_info.n_mit_mot + idx) replaced_outs.append(op_info.n_mit_mot + idx)
...@@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ op_info.n_shared_outs + op_info.n_shared_outs
) )
if nw_inputs[pos] == node.inputs[0]: if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = val nw_inputs[pos] = 1 if required_orphan else val
odx = op_info.n_mit_mot + idx odx = op_info.n_mit_mot + idx
replaced_outs.append(odx) replaced_outs.append(odx)
old_outputs += [ old_outputs += [
...@@ -1600,8 +1619,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1600,8 +1619,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
], ],
) )
] ]
# 3.4. Recompute inputs for everything else based on the new # 3.4. Recompute inputs for everything else based on the new number of steps
# number of steps
if global_nsteps is not None: if global_nsteps is not None:
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): for idx, val in enumerate(store_steps[op_info.n_mit_mot :]):
if val == 0: if val == 0:
...@@ -1609,28 +1627,14 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1609,28 +1627,14 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
# results for that state, including the initial values. # results for that state, including the initial values.
if idx < op_info.n_mit_sot + op_info.n_sit_sot: if idx < op_info.n_mit_sot + op_info.n_sit_sot:
in_idx = offset + idx in_idx = offset + idx
# Number of steps in the initial state nw_input = nw_inputs[in_idx]
initl = init_l[op_info.n_mit_mot + idx] if _is_default_scan_buffer(nw_input):
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps)
# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if (
nw_inputs[in_idx].owner
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[in_idx].owner.op.idx_list[0], slice
)
):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input
else: else:
# FIXME: This is never used # Number of steps in the initial state
nw_input = nw_inputs[in_idx][: (initl + nw_steps)] init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx])
nw_input = nw_input[: (init_l_pt + nw_steps)]
nw_inputs[in_idx] = nw_input
elif ( elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
......
...@@ -1634,21 +1634,33 @@ class TestSaveMem: ...@@ -1634,21 +1634,33 @@ class TestSaveMem:
assert stored_ys_steps == 2 assert stored_ys_steps == 2
assert stored_zs_steps == 1 assert stored_zs_steps == 1
def test_vector_zeros_init(self): @pytest.mark.parametrize("val_ndim", (0, 1))
@pytest.mark.parametrize("keep_beginning", (False, True))
def test_broadcasted_init(self, keep_beginning, val_ndim):
# Regression test when the original value is a broadcasted alloc
# The scan save mem rewrite used to wrongly slice on the unbroadcasted value
val_shape = (1,) * val_ndim
val = pt.tensor("val", shape=val_shape)
val_test = np.zeros(val_shape, dtype=val.dtype)
init = pt.full((2,), val)
ys, _ = pytensor.scan( ys, _ = pytensor.scan(
fn=lambda ytm2, ytm1: ytm1 + ytm2, fn=lambda *args: pt.add(*args),
outputs_info=[{"initial": pt.zeros(2), "taps": range(-2, 0)}], outputs_info=[{"initial": init, "taps": (-2, -1)}],
n_steps=100, n_steps=100,
) )
fn = pytensor.function([], ys[-50:], mode=self.mode) out = ys[:-50] if keep_beginning else ys[-50:]
assert tuple(fn().shape) == (50,) fn = pytensor.function([val], out, mode=self.mode)
assert fn(val_test).shape == (50,)
# Check that rewrite worked # Check that rewrite worked
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) [scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs _, ys_trace = scan_node.inputs
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True) buffer_size_fn = pytensor.function(
assert debug_fn() == 50 [val], ys_trace.shape[0], accept_inplace=True
)
assert buffer_size_fn(val_test) == 52 if keep_beginning else 50
def test_inner_replace_dot(): def test_inner_replace_dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论