提交 61c1b955 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Added tensor.isnan and tensor.isinf. Also:

- Added some doc about the identity, commutative and associative fields - Fixed crash in filter function when using NaN values
上级 ef5ce2e0
...@@ -672,6 +672,12 @@ class UnaryScalarOp(ScalarOp): ...@@ -672,6 +672,12 @@ class UnaryScalarOp(ScalarOp):
nin = 1 nin = 1
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields:
# - `identity`: for an associative operation, identity corresponds to
# the neutral element. For instance, it will be 0 for addition, 1 for
# multiplication, True for "and", False for "or".
# - `commutative`: whether op(a, b) == op(b, a)
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
nin = 2 nin = 2
...@@ -685,6 +691,15 @@ class LogicalComparison(BinaryScalarOp): ...@@ -685,6 +691,15 @@ class LogicalComparison(BinaryScalarOp):
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
return [None, None] return [None, None]
class FixedLogicalComparison(UnaryScalarOp):
"""
Comparison to a fixed value.
"""
def output_types(self, *input_dtypes):
return [int8]
def grad(self, inputs, output_gradients):
return [None]
class LT(LogicalComparison): class LT(LogicalComparison):
identity = False identity = False
commutative = False commutative = False
...@@ -749,6 +764,7 @@ class EQ(LogicalComparison): ...@@ -749,6 +764,7 @@ class EQ(LogicalComparison):
return "%(z)s = (%(x)s == %(y)s);" % locals() return "%(z)s = (%(x)s == %(y)s);" % locals()
eq = EQ() eq = EQ()
class NEQ(LogicalComparison): class NEQ(LogicalComparison):
identity = False identity = False
commutative = True commutative = True
...@@ -761,6 +777,30 @@ class NEQ(LogicalComparison): ...@@ -761,6 +777,30 @@ class NEQ(LogicalComparison):
return "%(z)s = (%(x)s != %(y)s);" % locals() return "%(z)s = (%(x)s != %(y)s);" % locals()
neq = NEQ() neq = NEQ()
class IsNan(FixedLogicalComparison):
def impl(self, x):
return theano._asarray(numpy.isnan(x), dtype='int8')
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = isnan(%(x)s);" % locals()
isnan = IsNan()
class IsInf(FixedLogicalComparison):
def impl(self, x):
return theano._asarray(numpy.isinf(x), dtype='int8')
def c_code(self, node, name, (x, ), (z, ), sub):
if node.inputs[0].type in complex_types:
raise NotImplementedError()
# Note that the C isinf returns -1 for -Inf and +1 for +Inf, while
# numpy simply returns True: we mimic numpy's behavior here, thus
# the absolute value.
return "%(z)s = abs(isinf(%(x)s));" % locals()
isinf = IsInf()
class InRange(LogicalComparison): class InRange(LogicalComparison):
nin = 3 nin = 3
def __init__(self, openlow, openhi): def __init__(self, openlow, openhi):
......
...@@ -581,7 +581,10 @@ class TensorType(Type): ...@@ -581,7 +581,10 @@ class TensorType(Type):
# data has to be converted. # data has to be converted.
# Check that this conversion is lossless # Check that this conversion is lossless
converted_data = theano._asarray(data, self.dtype) converted_data = theano._asarray(data, self.dtype)
if numpy.all(data == converted_data): # We use the `values_eq` static function from TensorType
# to handle NaN values.
if TensorType.values_eq(data, converted_data,
force_same_dtype=False):
data = converted_data data = converted_data
else: else:
# Do not print a too long description of data # Do not print a too long description of data
...@@ -661,12 +664,12 @@ class TensorType(Type): ...@@ -661,12 +664,12 @@ class TensorType(Type):
return False return False
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b, force_same_dtype=True):
#TODO: check to see if the dtype and shapes must match #TODO: check to see if the dtype and shapes must match
# for now, we err on safe side... # for now, we err on safe side...
if a.shape != b.shape: if a.shape != b.shape:
return False return False
if a.dtype != b.dtype: if force_same_dtype and a.dtype != b.dtype:
return False return False
a_eq_b = (a==b) a_eq_b = (a==b)
r = numpy.all(a_eq_b) r = numpy.all(a_eq_b)
...@@ -2050,6 +2053,14 @@ def eq(a, b): ...@@ -2050,6 +2053,14 @@ def eq(a, b):
def neq(a, b): def neq(a, b):
"""a != b""" """a != b"""
@_scal_elemwise_with_nfunc('isnan', 1, 1)
def isnan(a):
"""isnan(a)"""
@_scal_elemwise_with_nfunc('isinf', 1, 1)
def isinf(a):
"""isinf(a)"""
########################## ##########################
# Condition # Condition
......
...@@ -6,6 +6,7 @@ from theano import gof ...@@ -6,6 +6,7 @@ from theano import gof
from theano.scalar import * from theano.scalar import *
from theano import tensor from theano import tensor
from theano.compile.mode import get_default_mode
from theano.tensor.elemwise import * from theano.tensor.elemwise import *
from theano.tests import unittest_tools from theano.tests import unittest_tools
...@@ -405,6 +406,45 @@ class test_Prod(unittest.TestCase): ...@@ -405,6 +406,45 @@ class test_Prod(unittest.TestCase):
cPickle.dumps(o) cPickle.dumps(o)
class test_IsInf_IsNan(unittest.TestCase):
def setUp(self):
self.test_vals = map(numpy.array, [
0,
1,
numpy.nan,
numpy.inf,
-numpy.inf,
[numpy.nan, numpy.inf, -numpy.inf, 0, 1, -1],
])
self.scalar = tensor.scalar()
self.vector = tensor.vector()
self.mode = get_default_mode()
if isinstance(self.mode, theano.compile.debugmode.DebugMode):
# Disable the check preventing usage of NaN / Inf values.
self.mode = copy(self.mode)
self.mode.check_isfinite = False
def run_test(self, isfunc):
for input in (self.scalar, self.vector):
theano_isfunc = theano.function([input],
getattr(tensor, isfunc)(input),
mode=self.mode)
numpy_isfunc = getattr(numpy, isfunc)
for x in self.test_vals:
if ((x.ndim == 0 and input is not self.scalar) or
(x.ndim == 1 and input is not self.vector)):
# We only test with the appropriate input type.
continue
assert (theano_isfunc(x) == numpy_isfunc(x)).all()
def test_isinf(self):
return self.run_test('isinf')
def test_isnan(self):
return self.run_test('isnan')
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')]) suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论