提交 e6bbc361 authored 作者: Frederic's avatar Frederic

make tensor.{all,any} don't crash with other dtype then uint8 and int8.

上级 1bf87a89
...@@ -1154,7 +1154,10 @@ class CAReduce(Op): ...@@ -1154,7 +1154,10 @@ class CAReduce(Op):
axis2.append(a) axis2.append(a)
assert len(axis) == len(axis2) assert len(axis) == len(axis2)
axis = tuple(axis2) axis = tuple(axis2)
op = self.__class__(self.scalar_op, axis) # We can't call self.__class__() as there is class that
# inherit from CAReduce that don't have the same signature
op = copy(self)
op.axis = axis
else: else:
op = self op = self
broadcastable = [x for i, x in enumerate(input.type.broadcastable) broadcastable = [x for i, x in enumerate(input.type.broadcastable)
...@@ -1409,6 +1412,12 @@ class All(CAReduce): ...@@ -1409,6 +1412,12 @@ class All(CAReduce):
else: else:
return "All{%s}" % ", ".join(map(str, self.axis)) return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input)
return ret
class Any(CAReduce): class Any(CAReduce):
""" Applies `bitwise or` to all the values of a tensor along the """ Applies `bitwise or` to all the values of a tensor along the
...@@ -1428,6 +1437,12 @@ class Any(CAReduce): ...@@ -1428,6 +1437,12 @@ class Any(CAReduce):
else: else:
return "Any{%s}" % ", ".join(map(str, self.axis)) return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
if input.dtype not in ["int8", "uint8"]:
input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input)
return ret
class CAReduceDtype(CAReduce): class CAReduceDtype(CAReduce):
""" """
......
...@@ -181,7 +181,7 @@ class test_CAReduce(unittest.TestCase): ...@@ -181,7 +181,7 @@ class test_CAReduce(unittest.TestCase):
unittest_tools.seed_rng() unittest_tools.seed_rng()
def with_linker(self, linker, scalar_op = add, dtype="floatX", def with_linker(self, linker, scalar_op = add, dtype="floatX",
test_nan=False): test_nan=False, tensor_op=None):
for xsh, tosum in [((5, 6), None), for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)), ((5, 6), (0, 1)),
((5, 6), (0, )), ((5, 6), (0, )),
...@@ -200,7 +200,11 @@ class test_CAReduce(unittest.TestCase): ...@@ -200,7 +200,11 @@ class test_CAReduce(unittest.TestCase):
if dtype == "floatX": if dtype == "floatX":
dtype = theano.config.floatX dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x) if tensor_op is None:
e = CAReduce(scalar_op, axis = tosum)(x)
else:
e = tensor_op(x, axis=tosum)
if tosum is None: tosum = range(len(xsh)) if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e])).make_function() f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
...@@ -227,8 +231,17 @@ class test_CAReduce(unittest.TestCase): ...@@ -227,8 +231,17 @@ class test_CAReduce(unittest.TestCase):
else: axis2.append(a) else: axis2.append(a)
assert len(axis2)==len(tosum) assert len(axis2)==len(tosum)
tosum = tuple(axis2) tosum = tuple(axis2)
if tensor_op == tensor.all:
if scalar_op == add: for axis in reversed(sorted(tosum)):
zv = numpy.all(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif tensor_op == tensor.any:
for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif scalar_op == add:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.add.reduce(zv, axis) zv = numpy.add.reduce(zv, axis)
elif scalar_op == mul: elif scalar_op == mul:
...@@ -283,7 +296,10 @@ class test_CAReduce(unittest.TestCase): ...@@ -283,7 +296,10 @@ class test_CAReduce(unittest.TestCase):
#the Shape op don't implement c_code! #the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker): if isinstance(linker,gof.PerformLinker):
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x) if tensor_op is None:
e = CAReduce(scalar_op, axis = tosum)(x)
else:
e = tensor_op(x, axis=tosum)
if tosum is None: tosum = range(len(xsh)) if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function() f = copy(linker).accept(Env([x], [e.shape])).make_function()
if not(scalar_op in [maximum,minimum] and ((xsh==() or numpy.prod(xsh)==0))): if not(scalar_op in [maximum,minimum] and ((xsh==() or numpy.prod(xsh)==0))):
...@@ -295,6 +311,10 @@ class test_CAReduce(unittest.TestCase): ...@@ -295,6 +311,10 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.PerformLinker(), mul, dtype=dtype) self.with_linker(gof.PerformLinker(), mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), maximum, dtype=dtype) self.with_linker(gof.PerformLinker(), maximum, dtype=dtype)
self.with_linker(gof.PerformLinker(), minimum, dtype=dtype) self.with_linker(gof.PerformLinker(), minimum, dtype=dtype)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]: for dtype in ["int8", "uint8"]:
self.with_linker(gof.PerformLinker(), or_, dtype=dtype) self.with_linker(gof.PerformLinker(), or_, dtype=dtype)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype) self.with_linker(gof.PerformLinker(), and_, dtype=dtype)
...@@ -317,6 +337,10 @@ class test_CAReduce(unittest.TestCase): ...@@ -317,6 +337,10 @@ class test_CAReduce(unittest.TestCase):
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype, self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype,
test_nan=True, tensor_op=tensor.any)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype,
test_nan=True, tensor_op=tensor.all)
def test_c(self): def test_c(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
...@@ -325,6 +349,11 @@ class test_CAReduce(unittest.TestCase): ...@@ -325,6 +349,11 @@ class test_CAReduce(unittest.TestCase):
for dtype in ["floatX", "int8", "uint8"]: for dtype in ["floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), minimum, dtype=dtype) self.with_linker(gof.CLinker(), minimum, dtype=dtype)
self.with_linker(gof.CLinker(), maximum, dtype=dtype) self.with_linker(gof.CLinker(), maximum, dtype=dtype)
# all and any use neq that don't have c code for complex
self.with_linker(gof.CLinker(), and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.CLinker(), or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]: for dtype in ["int8", "uint8"]:
self.with_linker(gof.CLinker(), or_, dtype=dtype) self.with_linker(gof.CLinker(), or_, dtype=dtype)
self.with_linker(gof.CLinker(), and_, dtype=dtype) self.with_linker(gof.CLinker(), and_, dtype=dtype)
......
...@@ -72,7 +72,7 @@ class TestKeepDims: ...@@ -72,7 +72,7 @@ class TestKeepDims:
# the following ops can be specified with a freely specified axis # the following ops can be specified with a freely specified axis
# parameter # parameter
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var, for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std]): tensor.std, tensor.all, tensor.any]):
# FRED: il faudra ajouter les ops suivantes a la boucle ci-dessus: # FRED: il faudra ajouter les ops suivantes a la boucle ci-dessus:
# tensor.all, tensor.any # tensor.all, tensor.any
# Celles-ci semblent presentement defectueuses puisqu'elles plantent # Celles-ci semblent presentement defectueuses puisqu'elles plantent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论