提交 1bf87a89 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

batterie de tests pour keepdims et corrections correspondantes de basic

上级 02e78aca
......@@ -1465,10 +1465,10 @@ class _tensor_py_operators:
# We can't implement __len__ to provide a better error message.
def any(self, axis=None, keepdims=False):
return elemwise.Any(axis, keepdims)(self)
return any(self, axis=axis, keepdims=keepdims)
def all(self, axis=None, keepdims=False):
return elemwise.All(axis, keepdims)(self)
return all(self, axis=axis, keepdims=keepdims)
# Otherwise TensorVariable[:-1] does not work as Python 2.5.1 calls
# __len__ before calling __getitem__. It also does not catch the raised
......@@ -1622,7 +1622,7 @@ class _tensor_py_operators:
"""See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
def prod(self, axis=None, dtype=None, keepdims=False)
def prod(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype, keepdims=keepdims)
......@@ -2313,6 +2313,8 @@ class MaxAndArgmax(Op):
def __str__(self):
return self.__class__.__name__
_max_and_argmax = MaxAndArgmax()
......@@ -2322,8 +2324,13 @@ def makeKeepDims(x, y, axis):
in a prior reduction of x. With this option, the resulting tensor will
broadcast correctly against the original tensor x.
"""
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if axis is None:
axis = range(x.type.ndim)
i = 0
new_dims = []
for j, _ in enumerate(x.type.broadcastable):
if j in axis:
new_dims.append('x')
......@@ -2333,7 +2340,7 @@ def makeKeepDims(x, y, axis):
return DimShuffle(y.type.broadcastable, new_dims)(y)
@_constructor
@constructor
def max_and_argmax(a, axis=None, keepdims=False):
"""
Returns maximum elements and their indices obtained by iterating over
......@@ -2421,7 +2428,7 @@ def min(x, axis=None, keepdims=False):
str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes:
out = -max(-x, axis=axis, keepdims=keepdims)
return -max(-x, axis=axis, keepdims=keepdims)
else:
#Be careful about unsigned integers, complex
raise NotImplementedError()
......
import numpy
from theano import tensor, function
class TestKeepDims:
def makeKeepDims_local(self, x, y, axis):
x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y)
if axis is None:
axis = numpy.arange(x.ndim)
i = 0
new_dims = []
for j, _ in enumerate(x.shape):
if j in axis:
new_dims.append('x')
else:
new_dims.append(i)
i += 1
return tensor.DimShuffle(y.type.broadcastable, new_dims)(y)
def test_keepdims(self):
x = tensor.dtensor3()
a = numpy.random.rand(3, 2, 4)
# 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis:
for axis in [[0], [1], [2]]:
op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[0], axis))
# FRED: choisir l'une ou l'autre de ces verifications:
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=axis, keepdims=True)[1])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[1], axis))
# FRED: choisir l'une ou l'autre de ces verifications:
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
# the following ops can be specified with either a single axis or every
# axis:
for op in ([tensor.argmax, tensor.max, tensor.argmin, tensor.min]):
for axis in [[0], [1], [2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
# FRED: choisir l'une ou l'autre de ces verifications:
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None))
# FRED: choisir l'une ou l'autre de ces verifications:
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
# the following ops can be specified with a freely specified axis
# parameter
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std]):
# FRED: il faudra ajouter les ops suivantes a la boucle ci-dessus:
# tensor.all, tensor.any
# Celles-ci semblent presentement defectueuses puisqu'elles plantent
# a la compilation dans un interpreteur distinct.
for axis in [[0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
# FRED: choisir l'une ou l'autre de ces verifications:
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None))
# FRED: choisir l'une ou l'autre de ces verifications:
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
if __name__ == '__main__':
TestKeepDims().test_keepdims()
\ No newline at end of file
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论