提交 42b861a0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3379 from lamblin/fix_pooldesc_merge

Enable merging of GpuDnnPoolDesc
...@@ -484,9 +484,11 @@ class MergeFeature(object): ...@@ -484,9 +484,11 @@ class MergeFeature(object):
# signature -> variable (for constants) # signature -> variable (for constants)
self.const_sig_inv = _metadict() self.const_sig_inv = _metadict()
# For all variables # For all Apply nodes
# Set of distinct (not mergeable) nodes # Set of distinct (not mergeable) nodes
self.nodes_seen = set() self.nodes_seen = set()
# Ordered set of distinct (not mergeable) nodes without any input
self.noinput_nodes = OrderedSet()
# Each element of scheduled is a list of list of (out, new_out) pairs. # Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all # Each list of pairs represent the substitution needed to replace all
...@@ -514,6 +516,10 @@ class MergeFeature(object): ...@@ -514,6 +516,10 @@ class MergeFeature(object):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
self.process_node(fgraph, node) self.process_node(fgraph, node)
# Since we are in on_change_input, node should have inputs.
if not isinstance(node, string_types):
assert node.inputs
if isinstance(new_r, graph.Constant): if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r) self.process_constant(fgraph, new_r)
...@@ -526,6 +532,8 @@ class MergeFeature(object): ...@@ -526,6 +532,8 @@ class MergeFeature(object):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
if not node.inputs:
self.noinput_nodes.discard(node)
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1): if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant # This was the last node using this constant
...@@ -592,7 +600,10 @@ class MergeFeature(object): ...@@ -592,7 +600,10 @@ class MergeFeature(object):
merge_candidates.extend(assert_clients) merge_candidates.extend(assert_clients)
else: else:
merge_candidates = [] # If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
# In that case, the candidates are all the nodes without inputs.
merge_candidates = self.noinput_nodes
replacement_candidates = [] replacement_candidates = []
for candidate in merge_candidates: for candidate in merge_candidates:
...@@ -672,6 +683,8 @@ class MergeFeature(object): ...@@ -672,6 +683,8 @@ class MergeFeature(object):
self.scheduled.append(replacement_candidates) self.scheduled.append(replacement_candidates)
else: else:
self.nodes_seen.add(node) self.nodes_seen.add(node)
if not node.inputs:
self.noinput_nodes.add(node)
def get_merged_assert_input(self, node, candidate): def get_merged_assert_input(self, node, candidate):
new_inputs = [] new_inputs = []
...@@ -2217,7 +2230,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2217,7 +2230,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count = {} process_count = {}
for o in (opt.global_optimizers + for o in (opt.global_optimizers +
list(opt.get_local_optimizers()) + list(opt.get_local_optimizers()) +
opt.final_optimizers): list(opt.final_optimizers)):
process_count.setdefault(o, 0) process_count.setdefault(o, 0)
for count in loop_process_count: for count in loop_process_count:
for o, v in iteritems(count): for o, v in iteritems(count):
...@@ -2246,7 +2259,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2246,7 +2259,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
# Skip opt that have 0 times, they probably wasn't even tried. # Skip opt that have 0 times, they probably wasn't even tried.
print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream) print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream)
print(file=stream) print(file=stream)
gf_opts = [o for o in opt.global_optimizers + opt.final_optimizers gf_opts = [o for o in (opt.global_optimizers +
list(opt.final_optimizers))
if o.print_profile.func_code is not if o.print_profile.func_code is not
Optimizer.print_profile.func_code] Optimizer.print_profile.func_code]
if not gf_opts: if not gf_opts:
......
...@@ -3,7 +3,7 @@ from theano.gof.type import Type ...@@ -3,7 +3,7 @@ from theano.gof.type import Type
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.opt import * # noqa from theano.gof.opt import * # noqa
from theano.gof.fg import FunctionGraph as Env from theano.gof.fg import FunctionGraph
from theano.gof.toolbox import * # noqa from theano.gof.toolbox import * # noqa
from theano import tensor as T from theano import tensor as T
...@@ -100,7 +100,7 @@ class TestPatternOptimizer: ...@@ -100,7 +100,7 @@ class TestPatternOptimizer:
# replacing the whole graph # replacing the whole graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '2'), '3'), PatternOptimizer((op1, (op2, '1', '2'), '3'),
(op4, '3', '2')).optimize(g) (op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]" assert str(g) == "[Op4(z, y)]"
...@@ -108,7 +108,7 @@ class TestPatternOptimizer: ...@@ -108,7 +108,7 @@ class TestPatternOptimizer:
def test_nested_out_pattern(self): def test_nested_out_pattern(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(x, y) e = op1(x, y)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, '1', '2'), PatternOptimizer((op1, '1', '2'),
(op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g) (op4, (op1, '1'), (op2, '2'), (op3, '1', '2'))).optimize(g)
assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]" assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
...@@ -116,7 +116,7 @@ class TestPatternOptimizer: ...@@ -116,7 +116,7 @@ class TestPatternOptimizer:
def test_unification_1(self): def test_unification_1(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, x), z) # the arguments to op2 are the same e = op1(op2(x, x), z) # the arguments to op2 are the same
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(op4, '2', '1')).optimize(g) (op4, '2', '1')).optimize(g)
# So the replacement should occur # So the replacement should occur
...@@ -125,7 +125,7 @@ class TestPatternOptimizer: ...@@ -125,7 +125,7 @@ class TestPatternOptimizer:
def test_unification_2(self): def test_unification_2(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) # the arguments to op2 are different e = op1(op2(x, y), z) # the arguments to op2 are different
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern PatternOptimizer((op1, (op2, '1', '1'), '2'), # they are the same in the pattern
(op4, '2', '1')).optimize(g) (op4, '2', '1')).optimize(g)
# The replacement should NOT occur # The replacement should NOT occur
...@@ -135,7 +135,7 @@ class TestPatternOptimizer: ...@@ -135,7 +135,7 @@ class TestPatternOptimizer:
# replacing inside the graph # replacing inside the graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op1, '2', '1')).optimize(g) (op1, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]" assert str(g) == "[Op1(Op1(y, x), z)]"
...@@ -146,7 +146,7 @@ class TestPatternOptimizer: ...@@ -146,7 +146,7 @@ class TestPatternOptimizer:
# it should do the replacement and stop # it should do the replacement and stop
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op2, '2', '1'), ign=True).optimize(g) (op2, '2', '1'), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]" assert str(g) == "[Op1(Op2(y, x), z)]"
...@@ -155,7 +155,7 @@ class TestPatternOptimizer: ...@@ -155,7 +155,7 @@ class TestPatternOptimizer:
# it should replace all occurrences of the pattern # it should replace all occurrences of the pattern
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(y, z)) e = op1(op2(x, y), op2(x, y), op2(y, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op4, '1')).optimize(g) (op4, '1')).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]" assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
...@@ -165,7 +165,7 @@ class TestPatternOptimizer: ...@@ -165,7 +165,7 @@ class TestPatternOptimizer:
# should work # should work
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(x)))) e = op1(op1(op1(op1(x))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, '1')), PatternOptimizer((op1, (op1, '1')),
'1').optimize(g) '1').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
...@@ -173,7 +173,7 @@ class TestPatternOptimizer: ...@@ -173,7 +173,7 @@ class TestPatternOptimizer:
def test_nested_odd(self): def test_nested_odd(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op1, '1')), PatternOptimizer((op1, (op1, '1')),
'1').optimize(g) '1').optimize(g)
assert str(g) == "[Op1(x)]" assert str(g) == "[Op1(x)]"
...@@ -181,7 +181,7 @@ class TestPatternOptimizer: ...@@ -181,7 +181,7 @@ class TestPatternOptimizer:
def test_expand(self): def test_expand(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, '1'), PatternOptimizer((op1, '1'),
(op2, (op1, '1')), ign=True).optimize(g) (op2, (op1, '1')), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]" assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
...@@ -192,7 +192,7 @@ class TestPatternOptimizer: ...@@ -192,7 +192,7 @@ class TestPatternOptimizer:
# = True or with other NavigatorOptimizers may differ. # = True or with other NavigatorOptimizers may differ.
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
TopoPatternOptimizer((op1, (op1, '1')), TopoPatternOptimizer((op1, (op1, '1')),
(op1, '1'), ign=False).optimize(g) (op1, '1'), ign=False).optimize(g)
assert str(g) == "[Op1(x)]" assert str(g) == "[Op1(x)]"
...@@ -202,7 +202,7 @@ class TestPatternOptimizer: ...@@ -202,7 +202,7 @@ class TestPatternOptimizer:
y = MyVariable('y') y = MyVariable('y')
z = Constant(MyType(), 2, name='z') z = Constant(MyType(), 2, name='z')
e = op1(op1(x, y), y) e = op1(op1(x, y), y)
g = Env([y], [e]) g = FunctionGraph([y], [e])
PatternOptimizer((op1, z, '1'), PatternOptimizer((op1, z, '1'),
(op2, '1', z)).optimize(g) (op2, '1', z)).optimize(g)
assert str(g) == "[Op1(Op2(y, z), y)]" assert str(g) == "[Op1(Op2(y, z), y)]"
...@@ -210,7 +210,7 @@ class TestPatternOptimizer: ...@@ -210,7 +210,7 @@ class TestPatternOptimizer:
def test_constraints(self): def test_constraints(self):
x, y, z = inputs() x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y))) e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
def constraint(r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
...@@ -223,7 +223,7 @@ class TestPatternOptimizer: ...@@ -223,7 +223,7 @@ class TestPatternOptimizer:
def test_match_same(self): def test_match_same(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(x, x) e = op1(x, x)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, 'x', 'y'), PatternOptimizer((op1, 'x', 'y'),
(op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(x, x)]" assert str(g) == "[Op3(x, x)]"
...@@ -231,7 +231,7 @@ class TestPatternOptimizer: ...@@ -231,7 +231,7 @@ class TestPatternOptimizer:
def test_match_same_illegal(self): def test_match_same_illegal(self):
x, y, z = inputs() x, y, z = inputs()
e = op2(op1(x, x), op1(x, y)) e = op2(op1(x, x), op1(x, y))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
def constraint(r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
...@@ -245,7 +245,7 @@ class TestPatternOptimizer: ...@@ -245,7 +245,7 @@ class TestPatternOptimizer:
x, y, z = inputs() x, y, z = inputs()
e0 = op1(x, y) e0 = op1(x, y)
e = op3(op4(e0), e0) e = op3(op4(e0), e0)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, 'x', 'y')), PatternOptimizer((op4, (op1, 'x', 'y')),
(op3, 'x', 'y')).optimize(g) (op3, 'x', 'y')).optimize(g)
assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]" assert str(g) == "[Op3(Op4(*1 -> Op1(x, y)), *1)]"
...@@ -254,7 +254,7 @@ class TestPatternOptimizer: ...@@ -254,7 +254,7 @@ class TestPatternOptimizer:
# replacing the whole graph # replacing the whole graph
x, y, z = inputs() x, y, z = inputs()
e = op1(op_y(x, y), z) e = op1(op_y(x, y), z)
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op1, (op_z, '1', '2'), '3'), PatternOptimizer((op1, (op_z, '1', '2'), '3'),
(op4, '3', '2')).optimize(g) (op4, '3', '2')).optimize(g)
str_g = str(g) str_g = str(g)
...@@ -265,7 +265,7 @@ class TestPatternOptimizer: ...@@ -265,7 +265,7 @@ class TestPatternOptimizer:
# x, y, z = inputs() # x, y, z = inputs()
# e0 = op1(x, y) # e0 = op1(x, y)
# e = op4(e0, e0) # e = op4(e0, e0)
# g = Env([x, y, z], [e]) # g = FunctionGraph([x, y, z], [e])
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')), # PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g) # (op3, 'x', 'y')).optimize(g)
# assert str(g) == "[Op3(x, y)]" # assert str(g) == "[Op3(x, y)]"
...@@ -280,24 +280,37 @@ class TestOpSubOptimizer: ...@@ -280,24 +280,37 @@ class TestOpSubOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op1, op2).optimize(g) OpSubOptimizer(op1, op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]" assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_straightforward_2(self): def test_straightforward_2(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x), op3(y), op4(z)) e = op1(op2(x), op3(y), op4(z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpSubOptimizer(op3, op4).optimize(g) OpSubOptimizer(op3, op4).optimize(g)
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]" assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
class NoInputOp(Op):
__props__ = ('param',)
def __init__(self, param):
self.param = param
def make_node(self):
return Apply(self, [], [MyType()()])
def perform(self, node, inputs, output_storage):
output_storage[0][0] = self.param
class TestMergeOptimizer: class TestMergeOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]" assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
...@@ -306,7 +319,7 @@ class TestMergeOptimizer: ...@@ -306,7 +319,7 @@ class TestMergeOptimizer:
y = Constant(MyType(), 2, name='y') y = Constant(MyType(), 2, name='y')
z = Constant(MyType(), 2, name='z') z = Constant(MyType(), 2, name='z')
e = op1(op2(x, y), op2(x, y), op2(x, z)) e = op1(op2(x, y), op2(x, y), op2(x, z))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \ assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
...@@ -315,14 +328,14 @@ class TestMergeOptimizer: ...@@ -315,14 +328,14 @@ class TestMergeOptimizer:
def test_deep_merge(self): def test_deep_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]" assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]"
def test_no_merge(self): def test_no_merge(self):
x, y, z = inputs() x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x))) e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]" assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]"
...@@ -330,7 +343,7 @@ class TestMergeOptimizer: ...@@ -330,7 +343,7 @@ class TestMergeOptimizer:
x, y, z = inputs() x, y, z = inputs()
e1 = op3(op2(x, y)) e1 = op3(op2(x, y))
e2 = op3(op2(x, y)) e2 = op3(op2(x, y))
g = Env([x, y, z], [e1, e2]) g = FunctionGraph([x, y, z], [e1, e2])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]" assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
...@@ -339,7 +352,7 @@ class TestMergeOptimizer: ...@@ -339,7 +352,7 @@ class TestMergeOptimizer:
e1 = op1(x, y) e1 = op1(x, y)
e2 = op2(op3(x), y, z) e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2)) e = op1(e1, op4(e2, e1), op1(e2))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if # note: graph.as_string can only produce the following two possibilities, but if
...@@ -357,7 +370,7 @@ class TestMergeOptimizer: ...@@ -357,7 +370,7 @@ class TestMergeOptimizer:
e1 = op1(y, z) e1 = op1(y, z)
finally: finally:
config.compute_test_value = ctv_backup config.compute_test_value = ctv_backup
g = Env([x, y, z], [e1]) g = FunctionGraph([x, y, z], [e1])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
...@@ -367,7 +380,7 @@ class TestMergeOptimizer: ...@@ -367,7 +380,7 @@ class TestMergeOptimizer:
x1 = T.matrix('x1') x1 = T.matrix('x1')
x2 = T.matrix('x2') x2 = T.matrix('x2')
e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2) e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [e]) g = FunctionGraph([x1, x2], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 4 strref = '''Elemwise{add,no_inplace} [@A] '' 4
...@@ -391,7 +404,7 @@ class TestMergeOptimizer: ...@@ -391,7 +404,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3') x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\ e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2) T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref1 = '''Elemwise{add,no_inplace} [@A] '' 6 strref1 = '''Elemwise{add,no_inplace} [@A] '' 6
...@@ -434,7 +447,7 @@ class TestMergeOptimizer: ...@@ -434,7 +447,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3') x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\ e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all()))
g = Env([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7 strref = '''Elemwise{add,no_inplace} [@A] '' 7
...@@ -463,7 +476,7 @@ class TestMergeOptimizer: ...@@ -463,7 +476,7 @@ class TestMergeOptimizer:
x3 = T.matrix('x3') x3 = T.matrix('x3')
e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\ e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\
T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2)
g = Env([x1, x2, x3], [e]) g = FunctionGraph([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7 strref = '''Elemwise{add,no_inplace} [@A] '' 7
...@@ -485,13 +498,25 @@ class TestMergeOptimizer: ...@@ -485,13 +498,25 @@ class TestMergeOptimizer:
print(strg) print(strg)
assert strg == strref, (strg, strref) assert strg == strref, (strg, strref)
def test_merge_noinput(self):
# Check that identical Apply nodes without inputs will be merged
x = NoInputOp(param=0)()
y = NoInputOp(param=0)()
z = NoInputOp(param=1)()
fg = FunctionGraph([], [x, y, z])
MergeOptimizer().optimize(fg)
no_input_ops = [n for n in fg.apply_nodes
if isinstance(n.op, NoInputOp)]
assert len(no_input_ops) == 2, fg.apply_nodes
class TestEquilibrium(object): class TestEquilibrium(object):
def test_1(self): def test_1(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')), [PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
...@@ -506,7 +531,7 @@ class TestEquilibrium(object): ...@@ -506,7 +531,7 @@ class TestEquilibrium(object):
def test_2(self): def test_2(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op1(op1(op3(x, y))) e = op1(op1(op3(x, y)))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
opt = EquilibriumOptimizer( opt = EquilibriumOptimizer(
[PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')), [PatternSub((op1, (op2, 'x', 'y')), (op4, 'x', 'y')),
...@@ -522,7 +547,7 @@ class TestEquilibrium(object): ...@@ -522,7 +547,7 @@ class TestEquilibrium(object):
def test_low_use_ratio(self): def test_low_use_ratio(self):
x, y, z = map(MyVariable, 'xyz') x, y, z = map(MyVariable, 'xyz')
e = op3(op4(x, y)) e = op3(op4(x, y))
g = Env([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print 'before', g # print 'before', g
# display pesky warnings along with stdout # display pesky warnings along with stdout
# also silence logger for 'theano.gof.opt' # also silence logger for 'theano.gof.opt'
......
...@@ -68,6 +68,19 @@ def test_dnn_conv_desc_merge(): ...@@ -68,6 +68,19 @@ def test_dnn_conv_desc_merge():
assert d1 == d2 assert d1 == d2
def test_dnn_pool_desc_merge():
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
x = theano.tensor.ftensor4('x')
y = dnn.dnn_pool(x, (2, 2))
z = dnn.dnn_pool(x, (2, 2))
f = theano.function([x], [y, z])
descs = [n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, dnn.GpuDnnPoolDesc)]
assert len(descs) == 1, f.maker.fgraph
def test_dnn_conv_merge(): def test_dnn_conv_merge():
"""This test that we merge correctly multiple dnn_conv. """This test that we merge correctly multiple dnn_conv.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论