Unverified 提交 864ebd1a authored 作者: Rob Zinkov's avatar Rob Zinkov 提交者: GitHub

Handle `compute_map=None` in Scan (#1435)

上级 92eef5ed
......@@ -1647,8 +1647,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
if compute_map is not None:
for o in node.outputs:
compute_map[o][0] = True
if allow_gc:
self.fn.free()
return r
......
......@@ -27,6 +27,7 @@ from pytensor.compile.monitormode import MonitorMode
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Apply, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
......@@ -1178,6 +1179,17 @@ class TestScan:
utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng)
def test_blockwise_scan(self):
x = pt.tensor("x", shape=())
out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10)
x_vec = pt.tensor("x_vec", shape=(None,))
out_vec = vectorize_graph(out, {x: x_vec})
fn = function([x_vec], out_vec)
o1 = fn([1, 2, 3])
o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1)
assert np.allclose(o1, o2)
def test_connection_pattern(self):
"""Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论