提交 652c8540 authored 作者: James Bergstra's avatar James Bergstra

added test of optimization of crossentropy with biased softmax using advanced indexing syntax

上级 3e05912e
......@@ -322,9 +322,11 @@ def test_get_rid_of_advanced_indexing_version_of_xent():
rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3,5)
b_val = rng.randn(5)
y_val = numpy.asarray([2,4,1])
x = T.dmatrix('x')
b = T.dvector('b')
y = T.lvector('y')
expressions_to_test = [
......@@ -334,9 +336,10 @@ def test_get_rid_of_advanced_indexing_version_of_xent():
T.sum(-T.log(softmax(x))[T.arange(y.shape[0]), y])]
def assert_optimizer_worked(expr):
f = theano.function([x,y], expr)
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f = theano.function([x,y], expr, mode='FAST_RUN')
if 0:
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f(x_val, y_val)
assert len(f.maker.env.toposort()) == 4
for expr in expressions_to_test:
......@@ -345,16 +348,27 @@ def test_get_rid_of_advanced_indexing_version_of_xent():
## Gradient wrt x
for expr in expressions_to_test:
grad_x = T.grad(expr, x)
g = theano.function([x, y], grad_x)
for i, node in enumerate(g.maker.env.toposort()):
print i, node
g = theano.function([x, y], grad_x, mode='FAST_RUN')
if 0:
for i, node in enumerate(g.maker.env.toposort()):
print i, node
g(x_val, y_val)
assert len(g.maker.env.toposort()) == 4
#TODO: Case with bias
# hint - call local_softmax_with_bias from within the other optimization
# hint - call the argmax push-down optimization first too
## Test that a biased softmax is optimized correctly
for expr in [
T.sum(-T.log(softmax(x+b)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(b+x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x+b))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(b+x))[T.arange(y.shape[0]), y])]:
f = theano.function([x,b,y], expr, mode='FAST_RUN')
if 0:
for i, node in enumerate(f.maker.env.toposort()):
print i, node
assert len(g.maker.env.toposort()) == 2 # [big_op, sum]
f(x_val, b_val, y_val)
# hint - call the argmax push-down optimization first too
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论