提交 b4ccf879 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: test insertion of inputs into OpFromGraph inner graph

上级 0be1af83
......@@ -41,20 +41,39 @@ class TestMapVariables(unittest.TestCase):
assert v.owner.inputs == [a, c]
def test_opfromgraph(self):
# as with the scan tests above, insert foreign inputs into the
# inner graph.
outer = tensor.scalar("outer")
shared = theano.shared(1, name="shared")
constant = tensor.constant(1, name="constant")
z = outer * (shared + constant)
# construct the inner graph
a = tensor.scalar()
b = tensor.scalar()
r = a + b
r.tag.replacement = a - b
r.tag.replacement = z * (a - b)
# construct the outer graph
c = tensor.scalar()
d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d)
v, = map_variables(self.replacer, [u])
t = z * u
v, = map_variables(
self.replacer, [u],
additional_inputs=[outer, shared])
t2 = z * v
f = theano.function([c, d], [u, v])
f = theano.function([c, d, outer], [t, t2])
for m, n in itertools.combinations(range(10), 2):
assert f(m, n) == [m + n, m - n]
assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared.update = shared + 1
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [u],
additional_inputs=[outer, shared])
def test_scan(self):
x = tensor.vector('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论