提交 e6bbc361 authored 作者: Frederic's avatar Frederic

make tensor.{all,any} don't crash with other dtype then uint8 and int8.

上级 1bf87a89
......@@ -1154,7 +1154,10 @@ class CAReduce(Op):
axis2.append(a)
assert len(axis) == len(axis2)
axis = tuple(axis2)
op = self.__class__(self.scalar_op, axis)
# We can't call self.__class__() as there is class that
# inherit from CAReduce that don't have the same signature
op = copy(self)
op.axis = axis
else:
op = self
broadcastable = [x for i, x in enumerate(input.type.broadcastable)
......@@ -1409,6 +1412,12 @@ class All(CAReduce):
else:
return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input)
return ret
class Any(CAReduce):
""" Applies `bitwise or` to all the values of a tensor along the
......@@ -1428,6 +1437,12 @@ class Any(CAReduce):
else:
return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input)
return ret
class CAReduceDtype(CAReduce):
"""
......
......@@ -181,7 +181,7 @@ class test_CAReduce(unittest.TestCase):
unittest_tools.seed_rng()
def with_linker(self, linker, scalar_op = add, dtype="floatX",
test_nan=False):
test_nan=False, tensor_op=None):
for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)),
((5, 6), (0, )),
......@@ -200,7 +200,11 @@ class test_CAReduce(unittest.TestCase):
if dtype == "floatX":
dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x)
if tensor_op is None:
e = CAReduce(scalar_op, axis = tosum)(x)
else:
e = tensor_op(x, axis=tosum)
if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh))
......@@ -227,8 +231,17 @@ class test_CAReduce(unittest.TestCase):
else: axis2.append(a)
assert len(axis2)==len(tosum)
tosum = tuple(axis2)
if scalar_op == add:
if tensor_op == tensor.all:
for axis in reversed(sorted(tosum)):
zv = numpy.all(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif tensor_op == tensor.any:
for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif scalar_op == add:
for axis in reversed(sorted(tosum)):
zv = numpy.add.reduce(zv, axis)
elif scalar_op == mul:
......@@ -283,7 +296,10 @@ class test_CAReduce(unittest.TestCase):
#the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker):
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x)
if tensor_op is None:
e = CAReduce(scalar_op, axis = tosum)(x)
else:
e = tensor_op(x, axis=tosum)
if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function()
if not(scalar_op in [maximum,minimum] and ((xsh==() or numpy.prod(xsh)==0))):
......@@ -295,6 +311,10 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.PerformLinker(), mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), maximum, dtype=dtype)
self.with_linker(gof.PerformLinker(), minimum, dtype=dtype)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]:
self.with_linker(gof.PerformLinker(), or_, dtype=dtype)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype)
......@@ -317,6 +337,10 @@ class test_CAReduce(unittest.TestCase):
test_nan=True)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
test_nan=True)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype,
test_nan=True, tensor_op=tensor.any)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
test_nan=True, tensor_op=tensor.all)
def test_c(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
......@@ -325,6 +349,11 @@ class test_CAReduce(unittest.TestCase):
for dtype in ["floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), minimum, dtype=dtype)
self.with_linker(gof.CLinker(), maximum, dtype=dtype)
# all and any use neq that don't have c code for complex
self.with_linker(gof.CLinker(), and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.CLinker(), or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]:
self.with_linker(gof.CLinker(), or_, dtype=dtype)
self.with_linker(gof.CLinker(), and_, dtype=dtype)
......
......@@ -72,7 +72,7 @@ class TestKeepDims:
# the following ops can be specified with a freely specified axis
# parameter
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std]):
tensor.std, tensor.all, tensor.any]):
# FRED: il faudra ajouter les ops suivantes a la boucle ci-dessus:
# tensor.all, tensor.any
# Celles-ci semblent presentement defectueuses puisqu'elles plantent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论