提交 c0f62796 authored 作者: abergeron's avatar abergeron

Merge pull request #3232 from t13m/merge_assert

make MergeOptimizer merge nodes with assert input
...@@ -503,7 +503,7 @@ class MergeFeature(object): ...@@ -503,7 +503,7 @@ class MergeFeature(object):
# we adopt convention to keep the last name # we adopt convention to keep the last name
if c.name: if c.name:
other_c.name = c.name other_c.name = c.name
self.scheduled.append([[(c, other_c)]]) self.scheduled.append([[(c, other_c, 'merge')]])
else: else:
# this is a new constant # this is a new constant
self.const_sig[c] = sig self.const_sig[c] = sig
...@@ -515,6 +515,8 @@ class MergeFeature(object): ...@@ -515,6 +515,8 @@ class MergeFeature(object):
if node in self.nodes_seen: if node in self.nodes_seen:
return return
node_has_assert = False
# These asserts ensure that the fgraph has set the clients field # These asserts ensure that the fgraph has set the clients field
# properly. # properly.
# The clients should at least contain `node` itself! # The clients should at least contain `node` itself!
...@@ -523,6 +525,23 @@ class MergeFeature(object): ...@@ -523,6 +525,23 @@ class MergeFeature(object):
assert (node, 0) in node.inputs[0].clients assert (node, 0) in node.inputs[0].clients
merge_candidates = [c for (c, i) in node.inputs[0].clients merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen] if c in self.nodes_seen]
# Put all clients of Assert inputs (if exist) into merge_candidates
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_has_assert = True
assert_clients = [c for (c, _) in i.owner.inputs[0].clients
if c in self.nodes_seen]
for idx in range(len(assert_clients)):
client = assert_clients[idx]
if isinstance(i.owner.op, theano.tensor.opt.Assert):
for c in client.outputs[0].clients:
if c[0] in self.nodes_seen:
assert_clients.append(c[0])
merge_candidates.extend(assert_clients)
else: else:
merge_candidates = [] merge_candidates = []
...@@ -533,19 +552,66 @@ class MergeFeature(object): ...@@ -533,19 +552,66 @@ class MergeFeature(object):
if len(node.inputs) != len(candidate.inputs): if len(node.inputs) != len(candidate.inputs):
continue continue
cand_has_assert = False
# Get input list of the candidate with assert removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_has_assert = True
cand_inputs_assert_removed.append(i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# Get input list of the node with assert removed
if node_has_assert:
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_inputs_assert_removed.append(i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
else:
node_inputs_assert_removed = node.inputs
inputs_match = all(node_in is cand_in inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, for node_in, cand_in
candidate.inputs)) in zip(node_inputs_assert_removed,
cand_inputs_assert_removed))
if inputs_match and node.op == candidate.op: if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist: if (node, candidate) in self.blacklist:
# They were already tried, and there was an error # They were already tried, and there was an error
continue continue
# Schedule transfer of clients from node to candidate # replace node with candidate
pairs = list(zip(node.outputs, candidate.outputs)) if not (node_has_assert or cand_has_assert):
# Schedule transfer of clients from node to candidate
pairs = list(zip(node.outputs,
candidate.outputs,
['merge'] * len(node.outputs)))
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
elif node_has_assert and not cand_has_assert:
pairs = list(zip(candidate.outputs,
node.outputs,
['merge'] * len(node.outputs)))
else:
new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs)
pairs = list(zip(node.outputs,
new_node.owner.outputs,
['new_node'] * len(node.outputs))) +\
list(zip(candidate.outputs,
new_node.owner.outputs,
['new_node'] * len(node.outputs)))
# transfer names # transfer names
for node_output, cand_output in pairs: for pair in pairs:
node_output, cand_output = pair[:2]
# clobber old name with new one # clobber old name with new one
# it's arbitrary... one of the names has to go # it's arbitrary... one of the names has to go
if node_output.name: if node_output.name:
...@@ -558,6 +624,37 @@ class MergeFeature(object): ...@@ -558,6 +624,37 @@ class MergeFeature(object):
else: else:
self.nodes_seen.add(node) self.nodes_seen.add(node)
def get_merged_assert_input(self, node, candidate):
new_inputs = []
for node_i, cand_i in zip(node.inputs, candidate.inputs):
# if node_i is assert
if (node_i.owner and
isinstance(node_i.owner.op,
theano.tensor.opt.Assert)):
# node_i is assert, cand_i is assert
if (cand_i.owner and
isinstance(cand_i.owner.op,
theano.tensor.opt.Assert)):
# Here two assert nodes are merged.
# Step 1. Merge conditions of both assert nodes.
# Step 2. Make the new assert node
node_cond = node_i.owner.inputs[1:]
cand_cond = cand_i.owner.inputs[1:]
new_cond = list(set(node_cond + cand_cond))
new_inputs.append(
theano.tensor.opt.assert_op(
node_i.owner.inputs[0],
*new_cond))
# node_i is assert, cand_i is not assert
else:
new_inputs.append(node_i)
else:
# if node_i is not an assert node, append cand_i
new_inputs.append(cand_i)
return new_inputs
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """
...@@ -594,7 +691,7 @@ class MergeOptimizer(Optimizer): ...@@ -594,7 +691,7 @@ class MergeOptimizer(Optimizer):
while sched: while sched:
pairs_list = sched.pop() pairs_list = sched.pop()
success = True success = True
for pairs in pairs_list: for pairs_ in pairs_list:
# We must check again the equivalence, as the graph # We must check again the equivalence, as the graph
# can have changed. If so, doing the replacement can # can have changed. If so, doing the replacement can
# introduce node that depend on itself. Doing the # introduce node that depend on itself. Doing the
...@@ -603,17 +700,49 @@ class MergeOptimizer(Optimizer): ...@@ -603,17 +700,49 @@ class MergeOptimizer(Optimizer):
# doing the full cycle check. The full cycle check is # doing the full cycle check. The full cycle check is
# skipped by validate() if the graph don't contain # skipped by validate() if the graph don't contain
# destroyers. # destroyers.
var = pairs[0][0] var, candidate, merge_mode = pairs_[0]
candidate = pairs[0][1] if merge_mode == "new_node" and hasattr(var, 'fgraph'):
if (not hasattr(var, 'fgraph') or pass
not hasattr(candidate, 'fgraph')): elif (not hasattr(var, 'fgraph') or
not hasattr(candidate, 'fgraph')):
continue continue
# Keep len(item) == 2 for item in pairs
pairs = [pair[:2] for pair in pairs_]
if var.owner and candidate.owner: if var.owner and candidate.owner:
node = var.owner node = var.owner
candidate = candidate.owner candidate = candidate.owner
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip( # Get input list of the candidate node with assert
node.inputs, candidate.inputs)) # nodes removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_inputs_assert_removed.append(
i.owner.inputs[0])
else:
cand_inputs_assert_removed.append(i)
# Get input list of the node with assert nodes removed
node_inputs_assert_removed = []
for i in node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
node_inputs_assert_removed.append(
i.owner.inputs[0])
else:
node_inputs_assert_removed.append(i)
if merge_mode == "new_node":
inputs_match = True
else:
inputs_match = all(node_in is cand_in
for node_in, cand_in in
zip(node_inputs_assert_removed,
cand_inputs_assert_removed))
# No need to compare the op again, as it don't change. # No need to compare the op again, as it don't change.
if not inputs_match: if not inputs_match:
continue continue
......
...@@ -6,6 +6,9 @@ from theano.gof.opt import * # noqa ...@@ -6,6 +6,9 @@ from theano.gof.opt import * # noqa
from theano.gof.fg import FunctionGraph as Env from theano.gof.fg import FunctionGraph as Env
from theano.gof.toolbox import * # noqa from theano.gof.toolbox import * # noqa
from theano.tensor.opt import Assert
from theano import tensor as T
def as_variable(x): def as_variable(x):
if not isinstance(x, Variable): if not isinstance(x, Variable):
...@@ -360,6 +363,129 @@ class TestMergeOptimizer: ...@@ -360,6 +363,129 @@ class TestMergeOptimizer:
strg = str(g) strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
def test_one_assert_merge(self):
# Merge two nodes, one has assert, the other not.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 4
|dot [@B] '' 3
| |Assert{msg='Theano Assert failed!'} [@C] '' 2
| | |x1 [@D]
| | |All [@E] '' 1
| | |Elemwise{gt,no_inplace} [@F] '' 0
| | |x1 [@D]
| | |x2 [@G]
| |x2 [@G]
|dot [@B] '' 3
'''
assert strg == strref, (strg, strref)
def test_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node
# with different conditions.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref1 = '''Elemwise{add,no_inplace} [@A] '' 6
|dot [@B] '' 5
| |Assert{msg='Theano Assert failed!'} [@C] '' 4
| | |x1 [@D]
| | |All [@E] '' 3
| | | |Elemwise{gt,no_inplace} [@F] '' 1
| | | |x1 [@D]
| | | |x3 [@G]
| | |All [@H] '' 2
| | |Elemwise{gt,no_inplace} [@I] '' 0
| | |x1 [@D]
| | |x2 [@J]
| |x2 [@J]
|dot [@B] '' 5
'''
strref2 = '''Elemwise{add,no_inplace} [@A] '' 6
|dot [@B] '' 5
| |Assert{msg='Theano Assert failed!'} [@C] '' 4
| | |x1 [@D]
| | |All [@E] '' 3
| | | |Elemwise{gt,no_inplace} [@F] '' 1
| | | |x1 [@D]
| | | |x2 [@G]
| | |All [@H] '' 2
| | |Elemwise{gt,no_inplace} [@I] '' 0
| | |x1 [@D]
| | |x3 [@J]
| |x2 [@G]
|dot [@B] '' 5
'''
# print(strg)
assert strg == strref1 or strg == strref2, (strg, strref1, strref2)
def test_both_assert_merge_2(self):
# Merge two nodes, both have assert on different node
x1 = T.matrix('x1')
x2 = T.matrix('x2')
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all()))
g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7
|dot [@B] '' 6
| |Assert{msg='Theano Assert failed!'} [@C] '' 5
| | |x1 [@D]
| | |All [@E] '' 3
| | |Elemwise{gt,no_inplace} [@F] '' 1
| | |x1 [@D]
| | |x3 [@G]
| |Assert{msg='Theano Assert failed!'} [@H] '' 4
| |x2 [@I]
| |All [@J] '' 2
| |Elemwise{gt,no_inplace} [@K] '' 0
| |x2 [@I]
| |x3 [@G]
|dot [@B] '' 6
'''
# print(strg)
assert strg == strref, (strg, strref)
def test_both_assert_merge_2_reverse(self):
# Test case "test_both_assert_merge_2" but in reverse order
x1 = T.matrix('x1')
x2 = T.matrix('x2')
x3 = T.matrix('x3')
e = T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all())) +\
T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2)
g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7
|dot [@B] '' 6
| |Assert{msg='Theano Assert failed!'} [@C] '' 5
| | |x1 [@D]
| | |All [@E] '' 3
| | |Elemwise{gt,no_inplace} [@F] '' 1
| | |x1 [@D]
| | |x3 [@G]
| |Assert{msg='Theano Assert failed!'} [@H] '' 4
| |x2 [@I]
| |All [@J] '' 2
| |Elemwise{gt,no_inplace} [@K] '' 0
| |x2 [@I]
| |x3 [@G]
|dot [@B] '' 6
'''
print(strg)
assert strg == strref, (strg, strref)
class TestEquilibrium(object): class TestEquilibrium(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论