提交 d07714f7 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: test introduction of outer graph variables into inner graph

上级 7c82d893
......@@ -200,25 +200,41 @@ class TestMapVariables(X):
assert f(m, n) == [m + n, m - n]
def test_scan(self):
import numpy
from theano import function, scan
from theano import shared, scan, function
x = tensor.vector('x')
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer = tensor.scalar("outer")
shared = shared(1, name="shared")
constant = tensor.constant(1, name="constant")
# z will equal 1 so multiplying by it doesn't change any values
z = outer * (shared + constant)
def step(x, a):
r = a + x
r.tag.replacement = a - x
r.tag.replacement = z * (a - x)
return r
x = tensor.vector('x')
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
s2, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[s])
f = function([x], [s, s2])
assert numpy.array_equal(f(numpy.array([1, 2, 3])),
[[ 1, 3, 6],
[-1, -3, -6]])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t = z * s
s2, = map_variables(replacer, [t])
t2 = z * s2
f = function([x, outer], [t, t2])
rval = f(x=numpy.array([1, 2, 3]), outer=0.5)
assert numpy.array_equal(rval, [[ 1, 3, 6], [-1, -3, -6]])
def test_leaf_inside_scan(self):
import numpy
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论