提交 8239717a authored 作者: Bhavishya Pohani's avatar Bhavishya Pohani 提交者: Samira Shabanian

Added support for keepdims in norm function

上级 a536464a
...@@ -593,15 +593,19 @@ class _tensor_py_operators(object): ...@@ -593,15 +593,19 @@ class _tensor_py_operators(object):
dtype=dtype, keepdims=keepdims, dtype=dtype, keepdims=keepdims,
acc_dtype=acc_dtype) acc_dtype=acc_dtype)
def norm(self, L, axis=None): def norm(self, L, axis=None, keepdims=False):
if L == 0: if L == 0:
raise NotImplementedError() raise NotImplementedError()
if numpy.isinf(L): if numpy.isinf(L):
raise NotImplementedError() raise NotImplementedError()
# optimizations will/should catch cases like L=1, L=2 # optimizations will/should catch cases like L=1, L=2
return theano.tensor.basic.pow( y = theano.tensor.basic.pow(
theano.tensor.basic.pow( theano.tensor.basic.pow(
theano.tensor.basic.abs_(self), L).sum(axis=axis), 1.0 / L) theano.tensor.basic.abs_(self), L).sum(axis=axis), 1.0 / L)
if keepdims:
return theano.tensor.basic.makeKeepDims(self, y, axis)
else:
return y
def mean(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): def mean(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See `theano.tensor.mean`.""" """See `theano.tensor.mean`."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论