提交 d07714f7 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: test introduction of outer graph variables into inner graph

上级 7c82d893
...@@ -200,25 +200,41 @@ class TestMapVariables(X): ...@@ -200,25 +200,41 @@ class TestMapVariables(X):
assert f(m, n) == [m + n, m - n] assert f(m, n) == [m + n, m - n]
def test_scan(self): def test_scan(self):
import numpy from theano import shared, scan, function
from theano import function, scan
x = tensor.vector('x')
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer = tensor.scalar("outer")
shared = shared(1, name="shared")
constant = tensor.constant(1, name="constant")
# z will equal 1 so multiplying by it doesn't change any values
z = outer * (shared + constant)
def step(x, a): def step(x, a):
r = a + x r = a + x
r.tag.replacement = a - x r.tag.replacement = z * (a - x)
return r return r
x = tensor.vector('x') def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(step, sequences=x, s, _ = scan(step, sequences=x,
outputs_info=[numpy.array(0.)]) outputs_info=[numpy.array(0.)])
s2, = map_variables( # ensure z is owned by the outer graph so map_variables() will need to
lambda x: getattr(x.tag, "replacement", x), # jump through additional hoops to placate FunctionGraph.
[s]) t = z * s
s2, = map_variables(replacer, [t])
f = function([x], [s, s2]) t2 = z * s2
assert numpy.array_equal(f(numpy.array([1, 2, 3])),
[[ 1, 3, 6], f = function([x, outer], [t, t2])
[-1, -3, -6]]) rval = f(x=numpy.array([1, 2, 3]), outer=0.5)
assert numpy.array_equal(rval, [[ 1, 3, 6], [-1, -3, -6]])
def test_leaf_inside_scan(self): def test_leaf_inside_scan(self):
import numpy import numpy
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论