提交 42b861a0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3379 from lamblin/fix_pooldesc_merge

Enable merging of GpuDnnPoolDesc
......@@ -484,9 +484,11 @@ class MergeFeature(object):
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
# For all variables
# For all Apply nodes
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Ordered set of distinct (not mergeable) nodes without any input
self.noinput_nodes = OrderedSet()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
......@@ -514,6 +516,10 @@ class MergeFeature(object):
self.nodes_seen.discard(node)
self.process_node(fgraph, node)
# Since we are in on_change_input, node should have inputs.
if not isinstance(node, string_types):
assert node.inputs
if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r)
......@@ -526,6 +532,8 @@ class MergeFeature(object):
def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node)
if not node.inputs:
self.noinput_nodes.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
......@@ -592,7 +600,10 @@ class MergeFeature(object):
merge_candidates.extend(assert_clients)
else:
merge_candidates = []
# If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
# In that case, the candidates are all the nodes without inputs.
merge_candidates = self.noinput_nodes
replacement_candidates = []
for candidate in merge_candidates:
......@@ -672,6 +683,8 @@ class MergeFeature(object):
self.scheduled.append(replacement_candidates)
else:
self.nodes_seen.add(node)
if not node.inputs:
self.noinput_nodes.add(node)
def get_merged_assert_input(self, node, candidate):
new_inputs = []
......@@ -2217,7 +2230,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count = {}
for o in (opt.global_optimizers +
list(opt.get_local_optimizers()) +
opt.final_optimizers):
list(opt.final_optimizers)):
process_count.setdefault(o, 0)
for count in loop_process_count:
for o, v in iteritems(count):
......@@ -2246,7 +2259,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
# Skip opt that have 0 times, they probably wasn't even tried.
print(blanc + " ", ' %.3fs - %s' % (t, o), file=stream)
print(file=stream)
gf_opts = [o for o in opt.global_optimizers + opt.final_optimizers
gf_opts = [o for o in (opt.global_optimizers +
list(opt.final_optimizers))
if o.print_profile.func_code is not
Optimizer.print_profile.func_code]
if not gf_opts:
......
......@@ -68,6 +68,19 @@ def test_dnn_conv_desc_merge():
assert d1 == d2
def test_dnn_pool_desc_merge():
if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg)
x = theano.tensor.ftensor4('x')
y = dnn.dnn_pool(x, (2, 2))
z = dnn.dnn_pool(x, (2, 2))
f = theano.function([x], [y, z])
descs = [n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, dnn.GpuDnnPoolDesc)]
assert len(descs) == 1, f.maker.fgraph
def test_dnn_conv_merge():
"""This test that we merge correctly multiple dnn_conv.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论