提交 ee502492 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

clean-up: Use FunctionGraph instead of Env in tests

上级 67cd24d4
......@@ -3,7 +3,7 @@ from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op
from theano.gof.opt import * # noqa
from theano.gof.fg import FunctionGraph as Env
from theano.gof.fg import FunctionGraph
from theano.gof.toolbox import * # noqa
from theano import tensor as T
......@@ -100,7 +100,7 @@ class TestPatternOptimizer:
# replacing the whole graph
x, y, z = inputs()
e = op1(op2(x, y), z)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '2'), '3'),
(op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]"
......@@ -108,7 +108,7 @@ class TestPatternOptimizer:
def test_nested_out_pattern(self):
x, y, z = inputs()
e = op1(x, y)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, '1', '2'),
(op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g)
assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
......@@ -116,7 +116,7 @@ class TestPatternOptimizer:
def test_unification_1(self):
x, y, z = inputs()
e = op1(op2(x, x), z) # the arguments to op2 are the same
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(op4, '2', '1')).optimize(g)
# So the replacement should occur
......@@ -125,7 +125,7 @@ class TestPatternOptimizer:
def test_unification_2(self):
x, y, z = inputs()
e = op1(op2(x, y), z) # the arguments to op2 are different
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(op4, '2', '1')).optimize(g)
# The replacement should NOT occur
......@@ -135,7 +135,7 @@ class TestPatternOptimizer:
# replacing inside the graph
x, y, z = inputs()
e = op1(op2(x, y), z)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'),
(op1, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]"
......@@ -146,7 +146,7 @@ class TestPatternOptimizer:
# it should do the replacement and stop
x, y, z = inputs()
e = op1(op2(x, y), z)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'),
(op2, '2', '1'), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]"
......@@ -155,7 +155,7 @@ class TestPatternOptimizer:
# it should replace all occurrences of the pattern
x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(y, z))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'),
(op4, '1')).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
......@@ -165,7 +165,7 @@ class TestPatternOptimizer:
# should work
x, y, z = inputs()
e = op1(op1(op1(op1(x))))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, '1')),
'1').optimize(g)
assert str(g) == "[x]"
......@@ -173,7 +173,7 @@ class TestPatternOptimizer:
def test_nested_odd(self):
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, '1')),
'1').optimize(g)
assert str(g) == "[Op1(x)]"
......@@ -181,7 +181,7 @@ class TestPatternOptimizer:
def test_expand(self):
x, y, z = inputs()
e = op1(op1(op1(x)))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, '1'),
(op2, (op1, '1')), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
......@@ -192,7 +192,7 @@ class TestPatternOptimizer:
# = True or with other NavigatorOptimizers may differ.
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, '1')),
(op1, '1'), ign=False).optimize(g)
assert str(g) == "[Op1(x)]"
......@@ -202,7 +202,7 @@ class TestPatternOptimizer:
y = MyVariable('y')
z = Constant(MyType(), 2, name='z')
e = op1(op1(x, y), y)
g = Env([y], [e])
g = FunctionGraph([y], [e])
PatternOptimizer((op1, z, '1'),
(op2, '1', z)).optimize(g)
assert str(g) == "[Op1(Op2(y, z), y)]"
......@@ -210,7 +210,7 @@ class TestPatternOptimizer:
def test_constraints(self):
x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
def constraint(r):
# Only replacing if the input is an instance of Op2
......@@ -223,7 +223,7 @@ class TestPatternOptimizer:
def test_match_same(self):
x, y, z = inputs()
e = op1(x, x)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, 'x', 'y'),
(op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, x)]"
......@@ -231,7 +231,7 @@ class TestPatternOptimizer:
def test_match_same_illegal(self):
x, y, z = inputs()
e = op2(op1(x, x), op1(x, y))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
def constraint(r):
# Only replacing if the input is an instance of Op2
......@@ -245,7 +245,7 @@ class TestPatternOptimizer:
x, y, z = inputs()
e0 = op1(x, y)
e = op3(op4(e0), e0)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, 'x', 'y')),
(op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
......@@ -254,7 +254,7 @@ class TestPatternOptimizer:
# replacing the whole graph
x, y, z = inputs()
e = op1(op_y(x, y), z)
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op_z, '1', '2'), '3'),
(op4, '3', '2')).optimize(g)
str_g = str(g)
......@@ -265,7 +265,7 @@ class TestPatternOptimizer:
# x, y, z = inputs()
# e0 = op1(x, y)
# e = op4(e0, e0)
# g = Env([x, y, z], [e])
# g = FunctionGraph([x, y, z], [e])
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g)
# assert str(g) == "[Op3(x, y)]"
......@@ -280,14 +280,14 @@ class TestOpSubOptimizer:
def test_straightforward(self):
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op1, op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_straightforward_2(self):
x, y, z = inputs()
e = op1(op2(x), op3(y), op4(z))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op3, op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
......@@ -297,7 +297,7 @@ class TestMergeOptimizer:
def test_straightforward(self):
x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
......@@ -306,7 +306,7 @@ class TestMergeOptimizer:
y = Constant(MyType(), 2, name='y')
z = Constant(MyType(), 2, name='z')
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
strg = str(g)
assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
......@@ -315,14 +315,14 @@ class TestMergeOptimizer:
def test_deep_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]"
def test_no_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]"
......@@ -330,7 +330,7 @@ class TestMergeOptimizer:
x, y, z = inputs()
e1 = op3(op2(x, y))
e2 = op3(op2(x, y))
g = Env([x, y, z], [e1, e2])
g = FunctionGraph([x, y, z], [e1, e2])
MergeOptimizer().optimize(g)
assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
......@@ -339,7 +339,7 @@ class TestMergeOptimizer:
e1 = op1(x, y)
e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g)
strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if
......@@ -357,7 +357,7 @@ class TestMergeOptimizer:
e1 = op1(y, z)
finally:
config.compute_test_value = ctv_backup
g = Env([x, y, z], [e1])
g = FunctionGraph([x, y, z], [e1])
MergeOptimizer().optimize(g)
strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
......@@ -367,7 +367,7 @@ class TestMergeOptimizer:
x1 = T.matrix('x1')
x2 = T.matrix('x2')
e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [e])
g = FunctionGraph([x1, x2], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 4
......@@ -391,7 +391,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2, x3], [e])
g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref1 = '''Elemwise{add,no_inplace} [@A] '' 6
......@@ -434,7 +434,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all()))
g = Env([x1, x2, x3], [e])
g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7
......@@ -463,7 +463,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3')
e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\
T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2)
g = Env([x1, x2, x3], [e])
g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7
......@@ -491,7 +491,7 @@ class TestEquilibrium(object):
def test_1(self):
x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
# print g
opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
......@@ -506,7 +506,7 @@ class TestEquilibrium(object):
def test_2(self):
x, y, z = map(MyVariable, 'xyz')
e = op1(op1(op3(x, y)))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
# print g
opt = EquilibriumOptimizer(
[PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')),
......@@ -522,7 +522,7 @@ class TestEquilibrium(object):
def test_low_use_ratio(self):
x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
g = FunctionGraph([x, y, z], [e])
# print 'before', g
# display pesky warnings along with stdout
# also silence logger for 'theano.gof.opt'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论