提交 9685f1d1 authored 作者: Ziye Fan's avatar Ziye Fan

add 3 new test cases for different cases of merging asserts; bug fix

上级 01e2ff6a
......@@ -515,6 +515,7 @@ class MergeFeature(object):
if node in self.nodes_seen:
return
# import ipdb;ipdb.set_trace()
node_has_assert = False
# These asserts ensure that the fgraph has set the clients field
......@@ -533,6 +534,14 @@ class MergeFeature(object):
node_has_assert = True
assert_clients = [c for (c, _) in i.owner.inputs[0].clients
if c in self.nodes_seen]
for idx in range(len(assert_clients)):
client = assert_clients[idx]
if isinstance(i.owner.op, theano.tensor.opt.Assert):
for c in client.outputs[0].clients:
if c[0] in self.nodes_seen:
assert_clients.append(c[0])
merge_candidates.extend(assert_clients)
else:
merge_candidates = []
......@@ -556,11 +565,6 @@ class MergeFeature(object):
else:
cand_inputs_assert_removed.append(i)
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
if node_has_assert and not cand_has_assert:
continue
# Get input list of the node with assert removed
if node_has_assert:
node_inputs_assert_removed = []
......@@ -589,6 +593,13 @@ class MergeFeature(object):
pairs = list(zip(node.outputs,
candidate.outputs,
['merge'] * len(node.outputs)))
# if the current node has assert input, it should not be
# replaced with a candidate node which has no assert input
elif node_has_assert and not cand_has_assert:
pairs = list(zip(candidate.outputs,
node.outputs,
['merge'] * len(node.outputs)))
else:
new_inputs = self.get_merged_assert_input(node, candidate)
new_node = node.op(*new_inputs)
......@@ -626,17 +637,15 @@ class MergeFeature(object):
isinstance(cand_i.owner.op,
theano.tensor.opt.Assert)):
# Here two assert nodes are merged.
# Step 1. Check if two conditions the same one
# Step 2. Combine the two with T.and_(a, b)
node_cond = node_i.owner.inputs[1]
cand_cond = cand_i.owner.inputs[1]
if node_cond.owner is cand_cond.owner:
new_inputs.append(cand_i)
else:
new_inputs.append(
theano.tensor.opt.assert_op(
node_i.owner.inputs[0],
theano.tensor.and_(node_cond, cand_cond)))
# Step 1. Merge conditions of both assert nodes.
# Step 2. Make the new assert node
node_cond = node_i.owner.inputs[1:]
cand_cond = cand_i.owner.inputs[1:]
new_cond = list(set(node_cond + cand_cond))
new_inputs.append(
theano.tensor.opt.assert_op(
node_i.owner.inputs[0],
*new_cond))
# node_i is assert, cand_i is not assert
else:
......
......@@ -67,6 +67,9 @@ class MyOp(Op):
else:
return id(self)
def __gt__(self):
return True
op1 = MyOp('Op1')
op2 = MyOp('Op2')
......@@ -363,40 +366,84 @@ class TestMergeOptimizer:
strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
def test_assert_merge(self):
def test_one_assert_merge(self):
# Merge two nodes, one has assert, the other not.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
e = T.dot(x1, x2) + T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 4
|dot [@B] '' 3
| |Assert{msg='Theano Assert failed!'} [@C] '' 2
| | |x1 [@D]
| | |All [@E] '' 1
| | |Elemwise{gt,no_inplace} [@F] '' 0
| | |x1 [@D]
| | |x2 [@G]
| |x2 [@G]
|dot [@B] '' 3
'''
assert strg == strref
def test_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node
# with different conditions.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
y1 = T.opt.assert_op(x1, (x1 < 0).all()) +\
T.opt.assert_op(x2, (x1 < 0).all())
y2 = T.opt.assert_op(x1, (x2 > 0).all()) + x2
g = Env([x1, x2], [y1, y2])
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''
Elemwise{add,no_inplace} [@A] '' 9
|Assert{msg='Theano Assert failed!'} [@B] '' 8
| |x1 [@C]
| |Elemwise{and_,no_inplace} [@D] '' 7
| |Elemwise{and_,no_inplace} [@E] '' 6
| | |All [@F] '' 3
| | | |Elemwise{lt,no_inplace} [@G] '' 1
| | | |x1 [@C]
| | | |DimShuffle{x,x} [@H] '' 0
| | | |TensorConstant{0} [@I]
| | |All [@J] '' 4
| | |Elemwise{gt,no_inplace} [@K] '' 2
| | |x2 [@L]
| | |DimShuffle{x,x} [@H] '' 0
| |All [@J] '' 4
|Assert{msg='Theano Assert failed!'} [@M] '' 5
|x2 [@L]
|All [@F] '' 3
Elemwise{add,no_inplace} [@A] '' 9
strref = '''Elemwise{add,no_inplace} [@A] '' 6
|dot [@B] '' 5
| |Assert{msg='Theano Assert failed!'} [@C] '' 4
| | |x1 [@D]
| | |All [@E] '' 3
| | | |Elemwise{gt,no_inplace} [@F] '' 1
| | | |x1 [@D]
| | | |x3 [@G]
| | |All [@H] '' 2
| | |Elemwise{gt,no_inplace} [@I] '' 0
| | |x1 [@D]
| | |x2 [@J]
| |x2 [@J]
|dot [@B] '' 5
'''
print(strg)
print(strref)
assert strg.strip() == strref.strip()
# print(strg)
assert strg == strref
def test_both_assert_merge_2(self):
# Merge two nodes, both have assert on different node
x1 = T.matrix('x1')
x2 = T.matrix('x2')
x3 = T.matrix('x3')
e = T.dot(T.opt.assert_op(x1, (x1 > x3).all()), x2) +\
T.dot(x1, T.opt.assert_op(x2, (x2 > x3).all()))
g = Env([x1, x2, x3], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [@A] '' 7
|dot [@B] '' 6
| |Assert{msg='Theano Assert failed!'} [@C] '' 5
| | |x1 [@D]
| | |All [@E] '' 3
| | |Elemwise{gt,no_inplace} [@F] '' 1
| | |x1 [@D]
| | |x3 [@G]
| |Assert{msg='Theano Assert failed!'} [@H] '' 4
| |x2 [@I]
| |All [@J] '' 2
| |Elemwise{gt,no_inplace} [@K] '' 0
| |x2 [@I]
| |x3 [@G]
|dot [@B] '' 6
'''
# print(strg)
assert strg == strref
class TestEquilibrium(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论