提交 1e3505b4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add ability to merge apply nodes without inputs

上级 94c6aff4
...@@ -484,9 +484,11 @@ class MergeFeature(object): ...@@ -484,9 +484,11 @@ class MergeFeature(object):
# signature -> variable (for constants) # signature -> variable (for constants)
self.const_sig_inv = _metadict() self.const_sig_inv = _metadict()
# For all variables # For all Apply nodes
# Set of distinct (not mergeable) nodes # Set of distinct (not mergeable) nodes
self.nodes_seen = set() self.nodes_seen = set()
# Ordered set of distinct (not mergeable) nodes without any input
self.noinput_nodes = OrderedSet()
# Each element of scheduled is a list of list of (out, new_out) pairs. # Each element of scheduled is a list of list of (out, new_out) pairs.
# Each list of pairs represent the substitution needed to replace all # Each list of pairs represent the substitution needed to replace all
...@@ -514,6 +516,10 @@ class MergeFeature(object): ...@@ -514,6 +516,10 @@ class MergeFeature(object):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
self.process_node(fgraph, node) self.process_node(fgraph, node)
# Since we are in on_change_input, node should have inputs.
if not isinstance(node, string_types):
assert node.inputs
if isinstance(new_r, graph.Constant): if isinstance(new_r, graph.Constant):
self.process_constant(fgraph, new_r) self.process_constant(fgraph, new_r)
...@@ -526,6 +532,8 @@ class MergeFeature(object): ...@@ -526,6 +532,8 @@ class MergeFeature(object):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
if not node.inputs:
self.noinput_nodes.discard(node)
for c in node.inputs: for c in node.inputs:
if isinstance(c, graph.Constant) and (len(c.clients) <= 1): if isinstance(c, graph.Constant) and (len(c.clients) <= 1):
# This was the last node using this constant # This was the last node using this constant
...@@ -592,7 +600,10 @@ class MergeFeature(object): ...@@ -592,7 +600,10 @@ class MergeFeature(object):
merge_candidates.extend(assert_clients) merge_candidates.extend(assert_clients)
else: else:
merge_candidates = [] # If two nodes have no input, but perform the same operation,
# they are not always constant-folded, so we want to merge them.
# In that case, the candidates are all the nodes without inputs.
merge_candidates = self.noinput_nodes
replacement_candidates = [] replacement_candidates = []
for candidate in merge_candidates: for candidate in merge_candidates:
...@@ -672,6 +683,8 @@ class MergeFeature(object): ...@@ -672,6 +683,8 @@ class MergeFeature(object):
self.scheduled.append(replacement_candidates) self.scheduled.append(replacement_candidates)
else: else:
self.nodes_seen.add(node) self.nodes_seen.add(node)
if not node.inputs:
self.noinput_nodes.add(node)
def get_merged_assert_input(self, node, candidate): def get_merged_assert_input(self, node, candidate):
new_inputs = [] new_inputs = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论