提交 ddc73246 authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for tests with outputgaurd

上级 c24938c6
...@@ -579,8 +579,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -579,8 +579,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_vector(self): def test_softmax_optimizations_vector(self):
...@@ -594,8 +593,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -594,8 +593,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias(self): def test_softmax_optimizations_w_bias(self):
...@@ -624,10 +622,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -624,10 +622,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# print node.op # print node.op
# print printing.pprint(node.outputs[0]) # print printing.pprint(node.outputs[0])
# print '====' # print '===='
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 1
assert (fgraph.outputs[0].owner.op ==
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias2(self): def test_softmax_optimizations_w_bias2(self):
...@@ -654,10 +650,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -654,10 +650,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
# print '====' # print '===='
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias_vector(self): def test_softmax_optimizations_w_bias_vector(self):
...@@ -681,9 +676,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -681,9 +676,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
# print '====' # print '===='
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard' assert (fgraph.outputs[0].owner.op ==
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_grad_optimizations(self): def test_softmax_grad_optimizations(self):
...@@ -1338,9 +1332,8 @@ def test_argmax_pushdown(): ...@@ -1338,9 +1332,8 @@ def test_argmax_pushdown():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 2 # an output_guard is second assert len(fgraph.toposort()) == 1 # an output_guard is second
assert fgraph.toposort()[0].op == tensor.basic._argmax assert fgraph.toposort()[0].op == tensor.basic._argmax
assert str(fgraph.toposort()[1].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=tensor.basic._argmax) fgraph, ops_to_check=tensor.basic._argmax)
x = tensor.matrix() x = tensor.matrix()
...@@ -1364,12 +1357,11 @@ def test_argmax_pushdown(): ...@@ -1364,12 +1357,11 @@ def test_argmax_pushdown():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 4 # an output_guard is second assert len(fgraph.toposort()) == 3 # an output_guard is second
assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise) assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise)
assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[1].op, Softmax)
assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
...@@ -1388,10 +1380,9 @@ def test_argmax_pushdown_bias(): ...@@ -1388,10 +1380,9 @@ def test_argmax_pushdown_bias():
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax) types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 4 assert len(fgraph.toposort()) == 3
for i, type in enumerate(types_to_check): for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type) assert isinstance(fgraph.toposort()[i].op, type)
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
assert check_stack_trace(fgraph, ops_to_check=types_to_check) assert check_stack_trace(fgraph, ops_to_check=types_to_check)
x = tensor.matrix() x = tensor.matrix()
...@@ -1412,11 +1403,10 @@ def test_argmax_pushdown_bias(): ...@@ -1412,11 +1403,10 @@ def test_argmax_pushdown_bias():
# print 'AFTER' # print 'AFTER'
# for node in fgraph.toposort(): # for node in fgraph.toposort():
# print node.op # print node.op
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum) assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
assert str(fgraph.toposort()[2].op) == 'OutputGuard'
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce)) fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论