提交 7c82d893 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: test replacement of inputs inside scan

上级 aafce809
...@@ -220,6 +220,27 @@ class TestMapVariables(X): ...@@ -220,6 +220,27 @@ class TestMapVariables(X):
[[ 1, 3, 6], [[ 1, 3, 6],
[-1, -3, -6]]) [-1, -3, -6]])
def test_leaf_inside_scan(self):
import numpy
from theano import function, scan
x = tensor.vector('x')
y = tensor.scalar('y')
z = tensor.scalar('z')
y.tag.replacement = z
s, _ = scan(lambda x: x * y, sequences=x)
s2, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[s])
f = function([x, y, z], [s, s2])
assert numpy.array_equal(
f(x=numpy.array([1, 2, 3]), y=1, z=2),
[[ 1, 2, 3],
[ 2, 4, 6]])
############ ############
# toposort # # toposort #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论