提交 96672f62 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: test replacement on the correct graph

上级 b48630c2
...@@ -149,9 +149,7 @@ class TestMapVariables(unittest.TestCase): ...@@ -149,9 +149,7 @@ class TestMapVariables(unittest.TestCase):
d = tensor.scalar() d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d) u = theano.OpFromGraph([a, b], [r])(c, d)
t = z * u t = z * u
v, = map_variables( v, = map_variables(self.replacer, [t])
self.replacer, [u],
additional_inputs=[outer, shared])
t2 = z * v t2 = z * v
f = theano.function([c, d, outer], [t, t2]) f = theano.function([c, d, outer], [t, t2])
...@@ -162,5 +160,4 @@ class TestMapVariables(unittest.TestCase): ...@@ -162,5 +160,4 @@ class TestMapVariables(unittest.TestCase):
# variable with updates crashes # variable with updates crashes
shared.update = shared + 1 shared.update = shared + 1
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
map_variables, self.replacer, [u], map_variables, self.replacer, [t])
additional_inputs=[outer, shared])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论