提交 0172bb18 authored 作者: Ziye Fan's avatar Ziye Fan

MergeOptimizer merge nodes with assert input

上级 2ddaca06
......@@ -503,7 +503,7 @@ class MergeFeature(object):
# we adopt convention to keep the last name
if c.name:
other_c.name = c.name
self.scheduled.append([[(c, other_c)]])
self.scheduled.append([[(c, other_c, 'merge')]])
else:
# this is a new constant
self.const_sig[c] = sig
......@@ -515,6 +515,8 @@ class MergeFeature(object):
if node in self.nodes_seen:
return
node_has_assert = False
# These asserts ensure that the fgraph has set the clients field
# properly.
# The clients should at least contain `node` itself!
......@@ -523,6 +525,15 @@ class MergeFeature(object):
assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
# Put all clients of Assert inputs (if exist) into merge_candidates
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_has_assert = True
assert_clients = [c for (c, _) in i.owner.inputs[0].clients
if c in self.nodes_seen]
merge_candidates.extend(assert_clients)
else:
merge_candidates = []
......@@ -533,19 +544,64 @@ class MergeFeature(object):
if len(node.inputs) != len(candidate.inputs):
continue
cand_has_assert = False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_has_assert = True
cand_inputs_assert_removed.append(i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
if node_has_assert and not cand_has_assert:
continue
# Get input list of the node with assert removed
if node_has_assert:
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_inputs_assert_removed.append(i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
else:
node_inputs_assert_removed = node.inputs
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs,
candidate.inputs))
for node_in, cand_in
in zip(node_inputs_assert_removed,
cand_inputs_assert_removed))
if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist:
# They were already tried, and there was an error
continue
# Schedule transfer of clients from node to candidate
pairs = list(zip(node.outputs, candidate.outputs))
# replace node with candidate
if not (node_has_assert or cand_has_assert):
# Schedule transfer of clients from node to candidate
pairs = list(zip(node.outputs,
candidate.outputs,
['merge'] * len(node.outputs)))
else:
new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs)
pairs = list(zip(node.outputs,
new_node.owner.outputs,
['new_node'] * len(node.outputs))) +\
list(zip(candidate.outputs,
new_node.owner.outputs,
['new_node'] * len(node.outputs)))
# transfer names
for node_output, cand_output in pairs:
for pair in pairs:
node_output, cand_output = pair[:2]
# clobber old name with new one
# it's arbitrary... one of the names has to go
if node_output.name:
......@@ -558,6 +614,37 @@ class MergeFeature(object):
else:
self.nodes_seen.add(node)
def get_merged_assert_input(self, node, candidate):
new_inputs = []
for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i is assert
if (node_i.owner and
isinstance(node_i.owner.op,
theano.tensor.opt.Assert)):
# node_i is assert, cand_i is assert
if (cand_i.owner and
isinstance(cand_i.owner.op,
theano.tensor.opt.Assert)):
# Here two assert nodes are merged.
# Step 1. Check if two conditions the same one
# Step 2. Combine the two with T.and_(a, b)
node_cond = node_i.owner.inputs[1]
cand_cond = cand_i.owner.inputs[1]
if node_cond.owner is cand_cond.owner:
new_inputs.append(cand_i)
else:
new_inputs.append(
theano.tensor.and_(node_cond, cand_cond))
# node_i is assert, cand_i is not assert
else:
new_inputs.append(node_i)
else:
# if node_i is not an assert node, append cand_i
new_inputs.append(cand_i)
return new_inputs
class MergeOptimizer(Optimizer):
"""
......@@ -594,7 +681,7 @@ class MergeOptimizer(Optimizer):
while sched:
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
for pairs_ in pairs_list:
# We must check again the equivalence, as the graph
# can have changed. If so, doing the replacement can
# introduce node that depend on itself. Doing the
......@@ -603,17 +690,49 @@ class MergeOptimizer(Optimizer):
# doing the full cycle check. The full cycle check is
# skipped by validate() if the graph don't contain
# destroyers.
var = pairs[0][0]
candidate = pairs[0][1]
if (not hasattr(var, 'fgraph') or
not hasattr(candidate, 'fgraph')):
var, candidate, merge_mode = pairs_[0]
if merge_mode == "new_node" and hasattr(var, 'fgraph'):
pass
elif (not hasattr(var, 'fgraph') or
not hasattr(candidate, 'fgraph')):
continue
# Keep len(item) == 2 for item in pairs
pairs = [pair[:2] for pair in pairs_]
if var.owner and candidate.owner:
node = var.owner
candidate = candidate.owner
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(
node.inputs, candidate.inputs))
# Get input list of the candidate node with assert
# nodes removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_inputs_assert_removed.append(
i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_inputs_assert_removed.append(
i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
if merge_mode == "new_node":
inputs_match = True
else:
inputs_match = all(node_in is cand_in
for node_in, cand_in in
zip(node_inputs_assert_removed,
cand_inputs_assert_removed))
# No need to compare the op again, as it don't change.
if not inputs_match:
continue
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论