提交 7fc5b57e authored 作者: James Bergstra's avatar James Bergstra

Added nnet unit test test_get_rid_of_advanced_indexing_version_of_xent

上级 c95da640
...@@ -317,6 +317,37 @@ def test_asymptotic_32(): ...@@ -317,6 +317,37 @@ def test_asymptotic_32():
assert gxval[0,1] == 0.25 assert gxval[0,1] == 0.25
def test_get_rid_of_advanced_indexing_version_of_xent():
rng = numpy.random.RandomState(234343)
x_val = rng.randn(3,5)
y_val = numpy.asarray([2,4,1])
y = T.lvector('y')
x = T.dmatrix('x')
def assert_optimizer_worked(expr):
f = theano.function([x,y], expr)
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f(x_val, y_val)
assert len(f.maker.env.toposort()) == 4
for expr in [
T.sum(-T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(x))[T.arange(y.shape[0]), y])]:
assert_optimizer_worked(expr)
#TODO: Case
#TODO: Case with bias
# hint - call local_softmax_with_bias from within the other optimization
# hint - call the argmax push-down optimization first too
#TODO: Case with derivative
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论