提交 262b0278 authored 作者: lamblin's avatar lamblin

Merge pull request #693 from larseeri/keepdims

ajout parametre keepdims
...@@ -1464,11 +1464,11 @@ class _tensor_py_operators: ...@@ -1464,11 +1464,11 @@ class _tensor_py_operators:
size = property(lambda self: prod(self.shape)) size = property(lambda self: prod(self.shape))
# We can't implement __len__ to provide a better error message. # We can't implement __len__ to provide a better error message.
def any(self, axis=None): def any(self, axis=None, keepdims=False):
return elemwise.Any(axis)(self) return any(self, axis=axis, keepdims=keepdims)
def all(self, axis=None): def all(self, axis=None, keepdims=False):
return elemwise.All(axis)(self) return all(self, axis=axis, keepdims=keepdims)
# Otherwise TensorVariable[:-1] does not work as Python 2.5.1 calls # Otherwise TensorVariable[:-1] does not work as Python 2.5.1 calls
# __len__ before calling __getitem__. It also does not catch the raised # __len__ before calling __getitem__. It also does not catch the raised
...@@ -1618,13 +1618,13 @@ class _tensor_py_operators: ...@@ -1618,13 +1618,13 @@ class _tensor_py_operators:
def __rdot__(right, left): def __rdot__(right, left):
return dot(left, right) return dot(left, right)
def sum(self, axis=None, dtype=None): def sum(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.sum`""" """See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype) return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
def prod(self, axis=None, dtype=None): def prod(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.prod`""" """See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype) return prod(self, axis=axis, dtype=dtype, keepdims=keepdims)
def norm(self, L, axis=None): def norm(self, L, axis=None):
if L == 0: if L == 0:
...@@ -1634,21 +1634,21 @@ class _tensor_py_operators: ...@@ -1634,21 +1634,21 @@ class _tensor_py_operators:
#optimizations will/should catch cases like L=1, L=2 #optimizations will/should catch cases like L=1, L=2
return pow(pow(abs_(self), L).sum(axis=axis), 1.0 / L) return pow(pow(abs_(self), L).sum(axis=axis), 1.0 / L)
def mean(self, axis=None, dtype=None): def mean(self, axis=None, dtype=None, keepdims=False):
"""See `theano.tensor.mean`""" """See `theano.tensor.mean`"""
return mean(self, axis=axis, dtype=dtype) return mean(self, axis=axis, dtype=dtype, keepdims=keepdims)
def var(self, axis=None): def var(self, axis=None, keepdims=False):
"""See `theano.tensor.var`""" """See `theano.tensor.var`"""
return var(self, axis) return var(self, axis, keepdims=keepdims)
def min(self, axis=None): def min(self, axis=None, keepdims=False):
"""See `theano.tensor.min`""" """See `theano.tensor.min`"""
return min(self, axis) return min(self, axis, keepdims=keepdims)
def max(self, axis=None): def max(self, axis=None, keepdims=False):
"""See `theano.tensor.max`""" """See `theano.tensor.max`"""
return max(self, axis) return max(self, axis, keepdims=keepdims)
#TO TRUMP NUMPY OPERATORS #TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000 __array_priority__ = 1000
...@@ -2182,7 +2182,7 @@ specify_shape = SpecifyShape() ...@@ -2182,7 +2182,7 @@ specify_shape = SpecifyShape()
class MaxAndArgmax(Op): class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis. """Calculate the max and argmax over a given axis or over all axes.
""" """
nin = 2 # tensor, axis nin = 2 # tensor, axis
nout = 2 # max val, max idx nout = 2 # max val, max idx
...@@ -2203,8 +2203,8 @@ class MaxAndArgmax(Op): ...@@ -2203,8 +2203,8 @@ class MaxAndArgmax(Op):
list(axis) list(axis)
axis.sort() axis.sort()
assert axis == range(x.type.ndim), ( assert axis == range(x.type.ndim), (
"MaxAndArgmax don't support multiple" "MaxAndArgmax does not support multiple"
" axis. the max fct support it.") " axes. the max fct supports it.")
# we make the axis all positive to make the infer_shape work # we make the axis all positive to make the infer_shape work
# with negative axis # with negative axis
if x.type.ndim > 0 and axis is not None: if x.type.ndim > 0 and axis is not None:
...@@ -2274,6 +2274,11 @@ class MaxAndArgmax(Op): ...@@ -2274,6 +2274,11 @@ class MaxAndArgmax(Op):
max_pos], None] max_pos], None]
def grad(self, inp, grads): def grad(self, inp, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s. # @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw) # (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy # gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
...@@ -2308,61 +2313,144 @@ class MaxAndArgmax(Op): ...@@ -2308,61 +2313,144 @@ class MaxAndArgmax(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
_max_and_argmax = MaxAndArgmax() _max_and_argmax = MaxAndArgmax()
@_redefine_asRoutine(_max_and_argmax) def makeKeepDims(x, y, axis):
def max_and_argmax(a): """
pass Reintroduces in y with length one the axes of x which have been left out
in a prior reduction of x. With this option, the resulting tensor will
broadcast correctly against the original tensor x.
"""
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if axis is None:
axis = range(x.type.ndim)
i = 0
new_dims = []
for j, _ in enumerate(x.type.broadcastable):
if j in axis:
new_dims.append('x')
else:
new_dims.append(i)
i += 1
return DimShuffle(y.type.broadcastable, new_dims)(y)
@constructor @constructor
def max(x, axis=None): def max_and_argmax(a, axis=None, keepdims=False):
""" """
Return maximum elements obtained by iterating over given axis Returns maximum elements and their indices obtained by iterating over
given axis
Default axis is None: max over all dimensions. When axis is None (the default value), the max is performed
over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
"""
out, argout = _max_and_argmax(a, axis)
if keepdims:
out = makeKeepDims(a, out, axis)
argout = makeKeepDims(a, argout, axis)
return [out, argout]
@constructor
def max(x, axis=None, keepdims=False):
"""
Returns maximum elements obtained by iterating over given axis
When axis is None (the default value), the max is performed
over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
:note: we return an error as numpy when we reduce a dim with a shape of 0 :note: we return an error as numpy when we reduce a dim with a shape of 0
""" """
if isinstance(axis, (list, tuple)) and len(axis) > 1: if isinstance(axis, (list, tuple)) and len(axis) > 1:
return CAReduce(scal.maximum, axis)(x) out = CAReduce(scal.maximum, axis)(x)
try: else:
const = get_constant_value(axis) try:
return CAReduce(scal.maximum, list(const))(x) const = get_constant_value(axis)
except Exception: out = CAReduce(scal.maximum, list(const))(x)
return max_and_argmax(x, axis)[0] except Exception:
out = max_and_argmax(x, axis)[0]
if keepdims:
out = makeKeepDims(x, out, axis)
return out
@constructor @constructor
def argmax(x, axis=None): def argmax(x, axis=None, keepdims=False):
""" """
Return indexes of maximum elements obtained by iterating over given axis Returns indices of maximum elements obtained by iterating over given axis
When axis is None (the default value), the argmax is performed When axis is None (the default value), the argmax is performed
over the flattened tensor. over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
""" """
# In python (using MaxAndArgmax.perform()) this leads to an wasteful # In python (using MaxAndArgmax.perform()) this leads to a wasteful
# implementation that goes through the data twice instead of once # implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine. # but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x, axis)[1]
argout = max_and_argmax(x, axis)[1]
if keepdims:
argout = makeKeepDims(x, argout, axis)
return argout
@constructor @constructor
def min(x, axis=None): def min(x, axis=None, keepdims=False):
"""
Returns minimum elements obtained by iterating over given axis
When axis is None (the default value), the min is performed
over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
"""
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return -max(-x, axis=axis) return -max(-x, axis=axis, keepdims=keepdims)
else: else:
#Be careful about unsigned integers, complex #Be careful about unsigned integers, complex
raise NotImplementedError() raise NotImplementedError()
@constructor @constructor
def argmin(x, axis=None): def argmin(x, axis=None, keepdims=False):
"""
Returns indices of minimum elements obtained by iterating over given axis
When axis is None (the default value), the argmin is performed
over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
"""
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
if str_x_type.startswith('float') or str_x_type in int_dtypes: if str_x_type.startswith('float') or str_x_type in int_dtypes:
return argmax(-x, axis=axis) return argmax(-x, axis=axis, keepdims=keepdims)
else: else:
#Be careful about unsigned integers, complex #Be careful about unsigned integers, complex
raise NotImplementedError() raise NotImplementedError()
...@@ -3029,27 +3117,51 @@ pprint.assign(tensor_copy, printing.IgnorePrinter()) ...@@ -3029,27 +3117,51 @@ pprint.assign(tensor_copy, printing.IgnorePrinter())
@constructor @constructor
def sum(input, axis=None, dtype=None): def sum(input, axis=None, dtype=None, keepdims=False):
""" """
Sum a tensor along the given axis(es). Computes the sum along the given axis(es) of a tensor `input`
When axis is None (the default value), the sum is performed
over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
For full documentation see ``tensor.elemwise.Sum``. For full documentation see ``tensor.elemwise.Sum``.
In particular please pay attention to the important warning when using In particular please pay attention to the important warning when using
a custom dtype. a custom dtype.
""" """
return elemwise.Sum(axis=axis, dtype=dtype)(input)
out = elemwise.Sum(axis=axis, dtype=dtype)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
return out
pprint.assign(Sum(), printing.FunctionPrinter('sum')) pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor @constructor
def prod(input, axis=None, dtype=None): def prod(input, axis=None, dtype=None, keepdims=False):
""" """
Returns the Product of a tensor's elements along the given axis(es). Computes the product along the given axis(es) of a tensor `input`
When axis is None (the default value), the product is performed
over the flattened tensor.
keepdims: If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
For full documentation see ``tensor.elemwise.Prod``. For full documentation see ``tensor.elemwise.Prod``.
""" """
return elemwise.Prod(axis, dtype=dtype)(input)
out = elemwise.Prod(axis, dtype=dtype)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
return out
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
...@@ -3088,8 +3200,9 @@ class Mean(elemwise.CAReduce): ...@@ -3088,8 +3200,9 @@ class Mean(elemwise.CAReduce):
@constructor @constructor
def mean(input, axis=None, dtype=None, op=False): def mean(input, axis=None, dtype=None, op=False, keepdims=False):
"""Compute the mean value along the given axis of a tensor `input` """
Computes the mean value along the given axis(es) of a tensor `input`
:param axis: compute the mean along this axis of the tensor. :param axis: compute the mean along this axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
...@@ -3102,9 +3215,14 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -3102,9 +3215,14 @@ def mean(input, axis=None, dtype=None, op=False):
If None, then we use the same rules as `sum()`. If None, then we use the same rules as `sum()`.
:type dtype: None or string :type dtype: None or string
:param keepdims: If this is set to True, the axes which are reduced are
left in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original tensor.
:note: for gpu, if you specify dtype=float32, everything will be done :note: for gpu, if you specify dtype=float32, everything will be done
on the gpu. on the gpu.
""" """
if op: if op:
if dtype not in (None, 'float64'): if dtype not in (None, 'float64'):
raise NotImplementedError( raise NotImplementedError(
...@@ -3112,7 +3230,10 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -3112,7 +3230,10 @@ def mean(input, axis=None, dtype=None, op=False):
'and will always use float64. If you want to specify ' 'and will always use float64. If you want to specify '
'the dtype, call tensor.mean(..., op=False).', 'the dtype, call tensor.mean(..., op=False).',
dtype) dtype)
return Mean(axis)(input) out = Mean(axis)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
return out
if dtype is not None: if dtype is not None:
# The summation will be done with the specified dtype. # The summation will be done with the specified dtype.
...@@ -3122,7 +3243,7 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -3122,7 +3243,7 @@ def mean(input, axis=None, dtype=None, op=False):
# Let sum() infer the appropriate dtype. # Let sum() infer the appropriate dtype.
sum_dtype = None sum_dtype = None
s = sum(input, axis=axis, dtype=sum_dtype) s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims)
shp = shape(input) shp = shape(input)
# Cast shp into a float type # Cast shp into a float type
...@@ -3138,6 +3259,7 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -3138,6 +3259,7 @@ def mean(input, axis=None, dtype=None, op=False):
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
# This sequential division will possibly be optimized by Theano:
for i in axis: for i in axis:
s = true_div(s, shp[i]) s = true_div(s, shp[i])
...@@ -3145,54 +3267,50 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -3145,54 +3267,50 @@ def mean(input, axis=None, dtype=None, op=False):
@constructor @constructor
def var(input, axis=None): def var(input, axis=None, keepdims=False):
"""Compute the variance along the given axis of a tensor `input`. """
Computes the variance along the given axis(es) of a tensor `input`.
:param axis: Compute the variance along this axis of the tensor. :param axis: Compute the variance along this axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`) :type axis: None or int or (list of int) (see `Sum`)
:param keepdims: If this is set to True, the axes which are reduced are
left in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original tensor.
""" """
input_ndim = input.type.ndim input_ndim = input.type.ndim
if axis is None: if axis is None:
axis = range(input_ndim) axis = range(input_ndim)
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
#make a pattern that will undo the reduction of dimensions caused by mean
pattern = []
next_dim = 0
for i in xrange(input_ndim):
if i in axis:
pattern.append('x')
else:
pattern.append(next_dim)
next_dim += 1
#compute the axis-wise mean #compute the axis-wise mean
mean_input_reduced = mean(input, axis) mean_input = mean(input, axis, keepdims=True)
#broadcast that back out to match input
mean_input = DimShuffle(
list(mean_input_reduced.type.broadcastable),
pattern)(mean_input_reduced)
#center the input #center the input
centered_input = input - mean_input centered_input = input - mean_input
#return the mean sqr #return the mean sqr
return mean((centered_input ** 2), axis) return mean((centered_input ** 2), axis, keepdims=keepdims)
@constructor @constructor
def std(input, axis=None): def std(input, axis=None, keepdims=False):
"""Compute the standard deviation along the given axis of a tensor `input`. """
Computes the standard deviation along the given axis(es) of a tensor `input`.
:param axis: Compute the standard deviation along this axis of the tensor. :param axis: Compute the standard deviation along this axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`) :type axis: None or int or (list of int) (see `Sum`)
:param keepdims: If this is set to True, the axes which are reduced are
left in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original tensor.
""" """
return sqrt(var(input=input, axis=axis))
return sqrt(var(input=input, axis=axis, keepdims=keepdims))
if 0: if 0:
## COMMENTED OUT FEB 17 2010 ## COMMENTED OUT FEB 17 2010
...@@ -6330,9 +6448,17 @@ def outer(x, y): ...@@ -6330,9 +6448,17 @@ def outer(x, y):
y.dimshuffle('x', 0)) y.dimshuffle('x', 0))
def any(x, axis=None): def any(x, axis=None, keepdims=False):
return elemwise.Any(axis)(x) out = elemwise.Any(axis)(x)
if keepdims:
out = makeKeepDims(x, out, axis)
return out
def all(x, axis=None, keepdims=False):
out = elemwise.All(axis)(x)
def all(x, axis=None): if keepdims:
return elemwise.All(axis)(x) out = makeKeepDims(x, out, axis)
return out
...@@ -1154,7 +1154,10 @@ class CAReduce(Op): ...@@ -1154,7 +1154,10 @@ class CAReduce(Op):
axis2.append(a) axis2.append(a)
assert len(axis) == len(axis2) assert len(axis) == len(axis2)
axis = tuple(axis2) axis = tuple(axis2)
op = self.__class__(self.scalar_op, axis) # We can't call self.__class__() as there is class that
# inherit from CAReduce that don't have the same signature
op = copy(self)
op.axis = axis
else: else:
op = self op = self
broadcastable = [x for i, x in enumerate(input.type.broadcastable) broadcastable = [x for i, x in enumerate(input.type.broadcastable)
...@@ -1409,6 +1412,12 @@ class All(CAReduce): ...@@ -1409,6 +1412,12 @@ class All(CAReduce):
else: else:
return "All{%s}" % ", ".join(map(str, self.axis)) return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input)
return ret
class Any(CAReduce): class Any(CAReduce):
""" Applies `bitwise or` to all the values of a tensor along the """ Applies `bitwise or` to all the values of a tensor along the
...@@ -1428,6 +1437,12 @@ class Any(CAReduce): ...@@ -1428,6 +1437,12 @@ class Any(CAReduce):
else: else:
return "Any{%s}" % ", ".join(map(str, self.axis)) return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input)
return ret
class CAReduceDtype(CAReduce): class CAReduceDtype(CAReduce):
""" """
......
import cPickle, time, unittest import cPickle
from copy import copy
from itertools import imap from itertools import imap
import time
import unittest
import numpy
from numpy.testing import dec from numpy.testing import dec
import theano
from theano.gof import Variable, Op from theano.gof import Variable, Op
from theano import gof from theano import gof, scalar, config
from theano.scalar import *
from theano import tensor from theano import tensor
from theano.tensor import TensorType
from theano.compile.mode import get_default_mode from theano.compile.mode import get_default_mode
from theano.tensor.elemwise import * from theano.tensor.elemwise import (CAReduce, Elemwise, DimShuffle,
Prod, ProdWithoutZeros)
from theano.tests import unittest_tools from theano.tests import unittest_tools
...@@ -18,6 +23,7 @@ def Env(i, o): ...@@ -18,6 +23,7 @@ def Env(i, o):
e = gof.Env(i, o) e = gof.Env(i, o)
return e return e
class test_DimShuffle(unittest.TestCase): class test_DimShuffle(unittest.TestCase):
def with_linker(self, linker): def with_linker(self, linker):
...@@ -25,11 +31,12 @@ class test_DimShuffle(unittest.TestCase): ...@@ -25,11 +31,12 @@ class test_DimShuffle(unittest.TestCase):
((1, 2, 3), (1, 2), (2, 3)), ((1, 2, 3), (1, 2), (2, 3)),
((1, 2, 1, 3), (1, 3), (2, 3)), ((1, 2, 1, 3), (1, 3), (2, 3)),
((2, 3, 4), (2, 1, 0), (4, 3, 2)), ((2, 3, 4), (2, 1, 0), (4, 3, 2)),
((2, 3, 4), ('x', 2, 1, 0, 'x'), (1, 4, 3, 2, 1)), ((2, 3, 4), ('x', 2, 1, 0, 'x'),
(1, 4, 3, 2, 1)),
((1, 4, 3, 2, 1), (3, 2, 1), (2, 3, 4)), ((1, 4, 3, 2, 1), (3, 2, 1), (2, 3, 4)),
((1, 1, 4), (1, 2), (1, 4)), ((1, 1, 4), (1, 2), (1, 4)),
((1, 1, 1), (), ()), ((1, 1, 1), (), ()),
((1,), ('x', 'x'), (1, 1)),]: ((1,), ('x', 'x'), (1, 1))]:
ib = [(entry == 1) for entry in xsh] ib = [(entry == 1) for entry in xsh]
x = TensorType('float64', ib)('x') x = TensorType('float64', ib)('x')
e = DimShuffle(ib, shuffle)(x) e = DimShuffle(ib, shuffle)(x)
...@@ -67,6 +74,7 @@ class test_DimShuffle(unittest.TestCase): ...@@ -67,6 +74,7 @@ class test_DimShuffle(unittest.TestCase):
# But This will test DimShuffle c code # But This will test DimShuffle c code
self.with_linker(gof.OpWiseCLinker()) self.with_linker(gof.OpWiseCLinker())
class test_Broadcast(unittest.TestCase): class test_Broadcast(unittest.TestCase):
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
...@@ -83,7 +91,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -83,7 +91,7 @@ class test_Broadcast(unittest.TestCase):
((), ())]: ((), ())]:
x = TensorType('float64', [(entry == 1) for entry in xsh])('x') x = TensorType('float64', [(entry == 1) for entry in xsh])('x')
y = TensorType('float64', [(entry == 1) for entry in ysh])('y') y = TensorType('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(add)(x, y) e = Elemwise(scalar.add)(x, y)
f = copy(linker).accept(Env([x, y], [e])).make_function() f = copy(linker).accept(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
...@@ -93,12 +101,12 @@ class test_Broadcast(unittest.TestCase): ...@@ -93,12 +101,12 @@ class test_Broadcast(unittest.TestCase):
#test Elemwise.infer_shape #test Elemwise.infer_shape
#the Shape op don't implement c_code! #the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker): if isinstance(linker, gof.PerformLinker):
x = TensorType('float64', [(entry == 1) for entry in xsh])('x') x = TensorType('float64', [(entry == 1) for entry in xsh])('x')
y = TensorType('float64', [(entry == 1) for entry in ysh])('y') y = TensorType('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(add)(x, y) e = Elemwise(scalar.add)(x, y)
f = copy(linker).accept(Env([x, y], [e.shape])).make_function() f = copy(linker).accept(Env([x, y], [e.shape])).make_function()
assert tuple(f(xv, yv))==tuple(zv.shape) assert tuple(f(xv, yv)) == tuple(zv.shape)
def with_linker_inplace(self, linker): def with_linker_inplace(self, linker):
for xsh, ysh in [((5, 5), (5, 5)), for xsh, ysh in [((5, 5), (5, 5)),
...@@ -111,7 +119,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -111,7 +119,7 @@ class test_Broadcast(unittest.TestCase):
((), ())]: ((), ())]:
x = TensorType('float64', [(entry == 1) for entry in xsh])('x') x = TensorType('float64', [(entry == 1) for entry in xsh])('x')
y = TensorType('float64', [(entry == 1) for entry in ysh])('y') y = TensorType('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(Add(transfer_type(0)), {0:0})(x, y) e = Elemwise(scalar.Add(scalar.transfer_type(0)), {0: 0})(x, y)
f = copy(linker).accept(Env([x, y], [e])).make_function() f = copy(linker).accept(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
...@@ -122,10 +130,10 @@ class test_Broadcast(unittest.TestCase): ...@@ -122,10 +130,10 @@ class test_Broadcast(unittest.TestCase):
self.assertTrue((xv == zv).all()) self.assertTrue((xv == zv).all())
#test Elemwise.infer_shape #test Elemwise.infer_shape
#the Shape op don't implement c_code! #the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker): if isinstance(linker, gof.PerformLinker):
x = TensorType('float64', [(entry == 1) for entry in xsh])('x') x = TensorType('float64', [(entry == 1) for entry in xsh])('x')
y = TensorType('float64', [(entry == 1) for entry in ysh])('y') y = TensorType('float64', [(entry == 1) for entry in ysh])('y')
e = Elemwise(Add(transfer_type(0)), {0:0})(x, y) e = Elemwise(scalar.Add(scalar.transfer_type(0)), {0: 0})(x, y)
f = copy(linker).accept(Env([x, y], [e.shape])).make_function() f = copy(linker).accept(Env([x, y], [e.shape])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
...@@ -133,7 +141,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -133,7 +141,7 @@ class test_Broadcast(unittest.TestCase):
f(xv, yv) f(xv, yv)
assert xv.shape==zv.shape assert xv.shape == zv.shape
def test_perform(self): def test_perform(self):
self.with_linker(gof.PerformLinker()) self.with_linker(gof.PerformLinker())
...@@ -150,7 +158,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -150,7 +158,7 @@ class test_Broadcast(unittest.TestCase):
def test_fill(self): def test_fill(self):
x = TensorType('float64', [0, 0])('x') x = TensorType('float64', [0, 0])('x')
y = TensorType('float64', [1, 1])('y') y = TensorType('float64', [1, 1])('y')
e = Elemwise(Second(transfer_type(0)), {0:0})(x, y) e = Elemwise(scalar.Second(scalar.transfer_type(0)), {0: 0})(x, y)
f = gof.CLinker().accept(Env([x, y], [e])).make_function() f = gof.CLinker().accept(Env([x, y], [e])).make_function()
xv = numpy.ones((5, 5)) xv = numpy.ones((5, 5))
yv = numpy.random.rand(1, 1) yv = numpy.random.rand(1, 1)
...@@ -160,7 +168,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -160,7 +168,7 @@ class test_Broadcast(unittest.TestCase):
def test_weird_strides(self): def test_weird_strides(self):
x = TensorType('float64', [0, 0, 0, 0, 0])('x') x = TensorType('float64', [0, 0, 0, 0, 0])('x')
y = TensorType('float64', [0, 0, 0, 0, 0])('y') y = TensorType('float64', [0, 0, 0, 0, 0])('y')
e = Elemwise(add)(x, y) e = Elemwise(scalar.add)(x, y)
f = gof.CLinker().accept(Env([x, y], [e])).make_function() f = gof.CLinker().accept(Env([x, y], [e])).make_function()
xv = numpy.random.rand(2, 2, 2, 2, 2) xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2) yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
...@@ -169,7 +177,7 @@ class test_Broadcast(unittest.TestCase): ...@@ -169,7 +177,7 @@ class test_Broadcast(unittest.TestCase):
def test_same_inputs(self): def test_same_inputs(self):
x = TensorType('float64', [0, 0])('x') x = TensorType('float64', [0, 0])('x')
e = Elemwise(add)(x, x) e = Elemwise(scalar.add)(x, x)
f = gof.CLinker().accept(Env([x], [e])).make_function() f = gof.CLinker().accept(Env([x], [e])).make_function()
xv = numpy.random.rand(2, 2) xv = numpy.random.rand(2, 2)
zv = xv + xv zv = xv + xv
...@@ -180,8 +188,8 @@ class test_CAReduce(unittest.TestCase): ...@@ -180,8 +188,8 @@ class test_CAReduce(unittest.TestCase):
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
def with_linker(self, linker, scalar_op = add, dtype="floatX", def with_linker(self, linker, scalar_op=scalar.add, dtype="floatX",
test_nan=False): test_nan=False, tensor_op=None):
for xsh, tosum in [((5, 6), None), for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)), ((5, 6), (0, 1)),
((5, 6), (0, )), ((5, 6), (0, )),
...@@ -200,18 +208,24 @@ class test_CAReduce(unittest.TestCase): ...@@ -200,18 +208,24 @@ class test_CAReduce(unittest.TestCase):
if dtype == "floatX": if dtype == "floatX":
dtype = theano.config.floatX dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x) if tensor_op is None:
if tosum is None: tosum = range(len(xsh)) e = CAReduce(scalar_op, axis=tosum)(x)
else:
e = tensor_op(x, axis=tosum)
if tosum is None:
tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e])).make_function() f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
if not "int" in dtype: if not "int" in dtype:
xv = numpy.asarray(xv,dtype=dtype) xv = numpy.asarray(xv, dtype=dtype)
else: else:
xv = numpy.asarray(xv<0.5,dtype=dtype) xv = numpy.asarray(xv < 0.5, dtype=dtype)
if test_nan and xv.size > 0: if test_nan and xv.size > 0:
if len(xsh)>0: if len(xsh) > 0:
xv = xv.flatten() xv = xv.flatten()
xv[0] = numpy.nan xv[0] = numpy.nan
xv = xv.reshape(*xsh) xv = xv.reshape(*xsh)
...@@ -219,49 +233,63 @@ class test_CAReduce(unittest.TestCase): ...@@ -219,49 +233,63 @@ class test_CAReduce(unittest.TestCase):
xv = numpy.asarray(numpy.nan, dtype=dtype) xv = numpy.asarray(numpy.nan, dtype=dtype)
zv = xv zv = xv
numpy_raised = False numpy_raised = False
if len(tosum)>1 and any([a<0 for a in tosum]): if len(tosum) > 1 and any([a < 0 for a in tosum]):
#In that case, we need to use the good order of axis in the reduction. #In that case, we need to use the good order of axis
#in the reduction.
axis2 = [] axis2 = []
for a in tosum: for a in tosum:
if a<0: axis2.append(a+len(xsh)) if a < 0:
else: axis2.append(a) axis2.append(a + len(xsh))
assert len(axis2)==len(tosum) else:
axis2.append(a)
assert len(axis2) == len(tosum)
tosum = tuple(axis2) tosum = tuple(axis2)
if tensor_op == tensor.all:
if scalar_op == add: for axis in reversed(sorted(tosum)):
zv = numpy.all(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif tensor_op == tensor.any:
for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif scalar_op == scalar.add:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.add.reduce(zv, axis) zv = numpy.add.reduce(zv, axis)
elif scalar_op == mul: elif scalar_op == scalar.mul:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.multiply.reduce(zv, axis) zv = numpy.multiply.reduce(zv, axis)
elif scalar_op == maximum: elif scalar_op == scalar.maximum:
try: try:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.maximum.reduce(zv, axis) zv = numpy.maximum.reduce(zv, axis)
except ValueError: except ValueError:
numpy_raised=True numpy_raised = True
elif scalar_op == minimum: elif scalar_op == scalar.minimum:
try: try:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.minimum.reduce(zv, axis) zv = numpy.minimum.reduce(zv, axis)
except ValueError: except ValueError:
numpy_raised=True numpy_raised = True
elif scalar_op == or_: elif scalar_op == scalar.or_:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.bitwise_or.reduce(zv, axis) zv = numpy.bitwise_or.reduce(zv, axis)
elif scalar_op == and_: elif scalar_op == scalar.and_:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.bitwise_and.reduce(zv, axis) zv = numpy.bitwise_and.reduce(zv, axis)
elif scalar_op == xor: elif scalar_op == scalar.xor:
# There is no identity value for the xor function # There is no identity value for the xor function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
if numpy.prod(zv.shape)==0: if numpy.prod(zv.shape) == 0:
continue continue
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.bitwise_xor.reduce(zv, axis) zv = numpy.bitwise_xor.reduce(zv, axis)
else: else:
raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op)) raise Exception(
if scalar_op in [maximum,minimum] and numpy_raised: "Test for CAReduce with scalar_op %s not implemented" %
str(scalar_op))
if scalar_op in [scalar.maximum, scalar.minimum] and numpy_raised:
try: try:
out = f(xv) out = f(xv)
assert out.dtype == dtype assert out.dtype == dtype
...@@ -271,78 +299,98 @@ class test_CAReduce(unittest.TestCase): ...@@ -271,78 +299,98 @@ class test_CAReduce(unittest.TestCase):
self.fail() self.fail()
else: else:
#numpy.{all,any} return bool type. #numpy.{all,any} return bool type.
if scalar_op in [and_, or_]: if scalar_op in [scalar.and_, scalar.or_]:
zv = numpy.asarray(zv, dtype=dtype) zv = numpy.asarray(zv, dtype=dtype)
if test_nan: if test_nan:
self.assertTrue(theano.tensor.TensorType.values_eq(f(xv), zv), (f(xv), zv)) self.assertTrue(theano.tensor.TensorType.values_eq(f(xv),
zv),
(f(xv), zv))
else: else:
self.assertTrue(numpy.allclose(f(xv), zv), (f(xv), zv)) self.assertTrue(numpy.allclose(f(xv), zv), (f(xv), zv))
#test CAReduce.infer_shape #test CAReduce.infer_shape
#the Shape op don't implement c_code! #the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker): if isinstance(linker, gof.PerformLinker):
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x) if tensor_op is None:
if tosum is None: tosum = range(len(xsh)) e = CAReduce(scalar_op, axis=tosum)(x)
else:
e = tensor_op(x, axis=tosum)
if tosum is None:
tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function() f = copy(linker).accept(Env([x], [e.shape])).make_function()
if not(scalar_op in [maximum,minimum] and ((xsh==() or numpy.prod(xsh)==0))): if not(scalar_op in [scalar.maximum, scalar.minimum] and
assert all(f(xv)== zv.shape) ((xsh == () or numpy.prod(xsh) == 0))):
assert all(f(xv) == zv.shape)
def test_perform(self): def test_perform(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.PerformLinker(), add, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.PerformLinker(), mul, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), maximum, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.maximum, dtype=dtype)
self.with_linker(gof.PerformLinker(), minimum, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.minimum, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.PerformLinker(), scalar.or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]: for dtype in ["int8", "uint8"]:
self.with_linker(gof.PerformLinker(), or_, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.or_, dtype=dtype)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.PerformLinker(), xor, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.xor, dtype=dtype)
@dec.knownfailureif( @dec.knownfailureif(
True, True,
("When there is nan in the input of CAReduce, we don't have a good output. ")) ("When there is nan in the input of CAReduce,"
" we don't have a good output. "))
def test_perform_nan(self): def test_perform_nan(self):
for dtype in ["floatX", "complex64", "complex128"]: for dtype in ["floatX", "complex64", "complex128"]:
self.with_linker(gof.PerformLinker(), add, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), mul, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.mul, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), maximum, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.maximum, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), minimum, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.minimum, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.or_, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype,
test_nan=True, tensor_op=tensor.any)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
test_nan=True, tensor_op=tensor.all)
def test_c(self): def test_c(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.CLinker(), add, dtype=dtype) self.with_linker(gof.CLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.CLinker(), mul, dtype=dtype) self.with_linker(gof.CLinker(), scalar.mul, dtype=dtype)
for dtype in ["floatX", "int8", "uint8"]: for dtype in ["floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), minimum, dtype=dtype) self.with_linker(gof.CLinker(), scalar.minimum, dtype=dtype)
self.with_linker(gof.CLinker(), maximum, dtype=dtype) self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]: for dtype in ["int8", "uint8"]:
self.with_linker(gof.CLinker(), or_, dtype=dtype) self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype)
self.with_linker(gof.CLinker(), and_, dtype=dtype) self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.CLinker(), xor, dtype=dtype) self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype)
@dec.knownfailureif( @dec.knownfailureif(
True, True,
("When there is nan in the input of CAReduce, we don't have a good output. ")) ("When there is nan in the input of CAReduce,"
" we don't have a good output. "))
def test_c_nan(self): def test_c_nan(self):
for dtype in ["floatX", "complex64", "complex128"]: for dtype in ["floatX", "complex64", "complex128"]:
self.with_linker(gof.CLinker(), add, dtype=dtype, self.with_linker(gof.CLinker(), scalar.add, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.CLinker(), mul, dtype=dtype, self.with_linker(gof.CLinker(), scalar.mul, dtype=dtype,
test_nan=True) test_nan=True)
for dtype in ["floatX"]: for dtype in ["floatX"]:
self.with_linker(gof.CLinker(), minimum, dtype=dtype, self.with_linker(gof.CLinker(), scalar.minimum, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.CLinker(), maximum, dtype=dtype, self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype,
test_nan=True) test_nan=True)
...@@ -350,7 +398,8 @@ class test_Prod(unittest.TestCase): ...@@ -350,7 +398,8 @@ class test_Prod(unittest.TestCase):
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
# we want to allow nans in the matrices, so we disable this DEBUG_MODE check # we want to allow nans in the matrices, so we disable this
# DEBUG_MODE check
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
mode = copy(mode) mode = copy(mode)
mode.check_isfinite = False mode.check_isfinite = False
......
import numpy
from theano import tensor, function
class TestKeepDims:
def makeKeepDims_local(self, x, y, axis):
x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y)
if axis is None:
axis = numpy.arange(x.ndim)
i = 0
new_dims = []
for j, _ in enumerate(x.shape):
if j in axis:
new_dims.append('x')
else:
new_dims.append(i)
i += 1
return tensor.DimShuffle(y.type.broadcastable, new_dims)(y)
def test_keepdims(self):
x = tensor.dtensor3()
a = numpy.random.rand(3, 2, 4)
# 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis:
for axis in [[0], [1], [2], None, [0, 1, 2]]:
op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[0], axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=axis, keepdims=True)[1])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[1], axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
# the following ops can be specified with either a single axis or every
# axis:
for op in ([tensor.argmax, tensor.argmin]):
for axis in [[0], [1], [2], None, [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
# the following ops can be specified with a freely specified axis
# parameter
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]):
for axis in [[0], [1], [2], [0, 1], [1, 2], [0, 1, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
if __name__ == '__main__':
TestKeepDims().test_keepdims()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论