提交 8cff931b authored 作者: nouiz's avatar nouiz

Merge pull request #731 from larseeri/shape_elemwise

Add infer_shape test for elemwise, dimshuffle and carreduce op.
...@@ -24,7 +24,7 @@ def FunctionGraph(i, o): ...@@ -24,7 +24,7 @@ def FunctionGraph(i, o):
return e return e
class test_DimShuffle(unittest.TestCase): class test_DimShuffle(unittest_tools.InferShapeTester):
def with_linker(self, linker): def with_linker(self, linker):
for xsh, shuffle, zsh in [((2, 3), (1, 'x', 0), (3, 1, 2)), for xsh, shuffle, zsh in [((2, 3), (1, 'x', 0), (3, 1, 2)),
...@@ -74,6 +74,24 @@ class test_DimShuffle(unittest.TestCase): ...@@ -74,6 +74,24 @@ class test_DimShuffle(unittest.TestCase):
# But This will test DimShuffle c code # But This will test DimShuffle c code
self.with_linker(gof.OpWiseCLinker()) self.with_linker(gof.OpWiseCLinker())
def test_infer_shape(self):
for xsh, shuffle in [((2, 3), (1, 'x', 0)),
((1, 2, 3), (1, 2)),
((1, 2, 1, 3), (1, 3)),
((2, 3, 4), (2, 1, 0)),
((2, 3, 4), ('x', 2, 1, 0, 'x')),
((1, 4, 3, 2, 1), (3, 2, 1)),
((1, 1, 4), (1, 2)),
((1, 1, 1), ()),
((1,), ('x', 'x'))]:
ib = [(entry == 1) for entry in xsh]
adtens = TensorType('float64', ib)('x')
adtens_val = numpy.ones(xsh)
self._compile_and_check([adtens],
[DimShuffle(ib, shuffle)(adtens)],
[adtens_val], DimShuffle)
class test_Broadcast(unittest.TestCase): class test_Broadcast(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -184,9 +202,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -184,9 +202,7 @@ class test_Broadcast(unittest.TestCase):
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
class test_CAReduce(unittest.TestCase): class test_CAReduce(unittest_tools.InferShapeTester):
def setUp(self):
unittest_tools.seed_rng()
def with_linker(self, linker, scalar_op=scalar.add, dtype="floatX", def with_linker(self, linker, scalar_op=scalar.add, dtype="floatX",
test_nan=False, tensor_op=None): test_nan=False, tensor_op=None):
...@@ -393,6 +409,31 @@ class test_CAReduce(unittest.TestCase): ...@@ -393,6 +409,31 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype, self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype,
test_nan=True) test_nan=True)
def test_infer_shape(self):
for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)),
((5, 6), (0, )),
((5, 6), (1, )),
((5, 6), (-1, )),
((5, 6), (-2, )),
((2, 3, 4, 5), (0, 1, 3)),
((2, 3, 4, 5), (-2, -3)),
((5, 0), None),
((5, 0), (0, )),
((5, 0), (1, )),
((5, 6), ()),
((5, 0), ()),
((), None),
((), ())]:
dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
if tosum is None:
tosum = range(len(xsh))
xv = numpy.asarray(numpy.random.rand(*xsh))
self._compile_and_check([x],
[CAReduce(add, axis=tosum)(x)],
[xv], CAReduce, ["local_cut_useless_reduce"])
class test_Prod(unittest.TestCase): class test_Prod(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -745,6 +786,32 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -745,6 +786,32 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
idx += 1 idx += 1
class TestElemwise(unittest_tools.InferShapeTester):
def test_infer_shape(self):
for s_left, s_right in [((5, 6), (5, 6)),
((5, 6), (5, 1)),
((5, 6), (1, 6)),
((5, 1), (5, 6)),
((1, 6), (5, 6)),
((2, 3, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 5), (2, 3, 1, 5)),
((2, 3, 4, 5), (1, 3, 4, 5)),
((2, 1, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 1), (2, 3, 4, 5))]:
dtype = theano.config.floatX
t_left = TensorType(dtype, [(entry == 1) for entry in s_left])()
t_right = TensorType(dtype, [(entry == 1) for entry in s_right])()
t_left_val = numpy.zeros(s_left)
t_right_val = numpy.zeros(s_right)
self._compile_and_check([t_left, t_right],
[Elemwise(add)(t_left, t_right)],
[t_left_val, t_right_val], Elemwise)
"""
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')]) suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
...@@ -752,3 +819,14 @@ if __name__ == '__main__': ...@@ -752,3 +819,14 @@ if __name__ == '__main__':
#suite.addTest(test_Prod('test_prod_without_zeros')) #suite.addTest(test_Prod('test_prod_without_zeros'))
#suite.addTest(test_Prod('test_other_grad_tests')) #suite.addTest(test_Prod('test_other_grad_tests'))
unittest.TextTestRunner().run(suite) unittest.TextTestRunner().run(suite)
"""
if __name__ == '__main__':
t = TestElemwise('setUp')
t.setUp()
t.test_infer_shape()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论