提交 6fcd2cd0 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4418 from shabanian/norm

Norm
...@@ -105,3 +105,29 @@ class TestKeepDims(unittest.TestCase): ...@@ -105,3 +105,29 @@ class TestKeepDims(unittest.TestCase):
ans1, ans2 = f(a) ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2) assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape assert ans1.shape == ans2.shape
def test_norm(self):
x = tensor.dtensor3()
a = numpy.random.rand(3, 2, 4).astype(theano.config.floatX)
mode = theano.compile.Mode(optimizer="fast_compile", linker="py")
for axis in [0, 1, 2, [0], [1], [2], None,
[0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
f = function([x], [x.norm(L=1, axis=axis, keepdims=True),
self.makeKeepDims_local(x, x.norm(L=1, axis=axis, keepdims=False), axis)
], mode=mode)
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
g = function([x], [x.norm(L=2, axis=axis, keepdims=True),
self.makeKeepDims_local(x, x.norm(L=2, axis=axis, keepdims=False), axis)
], mode=mode)
ans1, ans2 = g(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
...@@ -596,15 +596,19 @@ class _tensor_py_operators(object): ...@@ -596,15 +596,19 @@ class _tensor_py_operators(object):
dtype=dtype, keepdims=keepdims, dtype=dtype, keepdims=keepdims,
acc_dtype=acc_dtype) acc_dtype=acc_dtype)
def norm(self, L, axis=None): def norm(self, L, axis=None, keepdims=False):
if L == 0: if L == 0:
raise NotImplementedError() raise NotImplementedError()
if numpy.isinf(L): if numpy.isinf(L):
raise NotImplementedError() raise NotImplementedError()
# optimizations will/should catch cases like L=1, L=2 # optimizations will/should catch cases like L=1, L=2
return theano.tensor.basic.pow( y = theano.tensor.basic.pow(
theano.tensor.basic.pow( theano.tensor.basic.pow(
theano.tensor.basic.abs_(self), L).sum(axis=axis), 1.0 / L) theano.tensor.basic.abs_(self), L).sum(axis=axis), 1.0 / L)
if keepdims:
return theano.tensor.basic.makeKeepDims(self, y, axis)
else:
return y
def mean(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): def mean(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See `theano.tensor.mean`.""" """See `theano.tensor.mean`."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论