提交 8c488a3b authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: op SoftMaxWithBias

上级 e6ab53eb
...@@ -22,7 +22,7 @@ from theano.tensor.nnet import (categorical_crossentropy, ...@@ -22,7 +22,7 @@ from theano.tensor.nnet import (categorical_crossentropy,
softmax_with_bias, softmax_with_bias,
Prepend_scalar_constant_to_each_row, Prepend_scalar_constant_to_each_row,
Prepend_scalar_to_each_row) Prepend_scalar_to_each_row)
from theano.tensor import dmatrix from theano.tensor import dmatrix, dvector
class T_sigmoid(unittest.TestCase): class T_sigmoid(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -73,9 +73,8 @@ class T_Softmax(utt.InferShapeTester): ...@@ -73,9 +73,8 @@ class T_Softmax(utt.InferShapeTester):
utt.verify_grad(f, [numpy.random.rand(4)]) utt.verify_grad(f, [numpy.random.rand(4)])
class T_SoftmaxWithBias(unittest.TestCase): class T_SoftmaxWithBias(utt.InferShapeTester):
def setUp(self):
utt.seed_rng()
def test0(self): def test0(self):
def f(a, b): def f(a, b):
return softmax_with_bias(a, b)[:,0] return softmax_with_bias(a, b)[:,0]
...@@ -118,8 +117,13 @@ class T_SoftmaxWithBias(unittest.TestCase): ...@@ -118,8 +117,13 @@ class T_SoftmaxWithBias(unittest.TestCase):
#print f.maker.fgraph.toposort() #print f.maker.fgraph.toposort()
def test_infer_shape(self): def test_infer_shape(self):
fff=theano.function([],outputs=softmax_with_bias(numpy.random.rand(3,4),numpy.random.rand(4)).shape) admat = dmatrix()
assert all(fff()==[3,4]) advec = dvector()
admat_val = numpy.random.rand(3, 4)
advec_val = numpy.random.rand(4)
self._compile_and_check([admat, advec], [SoftmaxWithBias()(admat, advec)],
[admat_val, advec_val], SoftmaxWithBias)
class T_SoftmaxGrad(unittest.TestCase): class T_SoftmaxGrad(unittest.TestCase):
def test_infer_shape(self): def test_infer_shape(self):
...@@ -1099,7 +1103,7 @@ class Test_softmax_opt: ...@@ -1099,7 +1103,7 @@ class Test_softmax_opt:
if __name__ == '__main__': if __name__ == '__main__':
t = T_Softmax('setUp') t = T_SoftmaxWithBias('setUp')
t.setUp() t.setUp()
t.test_infer_shape() t.test_infer_shape()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论