提交 ac09d883 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: op SoftmaxGrad

上级 8c488a3b
...@@ -19,7 +19,7 @@ from theano.tensor.nnet import (categorical_crossentropy, ...@@ -19,7 +19,7 @@ from theano.tensor.nnet import (categorical_crossentropy,
crossentropy_softmax_argmax_1hot_with_bias, crossentropy_softmax_argmax_1hot_with_bias,
sigmoid, softplus, sigmoid, softplus,
Softmax, softmax, SoftmaxWithBias, softmax_grad, Softmax, softmax, SoftmaxWithBias, softmax_grad,
softmax_with_bias, softmax_with_bias, SoftmaxGrad,
Prepend_scalar_constant_to_each_row, Prepend_scalar_constant_to_each_row,
Prepend_scalar_to_each_row) Prepend_scalar_to_each_row)
from theano.tensor import dmatrix, dvector from theano.tensor import dmatrix, dvector
...@@ -124,13 +124,18 @@ class T_SoftmaxWithBias(utt.InferShapeTester): ...@@ -124,13 +124,18 @@ class T_SoftmaxWithBias(utt.InferShapeTester):
self._compile_and_check([admat, advec], [SoftmaxWithBias()(admat, advec)], self._compile_and_check([admat, advec], [SoftmaxWithBias()(admat, advec)],
[admat_val, advec_val], SoftmaxWithBias) [admat_val, advec_val], SoftmaxWithBias)
class T_SoftmaxGrad(unittest.TestCase): class T_SoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
a=T.constant(numpy.random.rand(3,4))
b=T.constant(numpy.random.rand(3,4)) admat = dmatrix()
f=theano.function([],softmax_grad(a,b).shape) bdmat = dmatrix()
assert numpy.all(f()==[3,4]) admat_val = numpy.random.rand(3, 4)
bdmat_val = numpy.random.rand(3, 4)
self._compile_and_check([admat, bdmat], [SoftmaxGrad()(admat, bdmat)],
[admat_val, bdmat_val], SoftmaxGrad)
class T_CrossentropySoftmax1Hot(unittest.TestCase): class T_CrossentropySoftmax1Hot(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -1103,7 +1108,7 @@ class Test_softmax_opt: ...@@ -1103,7 +1108,7 @@ class Test_softmax_opt:
if __name__ == '__main__': if __name__ == '__main__':
t = T_SoftmaxWithBias('setUp') t = T_SoftmaxGrad('setUp')
t.setUp() t.setUp()
t.test_infer_shape() t.test_infer_shape()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论