提交 ac8ce341 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make Any and All return bools.

上级 7563a320
...@@ -1703,18 +1703,16 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1703,18 +1703,16 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
class All(CAReduce): class All(CAReduce):
""" Applies `bitwise and` to all the values of a tensor along the """ Applies `logical and` to all the values of a tensor along the
specified axis(es). specified axis(es).
Equivalent to `CAReduce(scalar.and\_, axis=axis)`.
""" """
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis) CAReduce.__init__(self, scalar.and_, axis)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "int8" return "bool"
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
...@@ -1724,7 +1722,7 @@ class All(CAReduce): ...@@ -1724,7 +1722,7 @@ class All(CAReduce):
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype not in ["int8", "uint8"]: if input.dtype is not "bool":
input = theano.tensor.neq(input, 0) input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input) ret = super(All, self).make_node(input)
return ret return ret
...@@ -1738,15 +1736,13 @@ class Any(CAReduce): ...@@ -1738,15 +1736,13 @@ class Any(CAReduce):
""" Applies `bitwise or` to all the values of a tensor along the """ Applies `bitwise or` to all the values of a tensor along the
specified axis(es). specified axis(es).
Equivalent to `CAReduce(scalar.or\_, axis=axis)`.
""" """
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis) CAReduce.__init__(self, scalar.or_, axis)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "int8" return "bool"
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
...@@ -1756,7 +1752,7 @@ class Any(CAReduce): ...@@ -1756,7 +1752,7 @@ class Any(CAReduce):
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype not in ["int8", "uint8"]: if input.dtype is not "bool":
input = theano.tensor.neq(input, 0) input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input) ret = super(Any, self).make_node(input)
return ret return ret
...@@ -1985,7 +1981,7 @@ class Sum(CAReduceDtype): ...@@ -1985,7 +1981,7 @@ class Sum(CAReduceDtype):
out = self(*inp) out = self(*inp)
if out.dtype.find('int') != -1: if out.dtype not in theano.tensor.continuous_dtypes:
return [x.zeros_like(dtype=theano.config.floatX)] return [x.zeros_like(dtype=theano.config.floatX)]
gz, = grads gz, = grads
......
...@@ -8178,7 +8178,7 @@ def test_composite_neg_bool(): ...@@ -8178,7 +8178,7 @@ def test_composite_neg_bool():
# `-numpy.bool_(True)` is False and `-numpy.bool_(False)` is True. # `-numpy.bool_(True)` is False and `-numpy.bool_(False)` is True.
x = theano.tensor.vector() x = theano.tensor.vector()
f = theano.function([x], - (x > 0), mode=theano.Mode(linker='py')) f = theano.function([x], - (x > 0), mode=theano.Mode(linker='py'))
utt.assert_allclose(f([-1, 0, 1]), [0, 0, -1]) utt.assert_allclose(f([-1, 0, 1]), [False, False, True])
""" """
......
...@@ -452,10 +452,6 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -452,10 +452,6 @@ class test_CAReduce(unittest_tools.InferShapeTester):
else: else:
self.fail() self.fail()
else: else:
# numpy.{all,any} return bool type,
# but theano ops return an int8 array instead
if scalar_op in [scalar.and_, scalar.or_]:
zv = numpy.asarray(zv, dtype='int8')
if test_nan: if test_nan:
try: try:
self.assertTrue( self.assertTrue(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论