提交 6eda4baf authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic Bastien

testing infer_shape: op CAReduce

上级 6aedf054
...@@ -783,6 +783,32 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -783,6 +783,32 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
x) x)
idx += 1 idx += 1
class TestElemwise(unittest_tools.InferShapeTester):
def test_infer_shape(self):
for s_left, s_right in [((5, 6), (5, 6)),
((5, 6), (5, 1)),
((5, 6), (1, 6)),
((5, 1), (5, 6)),
((1, 6), (5, 6)),
((2, 3, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 5), (2, 3, 1, 5)),
((2, 3, 4, 5), (1, 3, 4, 5)),
((2, 1, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 1), (2, 3, 4, 5))]:
dtype = theano.config.floatX
t_left = TensorType(dtype, [(entry == 1) for entry in s_left])()
t_right = TensorType(dtype, [(entry == 1) for entry in s_right])()
t_left_val = numpy.zeros(s_left)
t_right_val = numpy.zeros(s_right)
self._compile_and_check([t_left, t_right],
[Elemwise(add)(t_left, t_right)],
[t_left_val, t_right_val], Elemwise)
""" """
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
...@@ -797,7 +823,7 @@ if __name__ == '__main__': ...@@ -797,7 +823,7 @@ if __name__ == '__main__':
if __name__ == '__main__': if __name__ == '__main__':
t = test_CAReduce('setUp') t = TestElemwise('setUp')
t.setUp() t.setUp()
t.test_infer_shape() t.test_infer_shape()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论