提交 88b0c613 authored 作者: James Bergstra's avatar James Bergstra

upgraded test_nnet to work with output_guard ops

上级 9b7319ba
......@@ -163,7 +163,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env)
assert env.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
assert str(env.outputs[0].owner.op) == 'OutputGuard'
assert env.outputs[0].owner.inputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias(self):
x = tensor.matrix('x')
......@@ -186,9 +187,10 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env)
assert len(env.toposort()) == 1
assert len(env.toposort()) == 2
assert env.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
assert str(env.outputs[0].owner.op) == 'OutputGuard'
assert env.outputs[0].owner.inputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_grad_optimizations(self):
......@@ -249,7 +251,7 @@ def test_argmax_pushdown():
#print 'AFTER'
#for node in env.toposort():
#print node.op
assert len(env.toposort()) == 1
assert len(env.toposort()) == 2 # an output_guard is second
assert env.toposort()[0].op == tensor._max_and_argmax
def test_argmax_pushdown_bias():
......@@ -263,10 +265,14 @@ def test_argmax_pushdown_bias():
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env)
#print 'AFTER'
#for node in env.toposort():
#print node.op
assert len(env.toposort()) == 3
print 'AFTER'
for node in env.toposort():
print node.op
assert len(env.toposort()) == 4
assert isinstance(env.toposort()[0].op, tensor.DimShuffle)
assert isinstance(env.toposort()[1].op, tensor.Elemwise)
assert isinstance(env.toposort()[2].op, tensor.MaxAndArgmax)
assert str(env.toposort()[3].op) == 'OutputGuard'
def test_asymptotic_32():
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论