提交 d555ab80 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Added gradient case to test_get_rid_of_advanced_indexing_version_of_xent

上级 a414e5f3
...@@ -319,13 +319,19 @@ def test_asymptotic_32(): ...@@ -319,13 +319,19 @@ def test_asymptotic_32():
def test_get_rid_of_advanced_indexing_version_of_xent(): def test_get_rid_of_advanced_indexing_version_of_xent():
rng = numpy.random.RandomState(234343) rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3,5) x_val = rng.randn(3,5)
y_val = numpy.asarray([2,4,1]) y_val = numpy.asarray([2,4,1])
y = T.lvector('y')
x = T.dmatrix('x') x = T.dmatrix('x')
y = T.lvector('y')
expressions_to_test = [
T.sum(-T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(x))[T.arange(y.shape[0]), y])]
def assert_optimizer_worked(expr): def assert_optimizer_worked(expr):
f = theano.function([x,y], expr) f = theano.function([x,y], expr)
...@@ -333,21 +339,22 @@ def test_get_rid_of_advanced_indexing_version_of_xent(): ...@@ -333,21 +339,22 @@ def test_get_rid_of_advanced_indexing_version_of_xent():
print i, node print i, node
f(x_val, y_val) f(x_val, y_val)
assert len(f.maker.env.toposort()) == 4 assert len(f.maker.env.toposort()) == 4
for expr in [ for expr in expressions_to_test:
T.sum(-T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(x))[T.arange(y.shape[0]), y])]:
assert_optimizer_worked(expr) assert_optimizer_worked(expr)
## Gradient wrt x
for expr in expressions_to_test:
grad_x = T.grad(expr, x)
g = theano.function([x, y], grad_x)
for i, node in enumerate(g.maker.env.toposort()):
print i, node
g(x_val, y_val)
assert len(g.maker.env.toposort()) == 4
#TODO: Case
#TODO: Case with bias #TODO: Case with bias
# hint - call local_softmax_with_bias from within the other optimization # hint - call local_softmax_with_bias from within the other optimization
# hint - call the argmax push-down optimization first too # hint - call the argmax push-down optimization first too
#TODO: Case with derivative
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论