提交 22d7c2fa authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5289 from nouiz/isnan

Remove isnan from the graph for discrete dtype.
...@@ -1334,13 +1334,17 @@ class IsNan(FixedLogicalComparison): ...@@ -1334,13 +1334,17 @@ class IsNan(FixedLogicalComparison):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError() raise NotImplementedError()
# Discrete type can never be nan
if node.inputs[0].type in discrete_types:
return "%(z)s = false;" % locals()
# Windows tries to be different and sometimes return -1, but we want # Windows tries to be different and sometimes return -1, but we want
# to be consistent with numpy (which returns True), hence the "abs". # to be consistent with numpy (which returns True), hence the "abs".
return "%(z)s = abs(isnan(%(x)s));" % locals() return "%(z)s = abs(isnan(%(x)s));" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
scalarop_version = super(IsNan, self).c_code_cache_version() scalarop_version = super(IsNan, self).c_code_cache_version()
return tuple(scalarop_version) + (2,) return tuple(scalarop_version) + (3,)
isnan = IsNan() isnan = IsNan()
...@@ -1355,10 +1359,18 @@ class IsInf(FixedLogicalComparison): ...@@ -1355,10 +1359,18 @@ class IsInf(FixedLogicalComparison):
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError() raise NotImplementedError()
# Discrete type can never be inf
if node.inputs[0].type in discrete_types:
return "%(z)s = false;" % locals()
# Note that the C isinf returns -1 for -Inf and +1 for +Inf, while # Note that the C isinf returns -1 for -Inf and +1 for +Inf, while
# numpy simply returns True: we mimic numpy's behavior here, thus # numpy simply returns True: we mimic numpy's behavior here, thus
# the absolute value. # the absolute value.
return "%(z)s = abs(isinf(%(x)s));" % locals() return "%(z)s = abs(isinf(%(x)s));" % locals()
def c_code_cache_version(self):
scalarop_version = super(IsInf, self).c_code_cache_version()
return tuple(scalarop_version) + (3,)
isinf = IsInf() isinf = IsInf()
......
...@@ -1825,12 +1825,39 @@ def neq(a, b): ...@@ -1825,12 +1825,39 @@ def neq(a, b):
def isnan(a): def isnan(a):
"""isnan(a)""" """isnan(a)"""
# Rename isnan to isnan_ to allow to bypass it when not needed.
# glibc 2.23 don't allow isnan on int, so we remove it from the graph.
isnan_ = isnan
def isnan(a):
"""isnan(a)"""
a = as_tensor_variable(a)
if a.dtype in discrete_dtypes:
return alloc(numpy.asarray(False, dtype="bool"),
*[a.shape[i] for i in range(a.ndim)])
return isnan_(a)
@_scal_elemwise @_scal_elemwise
def isinf(a): def isinf(a):
"""isinf(a)""" """isinf(a)"""
# Rename isnan to isnan_ to allow to bypass it when not needed.
# glibc 2.23 don't allow isnan on int, so we remove it from the graph.
isinf_ = isinf
def isinf(a):
"""isinf(a)"""
a = as_tensor_variable(a)
if a.dtype in discrete_dtypes:
return alloc(numpy.asarray(False, dtype="bool"),
*[a.shape[i] for i in range(a.ndim)])
return isinf_(a)
def allclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False): def allclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
""" """
Implement Numpy's ``allclose`` on tensors. Implement Numpy's ``allclose`` on tensors.
......
...@@ -2860,6 +2860,21 @@ def test_nan_inf_constant_signature(): ...@@ -2860,6 +2860,21 @@ def test_nan_inf_constant_signature():
assert f(numpy.nan) == 0 assert f(numpy.nan) == 0
def test_isnan():
for x in [tensor.matrix(), tensor.imatrix(), tensor.matrix(dtype='bool')]:
y = tensor.isnan(x)
assert isinstance(y.owner.op, tensor.Elemwise) == (
x.dtype not in tensor.discrete_dtypes)
assert y.dtype == 'bool'
# Test c code generator even for int type.
y = tensor.isnan_(x)
assert isinstance(y.owner.op, tensor.Elemwise)
assert y.dtype == 'bool'
f = theano.function([x], y, allow_input_downcast=True)
f([[0, 1, 2]])
class T_Shape(unittest.TestCase): class T_Shape(unittest.TestCase):
def test_basic0(self): def test_basic0(self):
s = shape(numpy.ones((5, 3))) s = shape(numpy.ones((5, 3)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论