提交 5c985943 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: introduce some tests

上级 3a6152d9
......@@ -11,7 +11,7 @@ from theano import (
from theano.gof.graph import (
Apply,
as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable)
is_same_graph, Variable, map_variables)
from theano.gof.op import Op
from theano.gof.type import Type
from theano.sandbox.cuda.var import (
......@@ -158,6 +158,69 @@ class TestClone(X):
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
#################
# map_variables #
#################
class TestMapVariables(X):
def test_leaf(self):
a = tensor.scalar("a")
b = tensor.scalar("b")
c = tensor.scalar("c")
b.tag.replacement = c
u = a + b
v, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
def test_opfromgraph(self):
from theano import OpFromGraph, function
import itertools
a = tensor.scalar()
b = tensor.scalar()
r = a + b
r.tag.replacement = a - b
c = tensor.scalar()
d = tensor.scalar()
u = OpFromGraph([a, b], [r])(c, d)
v, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[u])
f = function([c, d], [u, v])
for m, n in itertools.combinations(xrange(10), 2):
assert f(m, n) == [m + n, m - n]
def test_scan(self):
import numpy
from theano import function, scan
def step(x, a):
r = a + x
r.tag.replacement = a - x
return r
x = tensor.vector('x')
s, _ = scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
s2, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[s])
f = function([x], [s, s2])
assert numpy.array_equal(f(numpy.array([1, 2, 3])),
[[ 1, 3, 6],
[-1, -3, -6]])
############
# toposort #
############
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论