提交 b48630c2 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: reorder tests

上级 b4ccf879
...@@ -40,40 +40,19 @@ class TestMapVariables(unittest.TestCase): ...@@ -40,40 +40,19 @@ class TestMapVariables(unittest.TestCase):
assert u.owner.inputs == [a, b] assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c] assert v.owner.inputs == [a, c]
def test_opfromgraph(self): def test_leaf_inside_scan(self):
# as with the scan tests above, insert foreign inputs into the x = tensor.vector('x')
# inner graph. y = tensor.scalar('y')
outer = tensor.scalar("outer") z = tensor.scalar('z')
shared = theano.shared(1, name="shared")
constant = tensor.constant(1, name="constant")
z = outer * (shared + constant)
# construct the inner graph
a = tensor.scalar()
b = tensor.scalar()
r = a + b
r.tag.replacement = z * (a - b)
# construct the outer graph y.tag.replacement = z
c = tensor.scalar()
d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d)
t = z * u
v, = map_variables(
self.replacer, [u],
additional_inputs=[outer, shared])
t2 = z * v
f = theano.function([c, d, outer], [t, t2]) s, _ = theano.scan(lambda x: x * y, sequences=x)
for m, n in itertools.combinations(range(10), 2): s2, = map_variables(self.replacer, [s])
assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared f = theano.function([x, y, z], [s, s2])
# variable with updates crashes rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), y=1, z=2)
shared.update = shared + 1 assert numpy.array_equal(rval, [[1, 2, 3], [2, 4, 6]])
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [u],
additional_inputs=[outer, shared])
def test_scan(self): def test_scan(self):
x = tensor.vector('x') x = tensor.vector('x')
...@@ -151,16 +130,37 @@ class TestMapVariables(unittest.TestCase): ...@@ -151,16 +130,37 @@ class TestMapVariables(unittest.TestCase):
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s]) map_variables, self.replacer, [s])
def test_leaf_inside_scan(self): def test_opfromgraph(self):
x = tensor.vector('x') # as with the scan tests above, insert foreign inputs into the
y = tensor.scalar('y') # inner graph.
z = tensor.scalar('z') outer = tensor.scalar("outer")
shared = theano.shared(1, name="shared")
constant = tensor.constant(1, name="constant")
z = outer * (shared + constant)
y.tag.replacement = z # construct the inner graph
a = tensor.scalar()
b = tensor.scalar()
r = a + b
r.tag.replacement = z * (a - b)
s, _ = theano.scan(lambda x: x * y, sequences=x) # construct the outer graph
s2, = map_variables(self.replacer, [s]) c = tensor.scalar()
d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d)
t = z * u
v, = map_variables(
self.replacer, [u],
additional_inputs=[outer, shared])
t2 = z * v
f = theano.function([x, y, z], [s, s2]) f = theano.function([c, d, outer], [t, t2])
rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), y=1, z=2) for m, n in itertools.combinations(range(10), 2):
assert numpy.array_equal(rval, [[1, 2, 3], [2, 4, 6]]) assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared.update = shared + 1
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [u],
additional_inputs=[outer, shared])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论