提交 43ddbad3 authored 作者: Frederic Bastien's avatar Frederic Bastien

Do the same fix for scalar

上级 f6e913a9
......@@ -1334,13 +1334,17 @@ class IsNan(FixedLogicalComparison):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
# Discrete type can never be nan
if node.inputs[0].type in discrete_types:
return "%(z)s = false;" % locals()
# Windows tries to be different and sometimes return -1, but we want
# to be consistent with numpy (which returns True), hence the "abs".
return "%(z)s = abs(isnan(%(x)s));" % locals()
def c_code_cache_version(self):
scalarop_version = super(IsNan, self).c_code_cache_version()
return tuple(scalarop_version) + (2,)
return tuple(scalarop_version) + (3,)
isnan = IsNan()
......@@ -1355,10 +1359,18 @@ class IsInf(FixedLogicalComparison):
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
# Discrete type can never be inf
if node.inputs[0].type in discrete_types:
return "%(z)s = false;" % locals()
# Note that the C isinf returns -1 for -Inf and +1 for +Inf, while
# numpy simply returns True: we mimic numpy's behavior here, thus
# the absolute value.
return "%(z)s = abs(isinf(%(x)s));" % locals()
def c_code_cache_version(self):
scalarop_version = super(IsInf, self).c_code_cache_version()
return tuple(scalarop_version) + (3,)
isinf = IsInf()
......
......@@ -2867,6 +2867,13 @@ def test_isnan():
x.dtype not in tensor.discrete_dtypes)
assert y.dtype == 'bool'
# Test c code generator even for int type.
y = tensor.isnan_(x)
assert isinstance(y.owner.op, tensor.Elemwise)
assert y.dtype == 'bool'
f = theano.function([x], y, allow_input_downcast=True)
f([[0, 1, 2]])
class T_Shape(unittest.TestCase):
def test_basic0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论