提交 97082358 authored 作者: Reyhane Askari's avatar Reyhane Askari

added check for cycle_detection flag in test

上级 1cef6dac
......@@ -580,8 +580,13 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
if theano.config.cycle_detection == 'fast':
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
else:
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_vector(self):
x = tensor.vector('x')
......@@ -594,8 +599,13 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
if theano.config.cycle_detection == 'fast':
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
else:
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias(self):
x = tensor.matrix('x')
......@@ -623,9 +633,15 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# print node.op
# print printing.pprint(node.outputs[0])
# print '===='
assert len(fgraph.toposort()) == 1
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
if theano.config.cycle_detection == 'fast':
assert len(fgraph.toposort()) == 1
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
else:
assert len(fgraph.toposort()) == 2
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias2(self):
x = tensor.matrix('x')
......@@ -651,10 +667,15 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort():
# print node.op
# print '===='
assert len(fgraph.toposort()) == 2
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
if theano.config.cycle_detection == 'fast':
assert len(fgraph.toposort()) == 2
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
else:
assert len(fgraph.toposort()) == 3
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_optimizations_w_bias_vector(self):
x = tensor.vector('x')
......@@ -677,9 +698,15 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# for node in fgraph.toposort():
# print node.op
# print '===='
assert len(fgraph.toposort()) == 2
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
if theano.config.cycle_detection == 'fast':
assert len(fgraph.toposort()) == 2
assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
else:
assert len(fgraph.toposort()) == 3
assert str(fgraph.outputs[0].owner.op) == 'OutputGuard'
assert (fgraph.outputs[0].owner.inputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias)
def test_softmax_grad_optimizations(self):
x = tensor.matrix('x')
......@@ -1381,7 +1408,13 @@ def test_argmax_pushdown_bias():
# for node in fgraph.toposort():
# print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 3
if theano.config.cycle_detection == 'fast':
assert len(fgraph.toposort()) == 3
else:
assert len(fgraph.toposort()) == 4
assert str(fgraph.toposort()[3].op) == 'OutputGuard'
for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type)
assert check_stack_trace(fgraph, ops_to_check=types_to_check)
......@@ -1404,7 +1437,11 @@ def test_argmax_pushdown_bias():
# print 'AFTER'
# for node in fgraph.toposort():
# print node.op
assert len(fgraph.toposort()) == 2
if theano.config.cycle_detection == 'fast':
assert len(fgraph.toposort()) == 2
else:
assert len(fgraph.toposort()) == 3
assert str(fgraph.toposort()[2].op) == 'OutputGuard'
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce)
assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论