提交 9685f1d1 authored 作者: Ziye Fan's avatar Ziye Fan

add 3 new test cases for different cases of merging asserts; bug fix

上级 01e2ff6a
...@@ -515,6 +515,7 @@ class MergeFeature(object): ...@@ -515,6 +515,7 @@ class MergeFeature(object):
if node in self.nodes_seen: if node in self.nodes_seen:
return return
# import ipdb;ipdb.set_trace()
node_has_assert = False node_has_assert = False
# These asserts ensure that the fgraph has set the clients field # These asserts ensure that the fgraph has set the clients field
...@@ -533,6 +534,14 @@ class MergeFeature(object): ...@@ -533,6 +534,14 @@ class MergeFeature(object):
node_has_assert = True node_has_assert = True
assert_clients = [c for (c, _) in i.owner.inputs[0].clients assert_clients = [c for (c, _) in i.owner.inputs[0].clients
if c in self.nodes_seen] if c in self.nodes_seen]
for idx in range(len(assert_clients)):
client = assert_clients[idx]
if isinstance(i.owner.op, theano.tensor.opt.Assert):
for c in client.outputs[0].clients:
if c[0] in self.nodes_seen:
assert_clients.append(c[0])
merge_candidates.extend(assert_clients) merge_candidates.extend(assert_clients)
else: else:
merge_candidates = [] merge_candidates = []
...@@ -556,11 +565,6 @@ class MergeFeature(object): ...@@ -556,11 +565,6 @@ class MergeFeature(object):
else: else:
cand_inputs_assert_removed.append(i) cand_inputs_assert_removed.append(i)
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
if node_has_assert and not cand_has_assert:
continue
# Get input list of the node with assert removed # Get input list of the node with assert removed
if node_has_assert: if node_has_assert:
node_inputs_assert_removed = [] node_inputs_assert_removed = []
...@@ -589,6 +593,13 @@ class MergeFeature(object): ...@@ -589,6 +593,13 @@ class MergeFeature(object):
pairs = list(zip(node.outputs, pairs = list(zip(node.outputs,
candidate.outputs, candidate.outputs,
['merge'] * len(node.outputs))) ['merge'] * len(node.outputs)))
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
elif node_has_assert and not cand_has_assert:
pairs = list(zip(candidate.outputs,
node.outputs,
['merge'] * len(node.outputs)))
else: else:
new_inputs = self.get_merged_assert_input(node, candidate) new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs) new_node = node.op(*new_inputs)
...@@ -626,17 +637,15 @@ class MergeFeature(object): ...@@ -626,17 +637,15 @@ class MergeFeature(object):
isinstance(cand_i.owner.op, isinstance(cand_i.owner.op,
theano.tensor.opt.Assert)): theano.tensor.opt.Assert)):
# Here two assert nodes are merged. # Here two assert nodes are merged.
# Step 1. Check if two conditions the same one # Step 1. Merge conditions of both assert nodes.
# Step 2. Combine the two with T.and_(a, b) # Step 2. Make the new assert node
node_cond = node_i.owner.inputs[1] node_cond = node_i.owner.inputs[1:]
cand_cond = cand_i.owner.inputs[1] cand_cond = cand_i.owner.inputs[1:]
if node_cond.owner is cand_cond.owner: new_cond = list(set(node_cond + cand_cond))
new_inputs.append(cand_i) new_inputs.append(
else: theano.tensor.opt.assert_op(
new_inputs.append( node_i.owner.inputs[0],
theano.tensor.opt.assert_op( *new_cond))
node_i.owner.inputs[0],
theano.tensor.and_(node_cond, cand_cond)))
# node_i is assert, cand_i is not assert # node_i is assert, cand_i is not assert
else: else:
......
...@@ -67,6 +67,9 @@ class MyOp(Op): ...@@ -67,6 +67,9 @@ class MyOp(Op):
else: else:
return id(self) return id(self)
def __gt__(self):
return True
op1 = MyOp('Op1') op1 = MyOp('Op1')
op2 = MyOp('Op2') op2 = MyOp('Op2')
...@@ -363,40 +366,84 @@ class TestMergeOptimizer: ...@@ -363,40 +366,84 @@ class TestMergeOptimizer:
strg = str(g) strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
def test_assert_merge(self): def test_one_assert_merge(self):
# Merge two nodes, one has assert, the other not.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 4
|dot [@B] '' 3
| |Assert{msg='Theano Assert failed!'} [@C] '' 2
| | |x1 [@D]
| | |All [@E] '' 1
| | |Elemwise{gt,no_inplace} [@F] '' 0
| | |x1 [@D]
| | |x2 [@G]
| |x2 [@G]
|dot [@B] '' 3
'''
assert strg == strref
def test_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node
# with different conditions.
x1 = T.matrix('x1') x1 = T.matrix('x1')
x2 = T.matrix('x2') x2 = T.matrix('x2')
y1 = T.opt.assert_op(x1, (x1 < 0).all()) +\ x3 = T.matrix('x3')
T.opt.assert_op(x2, (x1 < 0).all()) e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
y2 = T.opt.assert_op(x1, (x2 > 0).all()) + x2 T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [y1, y2]) g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str') strg = theano.printing.debugprint(g, file='str')
strref = ''' strref = '''Elemwise{add,no_inplace} [@A] '' 6
Elemwise{add,no_inplace} [@A] '' 9 |dot [@B] '' 5
|Assert{msg='Theano Assert failed!'} [@B] '' 8 | |Assert{msg='Theano Assert failed!'} [@C] '' 4
| |x1 [@C] | | |x1 [@D]
| |Elemwise{and_,no_inplace} [@D] '' 7 | | |All [@E] '' 3
| |Elemwise{and_,no_inplace} [@E] '' 6 | | | |Elemwise{gt,no_inplace} [@F] '' 1
| | |All [@F] '' 3 | | | |x1 [@D]
| | | |Elemwise{lt,no_inplace} [@G] '' 1 | | | |x3 [@G]
| | | |x1 [@C] | | |All [@H] '' 2
| | | |DimShuffle{x,x} [@H] '' 0 | | |Elemwise{gt,no_inplace} [@I] '' 0
| | | |TensorConstant{0} [@I] | | |x1 [@D]
| | |All [@J] '' 4 | | |x2 [@J]
| | |Elemwise{gt,no_inplace} [@K] '' 2 | |x2 [@J]
| | |x2 [@L] |dot [@B] '' 5
| | |DimShuffle{x,x} [@H] '' 0
| |All [@J] '' 4
|Assert{msg='Theano Assert failed!'} [@M] '' 5
|x2 [@L]
|All [@F] '' 3
Elemwise{add,no_inplace} [@A] '' 9
''' '''
print(strg) # print(strg)
print(strref) assert strg == strref
assert strg.strip() == strref.strip()
def test_both_assert_merge_2(self):
# Merge two nodes, both have assert on different node
x1 = T.matrix('x1')
x2 = T.matrix('x2')
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all()))
g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7
|dot [@B] '' 6
| |Assert{msg='Theano Assert failed!'} [@C] '' 5
| | |x1 [@D]
| | |All [@E] '' 3
| | |Elemwise{gt,no_inplace} [@F] '' 1
| | |x1 [@D]
| | |x3 [@G]
| |Assert{msg='Theano Assert failed!'} [@H] '' 4
| |x2 [@I]
| |All [@J] '' 2
| |Elemwise{gt,no_inplace} [@K] '' 0
| |x2 [@I]
| |x3 [@G]
|dot [@B] '' 6
'''
# print(strg)
assert strg == strref
class TestEquilibrium(object): class TestEquilibrium(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论