提交 88c9397d authored 作者: Frederic Bastien's avatar Frederic Bastien

Add c code to abs for uint* and bool

上级 14c79a2a
# Test that normaly could be outside gpuarray, to have all gpuarray # Test that normaly could be outside gpuarray, to have all gpuarray
# tests in the same directory, we put them here. # tests in the same directory, we put them here.
import numpy as np
import theano import theano
from theano import tensor from theano import tensor
...@@ -9,9 +10,12 @@ from .config import mode_with_gpu ...@@ -9,9 +10,12 @@ from .config import mode_with_gpu
def test_nan_guard_mode(): def test_nan_guard_mode():
x = tensor.vector(dtype='int64') # Also test that abs uint* and bool have c code.
for dtype in ['uint8', 'int64', 'bool']:
x = tensor.vector(dtype=dtype)
y = x + 1 y = x + 1
mode = NanGuardMode(nan_is_error=True, optimizer=mode_with_gpu.optimizer) mode = NanGuardMode(nan_is_error=True,
optimizer=mode_with_gpu.optimizer)
f = theano.function([x], y, mode=mode) f = theano.function([x], y, mode=mode)
theano.printing.debugprint(f) d = np.asarray([23, 7]).astype(dtype)
f([23, 7]) assert np.allclose(f(d), d + 1)
...@@ -2420,6 +2420,11 @@ class Abs(UnaryScalarOp): ...@@ -2420,6 +2420,11 @@ class Abs(UnaryScalarOp):
return "%(z)s = fabs(%(x)s);" % locals() return "%(z)s = fabs(%(x)s);" % locals()
if type in complex_types: if type in complex_types:
return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals() return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals()
if node.outputs[0].type == bool:
return "%(z)s = (%(x)s) ? 1 : 0;" % locals()
if type in uint_types:
# uint are always already absolute value.
return "%(z)s = %(x)s;" % locals()
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
abs_ = Abs(same_out) abs_ = Abs(same_out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论