提交 28fd9366 authored 作者: Samira Shabanian's avatar Samira Shabanian

Added test for keepdims in norm function

上级 8239717a
......@@ -105,3 +105,29 @@ class TestKeepDims(unittest.TestCase):
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
def test_norm(self):
x = tensor.dtensor3()
a = numpy.random.rand(3, 2, 4)
mode = theano.compile.Mode(optimizer="fast_compile", linker="py")
for axis in [0, 1, 2, [0], [1], [2], None,
[0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
f = function([x], [x.norm(L=1, axis=axis, keepdims=True),
self.makeKeepDims_local(x, x.norm(L=1, axis=axis, keepdims=False), axis)
], mode=mode)
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
g = function([x], [x.norm(L=2, axis=axis, keepdims=True),
self.makeKeepDims_local(x, x.norm(L=2, axis=axis, keepdims=False), axis)
], mode=mode)
ans1, ans2 = g(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论