提交 7301a45e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in `save_mem_new_scan` due to broadcasting by set_subtensor

上级 bdface7e
......@@ -1516,13 +1516,17 @@ def save_mem_new_scan(fgraph, node):
if (
nw_inputs[offset + idx].owner
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and nw_inputs[offset + idx].owner.op.set_instead_of_inc
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
):
assert isinstance(
nw_inputs[offset + idx].owner.op, IncSubtensor
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value
# As it happens in set_subtensor(empty(2)[:], 0)
and not (
nw_inputs[offset + idx].ndim
> nw_inputs[offset + idx].owner.inputs[1].ndim
)
):
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = at.as_tensor_variable(val)
initl = at.as_tensor_variable(init_l[i])
......
......@@ -1487,6 +1487,22 @@ class TestSaveMem:
assert stored_ys_steps == 2
assert stored_zs_steps == 1
def test_vector_zeros_init(self):
ys, _ = pytensor.scan(
fn=lambda ytm2, ytm1: ytm1 + ytm2,
outputs_info=[{"initial": at.zeros(2), "taps": range(-2, 0)}],
n_steps=100,
)
fn = pytensor.function([], ys[-50:], mode=self.mode)
assert tuple(fn().shape) == (50,)
# Check that rewrite worked
[scan_node] = (n for n in fn.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, ys_trace = scan_node.inputs
debug_fn = pytensor.function([], ys_trace.shape[0], accept_inplace=True)
assert debug_fn() == 50
def test_inner_replace_dot():
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论