提交 1e3505b4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add ability to merge apply nodes without inputs

上级 94c6aff4
......@@ -484,9 +484,11 @@ class MergeFeature(object):
# signature -> variable (for constants)
self.const_sig_inv = _metadict()
# For all variables
# For all Apply nodes
# Set of distinct (not mergeable) nodes
self.nodes_seen = set()
# Ordered set of distinct (not mergeable) nodes without any input
self.noinput_nodes = OrderedSet()
# Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all
......@@ -514,6 +516,10 @@ class MergeFeature(object):
self.nodes_seen.discard(node)
self.process_node(fgraph, node)
# Since we are in on_change_input, node should have inputs.
if not isinstance(node, string_types):
assert node.inputs
if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r)
......@@ -526,6 +532,8 @@ class MergeFeature(object):
def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node)
if not node.inputs:
self.noinput_nodes.discard(node)
for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant
......@@ -592,7 +600,10 @@ class MergeFeature(object):
merge_candidates.extend(assert_clients)
else:
merge_candidates = []
# If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
# In that case, the candidates are all the nodes without inputs.
merge_candidates = self.noinput_nodes
replacement_candidates = []
for candidate in merge_candidates:
......@@ -672,6 +683,8 @@ class MergeFeature(object):
self.scheduled.append(replacement_candidates)
else:
self.nodes_seen.add(node)
if not node.inputs:
self.noinput_nodes.add(node)
def get_merged_assert_input(self, node, candidate):
new_inputs = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论