提交 e3eaaef1 authored 作者: James Bergstra's avatar James Bergstra

adjusted test of softmax optimization to account for the less-aggressive dimshuffle lift we use now

上级 58f6f6b1
...@@ -206,18 +206,21 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase): ...@@ -206,18 +206,21 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
print 'BEFORE' print 'BEFORE'
for node in env.toposort(): for node in env.toposort():
print node.op print node.op, node.inputs
print '----' print '----'
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env) theano.compile.mode.OPT_FAST_RUN).optimize(env)
print 'AFTER' print 'AFTER'
for node in env.toposort(): for node in env.toposort():
print node.op print node.op, node.inputs
# the function has 9 ops because the dimshuffle and elemwise{second} aren't getting
# cleaned up as well as we'd like.
assert env.toposort()[3].op == crossentropy_softmax_argmax_1hot_with_bias assert env.toposort()[3].op == crossentropy_softmax_argmax_1hot_with_bias
assert env.toposort()[5].op == crossentropy_softmax_1hot_with_bias_dx assert env.toposort()[8].op == crossentropy_softmax_1hot_with_bias_dx
assert len(env.toposort()) == 6 #shorthand for actually checking what I really assert len(env.toposort()) == 9 #shorthand for actually checking what I really
def test_argmax_pushdown(): def test_argmax_pushdown():
x = tensor.dmatrix() x = tensor.dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论