提交 0fdce438 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Still merge identical Asserts while Assert merging is disabled.

Although the fancy merging of conditions is disabled, identical Asserts can still be combined. See #3344.
上级 89849eb7
......@@ -615,6 +615,7 @@ class MergeFeature(object):
# Put all clients of Assert inputs (if exist) into merge_candidates
# TODO: Deactivated for now as this cause cycle in the graph.
# (There is a second deactivation part below.)
for i in []: # node.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
......@@ -647,7 +648,8 @@ class MergeFeature(object):
# Get input list of the candidate with assert removed
cand_inputs_assert_removed = []
for i in candidate.inputs:
# TODO: Deactivated while Assert merging is disabled. (See above and below.)
for i in []: # candidate.inputs:
if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert):
cand_has_assert = True
......@@ -655,6 +657,11 @@ class MergeFeature(object):
else:
cand_inputs_assert_removed.append(i)
# TODO: Remove this when Assert merging is re-enabled. (See above.)
# Without Assert merging we can still look for identical Asserts,
# so we should not treat Asserts separately for now.
cand_inputs_assert_removed = candidate.inputs
# Get input list of the node with assert removed
if node_has_assert:
node_inputs_assert_removed = []
......
......@@ -399,6 +399,30 @@ class TestMergeOptimizer:
'''
assert strg == strref, (strg, strref)
def test_both_assert_merge_identical(self):
# Merge two nodes, both have assert on the same node
# with the same conditions.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
e = T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = FunctionGraph([x1, x2], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [id A] '' 4
|dot [id B] '' 3
| |Assert{msg='Theano Assert failed!'} [id C] '' 2
| | |x1 [id D]
| | |All [id E] '' 1
| | |Elemwise{gt,no_inplace} [id F] '' 0
| | |x1 [id D]
| | |x2 [id G]
| |x2 [id G]
|dot [id B] '' 3
'''
# print(strg)
assert strg == strref, (strg, strref)
def est_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node
# with different conditions.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论