提交 01e2ff6a authored 作者: Ziye Fan's avatar Ziye Fan

new test case added; fix bug (create new assert node when different condition)

上级 0172bb18
...@@ -634,7 +634,9 @@ class MergeFeature(object): ...@@ -634,7 +634,9 @@ class MergeFeature(object):
new_inputs.append(cand_i) new_inputs.append(cand_i)
else: else:
new_inputs.append( new_inputs.append(
theano.tensor.and_(node_cond, cand_cond)) theano.tensor.opt.assert_op(
node_i.owner.inputs[0],
theano.tensor.and_(node_cond, cand_cond)))
# node_i is assert, cand_i is not assert # node_i is assert, cand_i is not assert
else: else:
......
...@@ -6,6 +6,9 @@ from theano.gof.opt import * # noqa ...@@ -6,6 +6,9 @@ from theano.gof.opt import * # noqa
from theano.gof.fg import FunctionGraph as Env from theano.gof.fg import FunctionGraph as Env
from theano.gof.toolbox import * # noqa from theano.gof.toolbox import * # noqa
from theano.tensor.opt import Assert
from theano import tensor as T
def as_variable(x): def as_variable(x):
if not isinstance(x, Variable): if not isinstance(x, Variable):
...@@ -360,6 +363,41 @@ class TestMergeOptimizer: ...@@ -360,6 +363,41 @@ class TestMergeOptimizer:
strg = str(g) strg = str(g)
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
def test_assert_merge(self):
x1 = T.matrix('x1')
x2 = T.matrix('x2')
y1 = T.opt.assert_op(x1, (x1 < 0).all()) +\
T.opt.assert_op(x2, (x1 < 0).all())
y2 = T.opt.assert_op(x1, (x2 > 0).all()) + x2
g = Env([x1, x2], [y1, y2])
MergeOptimizer().optimize(g)
strg = theano.printing.debugprint(g, file='str')
strref = '''
Elemwise{add,no_inplace} [@A] '' 9
|Assert{msg='Theano Assert failed!'} [@B] '' 8
| |x1 [@C]
| |Elemwise{and_,no_inplace} [@D] '' 7
| |Elemwise{and_,no_inplace} [@E] '' 6
| | |All [@F] '' 3
| | | |Elemwise{lt,no_inplace} [@G] '' 1
| | | |x1 [@C]
| | | |DimShuffle{x,x} [@H] '' 0
| | | |TensorConstant{0} [@I]
| | |All [@J] '' 4
| | |Elemwise{gt,no_inplace} [@K] '' 2
| | |x2 [@L]
| | |DimShuffle{x,x} [@H] '' 0
| |All [@J] '' 4
|Assert{msg='Theano Assert failed!'} [@M] '' 5
|x2 [@L]
|All [@F] '' 3
Elemwise{add,no_inplace} [@A] '' 9
'''
print(strg)
print(strref)
assert strg.strip() == strref.strip()
class TestEquilibrium(object): class TestEquilibrium(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论