提交 ee502492 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

clean-up: Use FunctionGraph instead of Env in tests

上级 67cd24d4
...@@ -3,7 +3,7 @@ from theano.gof.type import Type ...@@ -3,7 +3,7 @@ from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.opt import * # noqa from theano.gof.opt import * # noqa
from theano.gof.fg import FunctionGraph as Env from theano.gof.fg import FunctionGraph
from theano.gof.toolbox import * # noqa from theano.gof.toolbox import * # noqa
from theano import tensor as T from theano import tensor as T
...@@ -100,7 +100,7 @@ class TestPatternOptimizer: ...@@ -100,7 +100,7 @@ class TestPatternOptimizer:
# replacing the whole graph # replacing the whole graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '2'), '3'), PatternOptimizer((op1, (op2, '1', '2'), '3'),
(op4, '3', '2')).optimize(g) (op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]" assert str(g) == "[Op4(z, y)]"
...@@ -108,7 +108,7 @@ class TestPatternOptimizer: ...@@ -108,7 +108,7 @@ class TestPatternOptimizer:
def test_nested_out_pattern(self): def test_nested_out_pattern(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(x, y) e = op1(x, y)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, '1', '2'), PatternOptimizer((op1, '1', '2'),
(op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g) (op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g)
assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]" assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
...@@ -116,7 +116,7 @@ class TestPatternOptimizer: ...@@ -116,7 +116,7 @@ class TestPatternOptimizer:
def test_unification_1(self): def test_unification_1(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, x), z) # the arguments to op2 are the same e = op1(op2(x, x), z) # the arguments to op2 are the same
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(op4, '2', '1')).optimize(g) (op4, '2', '1')).optimize(g)
# So the replacement should occur # So the replacement should occur
...@@ -125,7 +125,7 @@ class TestPatternOptimizer: ...@@ -125,7 +125,7 @@ class TestPatternOptimizer:
def test_unification_2(self): def test_unification_2(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) # the arguments to op2 are different e = op1(op2(x, y), z) # the arguments to op2 are different
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(op4, '2', '1')).optimize(g) (op4, '2', '1')).optimize(g)
# The replacement should NOT occur # The replacement should NOT occur
...@@ -135,7 +135,7 @@ class TestPatternOptimizer: ...@@ -135,7 +135,7 @@ class TestPatternOptimizer:
# replacing inside the graph # replacing inside the graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op1, '2', '1')).optimize(g) (op1, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]" assert str(g) == "[Op1(Op1(y, x), z)]"
...@@ -146,7 +146,7 @@ class TestPatternOptimizer: ...@@ -146,7 +146,7 @@ class TestPatternOptimizer:
# it should do the replacement and stop # it should do the replacement and stop
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op2, '2', '1'), ign=True).optimize(g) (op2, '2', '1'), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]" assert str(g) == "[Op1(Op2(y, x), z)]"
...@@ -155,7 +155,7 @@ class TestPatternOptimizer: ...@@ -155,7 +155,7 @@ class TestPatternOptimizer:
# it should replace all occurrences of the pattern # it should replace all occurrences of the pattern
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(y, z)) e = op1(op2(x, y), op2(x, y), op2(y, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op4, '1')).optimize(g) (op4, '1')).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]" assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
...@@ -165,7 +165,7 @@ class TestPatternOptimizer: ...@@ -165,7 +165,7 @@ class TestPatternOptimizer:
# should work # should work
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(x)))) e = op1(op1(op1(op1(x))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, '1')), PatternOptimizer((op1, (op1, '1')),
'1').optimize(g) '1').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
...@@ -173,7 +173,7 @@ class TestPatternOptimizer: ...@@ -173,7 +173,7 @@ class TestPatternOptimizer:
def test_nested_odd(self): def test_nested_odd(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, '1')), PatternOptimizer((op1, (op1, '1')),
'1').optimize(g) '1').optimize(g)
assert str(g) == "[Op1(x)]" assert str(g) == "[Op1(x)]"
...@@ -181,7 +181,7 @@ class TestPatternOptimizer: ...@@ -181,7 +181,7 @@ class TestPatternOptimizer:
def test_expand(self): def test_expand(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, '1'), PatternOptimizer((op1, '1'),
(op2, (op1, '1')), ign=True).optimize(g) (op2, (op1, '1')), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]" assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
...@@ -192,7 +192,7 @@ class TestPatternOptimizer: ...@@ -192,7 +192,7 @@ class TestPatternOptimizer:
# = True or with other NavigatorOptimizers may differ. # = True or with other NavigatorOptimizers may differ.
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, '1')), TopoPatternOptimizer((op1, (op1, '1')),
(op1, '1'), ign=False).optimize(g) (op1, '1'), ign=False).optimize(g)
assert str(g) == "[Op1(x)]" assert str(g) == "[Op1(x)]"
...@@ -202,7 +202,7 @@ class TestPatternOptimizer: ...@@ -202,7 +202,7 @@ class TestPatternOptimizer:
y = MyVariable('y') y = MyVariable('y')
z = Constant(MyType(), 2, name='z') z = Constant(MyType(), 2, name='z')
e = op1(op1(x, y), y) e = op1(op1(x, y), y)
g = Env([y], [e]) g = FunctionGraph([y], [e])
PatternOptimizer((op1, z, '1'), PatternOptimizer((op1, z, '1'),
(op2, '1', z)).optimize(g) (op2, '1', z)).optimize(g)
assert str(g) == "[Op1(Op2(y, z), y)]" assert str(g) == "[Op1(Op2(y, z), y)]"
...@@ -210,7 +210,7 @@ class TestPatternOptimizer: ...@@ -210,7 +210,7 @@ class TestPatternOptimizer:
def test_constraints(self): def test_constraints(self):
x, y, z = inputs() x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y))) e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
def constraint(r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
...@@ -223,7 +223,7 @@ class TestPatternOptimizer: ...@@ -223,7 +223,7 @@ class TestPatternOptimizer:
def test_match_same(self): def test_match_same(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(x, x) e = op1(x, x)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, 'x', 'y'), PatternOptimizer((op1, 'x', 'y'),
(op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, x)]" assert str(g) == "[Op3(x, x)]"
...@@ -231,7 +231,7 @@ class TestPatternOptimizer: ...@@ -231,7 +231,7 @@ class TestPatternOptimizer:
def test_match_same_illegal(self): def test_match_same_illegal(self):
x, y, z = inputs() x, y, z = inputs()
e = op2(op1(x, x), op1(x, y)) e = op2(op1(x, x), op1(x, y))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
def constraint(r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
...@@ -245,7 +245,7 @@ class TestPatternOptimizer: ...@@ -245,7 +245,7 @@ class TestPatternOptimizer:
x, y, z = inputs() x, y, z = inputs()
e0 = op1(x, y) e0 = op1(x, y)
e = op3(op4(e0), e0) e = op3(op4(e0), e0)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, 'x', 'y')), PatternOptimizer((op4, (op1, 'x', 'y')),
(op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]" assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
...@@ -254,7 +254,7 @@ class TestPatternOptimizer: ...@@ -254,7 +254,7 @@ class TestPatternOptimizer:
# replacing the whole graph # replacing the whole graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op_y(x, y), z) e = op1(op_y(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op_z, '1', '2'), '3'), PatternOptimizer((op1, (op_z, '1', '2'), '3'),
(op4, '3', '2')).optimize(g) (op4, '3', '2')).optimize(g)
str_g = str(g) str_g = str(g)
...@@ -265,7 +265,7 @@ class TestPatternOptimizer: ...@@ -265,7 +265,7 @@ class TestPatternOptimizer:
# x, y, z = inputs() # x, y, z = inputs()
# e0 = op1(x, y) # e0 = op1(x, y)
# e = op4(e0, e0) # e = op4(e0, e0)
# g = Env([x, y, z], [e]) # g = FunctionGraph([x, y, z], [e])
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')), # PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g) # (op3, 'x', 'y')).optimize(g)
# assert str(g) == "[Op3(x, y)]" # assert str(g) == "[Op3(x, y)]"
...@@ -280,14 +280,14 @@ class TestOpSubOptimizer: ...@@ -280,14 +280,14 @@ class TestOpSubOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op1, op2).optimize(g) OpSubOptimizer(op1, op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]" assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_straightforward_2(self): def test_straightforward_2(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x), op3(y), op4(z)) e = op1(op2(x), op3(y), op4(z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op3, op4).optimize(g) OpSubOptimizer(op3, op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]" assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
...@@ -297,7 +297,7 @@ class TestMergeOptimizer: ...@@ -297,7 +297,7 @@ class TestMergeOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]" assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
...@@ -306,7 +306,7 @@ class TestMergeOptimizer: ...@@ -306,7 +306,7 @@ class TestMergeOptimizer:
y = Constant(MyType(), 2, name='y') y = Constant(MyType(), 2, name='y')
z = Constant(MyType(), 2, name='z') z = Constant(MyType(), 2, name='z')
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
...@@ -315,14 +315,14 @@ class TestMergeOptimizer: ...@@ -315,14 +315,14 @@ class TestMergeOptimizer:
def test_deep_merge(self): def test_deep_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]" assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]"
def test_no_merge(self): def test_no_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x))) e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]" assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]"
...@@ -330,7 +330,7 @@ class TestMergeOptimizer: ...@@ -330,7 +330,7 @@ class TestMergeOptimizer:
x, y, z = inputs() x, y, z = inputs()
e1 = op3(op2(x, y)) e1 = op3(op2(x, y))
e2 = op3(op2(x, y)) e2 = op3(op2(x, y))
g = Env([x, y, z], [e1, e2]) g = FunctionGraph([x, y, z], [e1, e2])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]" assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
...@@ -339,7 +339,7 @@ class TestMergeOptimizer: ...@@ -339,7 +339,7 @@ class TestMergeOptimizer:
e1 = op1(x, y) e1 = op1(x, y)
e2 = op2(op3(x), y, z) e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2)) e = op1(e1, op4(e2, e1), op1(e2))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if # note: graph.as_string can only produce the following two possibilities, but if
...@@ -357,7 +357,7 @@ class TestMergeOptimizer: ...@@ -357,7 +357,7 @@ class TestMergeOptimizer:
e1 = op1(y, z) e1 = op1(y, z)
finally: finally:
config.compute_test_value = ctv_backup config.compute_test_value = ctv_backup
g = Env([x, y, z], [e1]) g = FunctionGraph([x, y, z], [e1])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
...@@ -367,7 +367,7 @@ class TestMergeOptimizer: ...@@ -367,7 +367,7 @@ class TestMergeOptimizer:
x1 = T.matrix('x1') x1 = T.matrix('x1')
x2 = T.matrix('x2') x2 = T.matrix('x2')
e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2) e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [e]) g = FunctionGraph([x1, x2], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 4 strref = '''Elemwise{add,no_inplace} [@A] '' 4
...@@ -391,7 +391,7 @@ class TestMergeOptimizer: ...@@ -391,7 +391,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3') x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\ e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2) T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref1 = '''Elemwise{add,no_inplace} [@A] '' 6 strref1 = '''Elemwise{add,no_inplace} [@A] '' 6
...@@ -434,7 +434,7 @@ class TestMergeOptimizer: ...@@ -434,7 +434,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3') x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\ e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all()))
g = Env([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7 strref = '''Elemwise{add,no_inplace} [@A] '' 7
...@@ -463,7 +463,7 @@ class TestMergeOptimizer: ...@@ -463,7 +463,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3') x3 = T.matrix('x3')
e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\ e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\
T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2)
g = Env([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7 strref = '''Elemwise{add,no_inplace} [@A] '' 7
...@@ -491,7 +491,7 @@ class TestEquilibrium(object): ...@@ -491,7 +491,7 @@ class TestEquilibrium(object):
def test_1(self): def test_1(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')), [PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
...@@ -506,7 +506,7 @@ class TestEquilibrium(object): ...@@ -506,7 +506,7 @@ class TestEquilibrium(object):
def test_2(self): def test_2(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op1(op1(op3(x, y))) e = op1(op1(op3(x, y)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')), [PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')),
...@@ -522,7 +522,7 @@ class TestEquilibrium(object): ...@@ -522,7 +522,7 @@ class TestEquilibrium(object):
def test_low_use_ratio(self): def test_low_use_ratio(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print 'before', g # print 'before', g
# display pesky warnings along with stdout # display pesky warnings along with stdout
# also silence logger for 'theano.gof.opt' # also silence logger for 'theano.gof.opt'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论