提交 d3705773 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic Bastien

testing infer_shape: op DimShuffle

上级 b3205b23
...@@ -24,7 +24,7 @@ def FunctionGraph(i, o): ...@@ -24,7 +24,7 @@ def FunctionGraph(i, o):
return e return e
class test_DimShuffle(unittest.TestCase): class test_DimShuffle(unittest_tools.InferShapeTester):
def with_linker(self, linker): def with_linker(self, linker):
for xsh, shuffle, zsh in [((2, 3), (1, 'x', 0), (3, 1, 2)), for xsh, shuffle, zsh in [((2, 3), (1, 'x', 0), (3, 1, 2)),
...@@ -74,6 +74,24 @@ class test_DimShuffle(unittest.TestCase): ...@@ -74,6 +74,24 @@ class test_DimShuffle(unittest.TestCase):
# But This will test DimShuffle c code # But This will test DimShuffle c code
self.with_linker(gof.OpWiseCLinker()) self.with_linker(gof.OpWiseCLinker())
def test_infer_shape(self):
for xsh, shuffle in [((2, 3), (1, 'x', 0)),
((1, 2, 3), (1, 2)),
((1, 2, 1, 3), (1, 3)),
((2, 3, 4), (2, 1, 0)),
((2, 3, 4), ('x', 2, 1, 0, 'x')),
((1, 4, 3, 2, 1), (3, 2, 1)),
((1, 1, 4), (1, 2)),
((1, 1, 1), ()),
((1,), ('x', 'x'))]:
ib = [(entry == 1) for entry in xsh]
adtens = TensorType('float64', ib)('x')
adtens_val = numpy.ones(xsh)
self._compile_and_check([adtens],
[DimShuffle(ib, shuffle)(adtens)],
[adtens_val], DimShuffle)
class test_Broadcast(unittest.TestCase): class test_Broadcast(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -744,7 +762,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -744,7 +762,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
x) x)
idx += 1 idx += 1
"""
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')]) suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
...@@ -752,3 +770,14 @@ if __name__ == '__main__': ...@@ -752,3 +770,14 @@ if __name__ == '__main__':
#suite.addTest(test_Prod('test_prod_without_zeros')) #suite.addTest(test_Prod('test_prod_without_zeros'))
#suite.addTest(test_Prod('test_other_grad_tests')) #suite.addTest(test_Prod('test_other_grad_tests'))
unittest.TextTestRunner().run(suite) unittest.TextTestRunner().run(suite)
"""
if __name__ == '__main__':
t = test_DimShuffle('setUp')
t.setUp()
t.test_infer_shape()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论