提交 c9565520 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5309 from gvtulder/f-merge-identical-assert

Still merge identical Asserts while Assert merging is disabled
...@@ -615,6 +615,7 @@ class MergeFeature(object): ...@@ -615,6 +615,7 @@ class MergeFeature(object):
# Put all clients of Assert inputs (if exist) into merge_candidates # Put all clients of Assert inputs (if exist) into merge_candidates
# TODO: Deactivated for now as this cause cycle in the graph. # TODO: Deactivated for now as this cause cycle in the graph.
# (There is a second deactivation part below.)
for i in []: # node.inputs: for i in []: # node.inputs:
if i.owner and isinstance(i.owner.op, if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert): theano.tensor.opt.Assert):
...@@ -647,7 +648,8 @@ class MergeFeature(object): ...@@ -647,7 +648,8 @@ class MergeFeature(object):
# Get input list of the candidate with assert removed # Get input list of the candidate with assert removed
cand_inputs_assert_removed = [] cand_inputs_assert_removed = []
for i in candidate.inputs: # TODO: Deactivated while Assert merging is disabled. (See above and below.)
for i in []: # candidate.inputs:
if i.owner and isinstance(i.owner.op, if i.owner and isinstance(i.owner.op,
theano.tensor.opt.Assert): theano.tensor.opt.Assert):
cand_has_assert = True cand_has_assert = True
...@@ -655,6 +657,11 @@ class MergeFeature(object): ...@@ -655,6 +657,11 @@ class MergeFeature(object):
else: else:
cand_inputs_assert_removed.append(i) cand_inputs_assert_removed.append(i)
# TODO: Remove this when Assert merging is re-enabled. (See above.)
# Without Assert merging we can still look for identical Asserts,
# so we should not treat Asserts separately for now.
cand_inputs_assert_removed = candidate.inputs
# Get input list of the node with assert removed # Get input list of the node with assert removed
if node_has_assert: if node_has_assert:
node_inputs_assert_removed = [] node_inputs_assert_removed = []
......
...@@ -399,6 +399,30 @@ class TestMergeOptimizer: ...@@ -399,6 +399,30 @@ class TestMergeOptimizer:
''' '''
assert strg == strref, (strg, strref) assert strg == strref, (strg, strref)
def test_both_assert_merge_identical(self):
# Merge two nodes, both have assert on the same node
# with the same conditions.
x1 = T.matrix('x1')
x2 = T.matrix('x2')
e = T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2) +\
T.dot(T.opt.assert_op(x1, (x1 > x2).all()), x2)
g = FunctionGraph([x1, x2], [e])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''Elemwise{add,no_inplace} [id A] '' 4
|dot [id B] '' 3
| |Assert{msg='Theano Assert failed!'} [id C] '' 2
| | |x1 [id D]
| | |All [id E] '' 1
| | |Elemwise{gt,no_inplace} [id F] '' 0
| | |x1 [id D]
| | |x2 [id G]
| |x2 [id G]
|dot [id B] '' 3
'''
# print(strg)
assert strg == strref, (strg, strref)
def est_both_assert_merge_1(self): def est_both_assert_merge_1(self):
# Merge two nodes, both have assert on the same node # Merge two nodes, both have assert on the same node
# with different conditions. # with different conditions.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论