提交 893a4c74 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix scan_save_mem with 0 steps

上级 9004e5f2
...@@ -1459,15 +1459,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1459,15 +1459,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
real_steps = None real_steps = None
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
# FIXME: This is not correct. Scan with 0 steps seems to be supported
# Make sure the ScanSaveMem optimization never makes the new
# number of steps to be 0 (this could happen, for instance, if
# the optimization detects that the outputs of the Scan go through
# subtensor nodes that end up taking no elements) because Scan with
# 0 iterations are not supported. Make sure the new number of steps
# is at least 1.
nw_steps = select_max(nw_steps, 1)
# 2.4 Loop over the clients again now looking just to see how many # 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store. Skip mit_mot outputs as their # intermediate steps to store. Skip mit_mot outputs as their
# store_steps is always 0 (all intermediate values are needed). # store_steps is always 0 (all intermediate values are needed).
...@@ -1537,7 +1528,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1537,7 +1528,13 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
if prealloc_outs and preallocable_output: if prealloc_outs and preallocable_output:
# TODO: If there's only one output or other outputs do not depend # TODO: If there's only one output or other outputs do not depend
# on the same input, we could reduce the buffer size to the minimum # on the same input, we could reduce the buffer size to the minimum
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1) # The extra entry to prevent aliasing between the new
# state and the oldest tap is only needed when the
# scan actually runs (nw_steps >= 1).
pval = select_max(
nw_steps - start + init_l[i],
init_l[i] + minimum(nw_steps, 1),
)
else: else:
pval = select_max(nw_steps - start + init_l[i], init_l[i]) pval = select_max(nw_steps - start + init_l[i], init_l[i])
...@@ -1648,10 +1645,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1648,10 +1645,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
inv_compress_map = {v: k for k, v in compress_map.items()} inv_compress_map = {v: k for k, v in compress_map.items()}
# 3.6 Compose the new scan # 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
return False
# Do not call make_node for test_value # Do not call make_node for test_value
new_op = Scan( new_op = Scan(
......
...@@ -1575,10 +1575,7 @@ class TestSaveMem: ...@@ -1575,10 +1575,7 @@ class TestSaveMem:
def test_savemem_opt_0_step(self): def test_savemem_opt_0_step(self):
""" """
Test a case where the savemem optimization has the opportunity to Test a case where the savemem optimization has the opportunity to
lower the number of steps of a Scan to 0. It tests that the lower the number of steps of a Scan to 0.
optimization doesn't do so since Scan nodes with 0
steps are not currently supported and doing so would result in a
crash during the function execution.
""" """
def inner_scan_step(x_t_t, h_tm1, w): def inner_scan_step(x_t_t, h_tm1, w):
...@@ -1628,6 +1625,32 @@ class TestSaveMem: ...@@ -1628,6 +1625,32 @@ class TestSaveMem:
output = f(x_value, w_value) output = f(x_value, w_value)
utt.assert_allclose(output, expected_output) utt.assert_allclose(output, expected_output)
def test_savemem_0_steps_does_not_point_to_unitialized_memory(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1878
n = pt.tensor("n", shape=(), dtype=int)
init_state = pt.tensor("init_state", shape=(3,))
buffer_withot_init = pytensor.scan(
fn=lambda xtm1: xtm1 * 2,
outputs_info=[init_state],
n_steps=n,
return_updates=False,
)
# Access the last state of the Scan output buffer (which includes the initial state)
# It should never point to unitialized memory
full_buffer = buffer_withot_init.owner.inputs[0]
buffer_last_entry = full_buffer[-1]
fn = pytensor.function([init_state, n], buffer_last_entry)
init_state_val = np.ones((3,))
np.testing.assert_allclose(fn(init_state=init_state_val, n=0), init_state_val)
np.testing.assert_allclose(
fn(init_state=init_state_val, n=1), init_state_val * 2
)
np.testing.assert_allclose(
fn(init_state=init_state_val, n=2), init_state_val * 4
)
@pytest.mark.skip( @pytest.mark.skip(
reason="The 'assertion' of this test relied on something that no longer exists " reason="The 'assertion' of this test relied on something that no longer exists "
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论