提交 88b49770 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6025 from nouiz/nanguardmode_int

[ENH] Speed up nanguardmode by not checking *int* dtype
...@@ -106,6 +106,8 @@ def contains_nan(arr, node=None, var=None): ...@@ -106,6 +106,8 @@ def contains_nan(arr, node=None, var=None):
""" """
if not _is_numeric_value(arr, var): if not _is_numeric_value(arr, var):
return False return False
elif getattr(arr, 'dtype', '') in T.discrete_dtypes:
return False
elif pygpu_available and isinstance(arr, GpuArray): elif pygpu_available and isinstance(arr, GpuArray):
return np.isnan(f_gpua_min(arr.reshape(arr.size))) return np.isnan(f_gpua_min(arr.reshape(arr.size)))
...@@ -139,6 +141,8 @@ def contains_inf(arr, node=None, var=None): ...@@ -139,6 +141,8 @@ def contains_inf(arr, node=None, var=None):
""" """
if not _is_numeric_value(arr, var): if not _is_numeric_value(arr, var):
return False return False
elif getattr(arr, 'dtype', '') in T.discrete_dtypes:
return False
elif pygpu_available and isinstance(arr, GpuArray): elif pygpu_available and isinstance(arr, GpuArray):
return (np.isinf(f_gpua_min(arr.reshape(arr.size))) or return (np.isinf(f_gpua_min(arr.reshape(arr.size))) or
np.isinf(f_gpua_max(arr.reshape(arr.size)))) np.isinf(f_gpua_max(arr.reshape(arr.size))))
......
# Test that normaly could be outside gpuarray, to have all gpuarray
# tests in the same directory, we put them here.
from __future__ import absolute_import, print_function, division
import numpy as np
import theano
from theano import tensor
from theano.compile.nanguardmode import NanGuardMode
from .config import mode_with_gpu
def test_nan_guard_mode():
# Also test that abs uint* and bool have c code.
for dtype in ['uint8', 'int64', 'bool']:
x = tensor.vector(dtype=dtype)
y = x + 1
mode = NanGuardMode(nan_is_error=True,
optimizer=mode_with_gpu.optimizer)
f = theano.function([x], y, mode=mode)
d = np.asarray([23, 7]).astype(dtype)
assert np.allclose(f(d), d + 1)
...@@ -2420,6 +2420,11 @@ class Abs(UnaryScalarOp): ...@@ -2420,6 +2420,11 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
if type in complex_types: if type in complex_types:
return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals() return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals()
if node.outputs[0].type == bool:
return "%(z)s = (%(x)s) ? 1 : 0;" % locals()
if type in uint_types:
# uint are always already absolute value.
return "%(z)s = %(x)s;" % locals()
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
abs_ = Abs(same_out) abs_ = Abs(same_out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论