提交 180777ee authored 作者: Eric Larsen's avatar Eric Larsen

testing infer_shape: OP MakeVector

上级 9cc19f69
...@@ -533,6 +533,9 @@ class MakeVector(T.Op): ...@@ -533,6 +533,9 @@ class MakeVector(T.Op):
# assume that out has correct dtype. there is no cheap way to check # assume that out has correct dtype. there is no cheap way to check
out[0][...] = inputs out[0][...] = inputs
def infer_shape(self, node, ishapes):
return [(len(ishapes),)]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
# If the output is of an integer dtype, no gradient shall pass # If the output is of an integer dtype, no gradient shall pass
if 'int' in self.dtype: if 'int' in self.dtype:
......
...@@ -29,7 +29,8 @@ from theano.tensor.opt import ( ...@@ -29,7 +29,8 @@ from theano.tensor.opt import (
mul_canonizer, mul_canonizer,
out2in, out2in,
Shape_i, Shape_i,
Assert Assert,
MakeVector
) )
from theano import tensor from theano import tensor
from theano import tensor as T from theano import tensor as T
...@@ -3386,7 +3387,12 @@ class T_local_sum_dimshuffle(unittest.TestCase): ...@@ -3386,7 +3387,12 @@ class T_local_sum_dimshuffle(unittest.TestCase):
# test_local_sum_divprod_dimshuffle ((a * b) / (c * d)) # test_local_sum_divprod_dimshuffle ((a * b) / (c * d))
def test_make_vector(): class TestMakeVector(utt.InferShapeTester):
def setUp(self):
super(TestMakeVector, self).setUp()
def test_make_vector():
b = T.bscalar() b = T.bscalar()
i = T.iscalar() i = T.iscalar()
d = T.dscalar() d = T.dscalar()
...@@ -3468,6 +3474,32 @@ def test_make_vector(): ...@@ -3468,6 +3474,32 @@ def test_make_vector():
except AssertionError: except AssertionError:
pass pass
def test_infer_shape(self):
adscal = dscalar()
bdscal = dscalar()
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
discal = iscalar()
adscal_val = numpy.random.rand()
bdscal_val = numpy.random.rand()
aiscal_val = numpy.random.randint(10)
biscal_val = numpy.random.randint(10)
ciscal_val = numpy.random.randint(10)
discal_val = numpy.random.randint(10)
self._compile_and_check([adscal, aiscal],
[MakeVector('float64')(adscal, aiscal)],
[adscal_val, aiscal_val], MakeVector)
self._compile_and_check([adscal, bdscal, aiscal],
[MakeVector('float64')(adscal, bdscal, aiscal)],
[adscal_val, bdscal_val, aiscal_val], MakeVector)
self._compile_and_check([aiscal, biscal, ciscal, discal],
[MakeVector('int32')(aiscal, biscal, ciscal, discal)],
[aiscal_val, biscal_val, ciscal_val, discal_val],
MakeVector)
def test_local_join_1(): def test_local_join_1():
#test for vector #test for vector
...@@ -3684,9 +3716,9 @@ class TestShape_i(utt.InferShapeTester): ...@@ -3684,9 +3716,9 @@ class TestShape_i(utt.InferShapeTester):
if __name__ == '__main__': if __name__ == '__main__':
t = TestShape_i('setUp') t = TestMakeVector('setUp')
t.setUp() t.setUp()
t.test_perform() #t.test_perform()
t.test_infer_shape() t.test_infer_shape()
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论