提交 6eda4baf authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic Bastien

testing infer_shape: op CAReduce

上级 6aedf054
......@@ -783,6 +783,32 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
x)
idx += 1
class TestElemwise(unittest_tools.InferShapeTester):
def test_infer_shape(self):
for s_left, s_right in [((5, 6), (5, 6)),
((5, 6), (5, 1)),
((5, 6), (1, 6)),
((5, 1), (5, 6)),
((1, 6), (5, 6)),
((2, 3, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 5), (2, 3, 1, 5)),
((2, 3, 4, 5), (1, 3, 4, 5)),
((2, 1, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 1), (2, 3, 4, 5))]:
dtype = theano.config.floatX
t_left = TensorType(dtype, [(entry == 1) for entry in s_left])()
t_right = TensorType(dtype, [(entry == 1) for entry in s_right])()
t_left_val = numpy.zeros(s_left)
t_right_val = numpy.zeros(s_right)
self._compile_and_check([t_left, t_right],
[Elemwise(add)(t_left, t_right)],
[t_left_val, t_right_val], Elemwise)
"""
if __name__ == '__main__':
#unittest.main()
......@@ -794,10 +820,10 @@ if __name__ == '__main__':
"""
if __name__ == '__main__':
t = test_CAReduce('setUp')
t = TestElemwise('setUp')
t.setUp()
t.test_infer_shape()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论