提交 0172bb18 authored 作者: Ziye Fan's avatar Ziye Fan

MergeOptimizer merge nodes with assert input

上级 2ddaca06
...@@ -503,7 +503,7 @@ class MergeFeature(object): ...@@ -503,7 +503,7 @@ class MergeFeature(object):
# we adopt convention to keep the last name # we adopt convention to keep the last name
if c.name: if c.name:
other_c.name = c.name other_c.name = c.name
self.scheduled.append([[(c, other_c)]]) self.scheduled.append([[(c, other_c, 'merge')]])
else: else:
# this is a new constant # this is a new constant
self.const_sig[c] = sig self.const_sig[c] = sig
...@@ -515,6 +515,8 @@ class MergeFeature(object): ...@@ -515,6 +515,8 @@ class MergeFeature(object):
if node in self.nodes_seen: if node in self.nodes_seen:
return return
node_has_assert = False
# These asserts ensure that the fgraph has set the clients field # These asserts ensure that the fgraph has set the clients field
# properly. # properly.
# The clients should at least contain `node` itself! # The clients should at least contain `node` itself!
...@@ -523,6 +525,15 @@ class MergeFeature(object): ...@@ -523,6 +525,15 @@ class MergeFeature(object):
assert (node, 0) in node.inputs[0].clients assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen] if c in self.nodes_seen]
# Put all clients of Assert inputs (if exist) into merge_candidates
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_has_assert = True
assert_clients = [c for (c, _) in i.owner.inputs[0].clients
if c in self.nodes_seen]
merge_candidates.extend(assert_clients)
else: else:
merge_candidates = [] merge_candidates = []
...@@ -533,19 +544,64 @@ class MergeFeature(object): ...@@ -533,19 +544,64 @@ class MergeFeature(object):
if len(node.inputs) != len(candidate.inputs): if len(node.inputs) != len(candidate.inputs):
continue continue
cand_has_assert = False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_has_assert = True
cand_inputs_assert_removed.append(i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
if node_has_assert and not cand_has_assert:
continue
# Get input list of the node with assert removed
if node_has_assert:
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_inputs_assert_removed.append(i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
else:
node_inputs_assert_removed = node.inputs
inputs_match = all(node_in is cand_in inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, for node_in, cand_in
candidate.inputs)) in zip(node_inputs_assert_removed,
cand_inputs_assert_removed))
if inputs_match and node.op == candidate.op: if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist: if (node, candidate) in self.blacklist:
# They were already tried, and there was an error # They were already tried, and there was an error
continue continue
# Schedule transfer of clients from node to candidate # replace node with candidate
pairs = list(zip(node.outputs, candidate.outputs)) if not (node_has_assert or cand_has_assert):
# Schedule transfer of clients from node to candidate
pairs = list(zip(node.outputs,
candidate.outputs,
['merge'] * len(node.outputs)))
else:
new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs)
pairs = list(zip(node.outputs,
new_node.owner.outputs,
['new_node'] * len(node.outputs))) +\
list(zip(candidate.outputs,
new_node.owner.outputs,
['new_node'] * len(node.outputs)))
# transfer names # transfer names
for node_output, cand_output in pairs: for pair in pairs:
node_output, cand_output = pair[:2]
# clobber old name with new one # clobber old name with new one
# it's arbitrary... one of the names has to go # it's arbitrary... one of the names has to go
if node_output.name: if node_output.name:
...@@ -558,6 +614,37 @@ class MergeFeature(object): ...@@ -558,6 +614,37 @@ class MergeFeature(object):
else: else:
self.nodes_seen.add(node) self.nodes_seen.add(node)
def get_merged_assert_input(self, node, candidate):
new_inputs = []
for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i is assert
if (node_i.owner and
isinstance(node_i.owner.op,
theano.tensor.opt.Assert)):
# node_i is assert, cand_i is assert
if (cand_i.owner and
isinstance(cand_i.owner.op,
theano.tensor.opt.Assert)):
# Here two assert nodes are merged.
# Step 1. Check if two conditions the same one
# Step 2. Combine the two with T.and_(a, b)
node_cond = node_i.owner.inputs[1]
cand_cond = cand_i.owner.inputs[1]
if node_cond.owner is cand_cond.owner:
new_inputs.append(cand_i)
else:
new_inputs.append(
theano.tensor.and_(node_cond, cand_cond))
# node_i is assert, cand_i is not assert
else:
new_inputs.append(node_i)
else:
# if node_i is not an assert node, append cand_i
new_inputs.append(cand_i)
return new_inputs
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """
...@@ -594,7 +681,7 @@ class MergeOptimizer(Optimizer): ...@@ -594,7 +681,7 @@ class MergeOptimizer(Optimizer):
while sched: while sched:
pairs_list = sched.pop() pairs_list = sched.pop()
success = True success = True
for pairs in pairs_list: for pairs_ in pairs_list:
# We must check again the equivalence, as the graph # We must check again the equivalence, as the graph
# can have changed. If so, doing the replacement can # can have changed. If so, doing the replacement can
# introduce node that depend on itself. Doing the # introduce node that depend on itself. Doing the
...@@ -603,17 +690,49 @@ class MergeOptimizer(Optimizer): ...@@ -603,17 +690,49 @@ class MergeOptimizer(Optimizer):
# doing the full cycle check. The full cycle check is # doing the full cycle check. The full cycle check is
# skipped by validate() if the graph don't contain # skipped by validate() if the graph don't contain
# destroyers. # destroyers.
var = pairs[0][0] var, candidate, merge_mode = pairs_[0]
candidate = pairs[0][1] if merge_mode == "new_node" and hasattr(var, 'fgraph'):
if (not hasattr(var, 'fgraph') or pass
not hasattr(candidate, 'fgraph')): elif (not hasattr(var, 'fgraph') or
not hasattr(candidate, 'fgraph')):
continue continue
# Keep len(item) == 2 for item in pairs
pairs = [pair[:2] for pair in pairs_]
if var.owner and candidate.owner: if var.owner and candidate.owner:
node = var.owner node = var.owner
candidate = candidate.owner candidate = candidate.owner
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip( # Get input list of the candidate node with assert
node.inputs, candidate.inputs)) # nodes removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_inputs_assert_removed.append(
i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_inputs_assert_removed.append(
i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
if merge_mode == "new_node":
inputs_match = True
else:
inputs_match = all(node_in is cand_in
for node_in, cand_in in
zip(node_inputs_assert_removed,
cand_inputs_assert_removed))
# No need to compare the op again, as it don't change. # No need to compare the op again, as it don't change.
if not inputs_match: if not inputs_match:
continue continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论