提交 ac8ce341 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make Any and All return bools.

上级 7563a320
......@@ -1703,18 +1703,16 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
class All(CAReduce):
""" Applies `bitwise and` to all the values of a tensor along the
""" Applies `logical and` to all the values of a tensor along the
specified axis(es).
Equivalent to `CAReduce(scalar.and\_, axis=axis)`.
"""
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis)
def _output_dtype(self, idtype):
return "int8"
return "bool"
def __str__(self):
if self.axis is None:
......@@ -1724,7 +1722,7 @@ class All(CAReduce):
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype not in ["int8", "uint8"]:
if input.dtype is not "bool":
input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input)
return ret
......@@ -1738,15 +1736,13 @@ class Any(CAReduce):
""" Applies `bitwise or` to all the values of a tensor along the
specified axis(es).
Equivalent to `CAReduce(scalar.or\_, axis=axis)`.
"""
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis)
def _output_dtype(self, idtype):
return "int8"
return "bool"
def __str__(self):
if self.axis is None:
......@@ -1756,7 +1752,7 @@ class Any(CAReduce):
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype not in ["int8", "uint8"]:
if input.dtype is not "bool":
input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input)
return ret
......@@ -1985,7 +1981,7 @@ class Sum(CAReduceDtype):
out = self(*inp)
if out.dtype.find('int') != -1:
if out.dtype not in theano.tensor.continuous_dtypes:
return [x.zeros_like(dtype=theano.config.floatX)]
gz, = grads
......
......@@ -8178,7 +8178,7 @@ def test_composite_neg_bool():
# `-numpy.bool_(True)` is False and `-numpy.bool_(False)` is True.
x = theano.tensor.vector()
f = theano.function([x], - (x > 0), mode=theano.Mode(linker='py'))
utt.assert_allclose(f([-1, 0, 1]), [0, 0, -1])
utt.assert_allclose(f([-1, 0, 1]), [False, False, True])
"""
......
......@@ -452,10 +452,6 @@ class test_CAReduce(unittest_tools.InferShapeTester):
else:
self.fail()
else:
# numpy.{all,any} return bool type,
# but theano ops return an int8 array instead
if scalar_op in [scalar.and_, scalar.or_]:
zv = numpy.asarray(zv, dtype='int8')
if test_nan:
try:
self.assertTrue(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论