提交 60c75959 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5314 from notoraptor/erfinv

Port erfinv and erfcinv to new backend.
......@@ -8,6 +8,8 @@ from six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined
from theano.scalar import Scalar, Composite
from theano.tensor.elemwise import (Elemwise, DimShuffle, CAReduceDtype)
from theano.scalar.basic_scipy import Erfinv, Erfcinv
from theano.scalar.basic import upgrade_to_float_no_complex, complex_types
try:
import pygpu
......@@ -2580,6 +2582,51 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
return kernels
class GpuErfinv(Erfinv):
"""
Inverse error function for GPU.
"""
def c_headers(self):
return ['math_functions.h', 'cublas_v2.h']
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
# NB: CUDA erfinv function (GPU op) returns NaN if x not in [-1;1],
# while `scipy.special.erfinv` (CPU op) returns an infinite (-inf if x < -1, +inf if x > 1).
# For consistency of CPU and GPU ops, we wrap the CUDA erfinv in the following conditions
# to ensure that GPU op returns the same values as CPU op.
return "%(z)s = (%(x)s <= -1) ? erfinv(-1.0): ((%(x)s >= 1) ? erfinv(1.0): erfinv(%(x)s));" % locals()
class GpuErfcinv(Erfcinv):
"""
Inverse complementary error function for GPU.
"""
def c_headers(self):
return ['math_functions.h', 'cublas_v2.h']
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
# NB: CUDA erfcinv function (GPU op) returns NaN if x not in [0;2],
# while `scipy.special.erfcinv` (CPU op) returns an infinite (+inf if x < 0, -inf if x > 2).
# For consistency of CPU and GPU ops, we wrap the CUDA erfcinv in the following conditions
# to ensure that GPU op returns the same values as CPU op.
return "%(z)s = (%(x)s <= 0) ? erfcinv(0.0): ((%(x)s >= 2) ? erfcinv(2.0): erfcinv(%(x)s));" % locals()
gpu_erfinv = GpuErfinv(upgrade_to_float_no_complex, name='gpu_erfinv')
gpu_erfcinv = GpuErfcinv(upgrade_to_float_no_complex, name='gpu_erfcinv')
# Caching GpuCAReduceCuda
def gpu_ca_reduce_cuda(scalar_op, axis=None, reduce_mask=None, dtype=None, acc_dtype=None,
pre_scalar_op=None):
......
......@@ -19,6 +19,7 @@ from theano.ifelse import IfElse
from theano.misc.ordered_set import OrderedSet
from theano.scalar.basic import Scalar, Pow, Cast
from theano.scalar.basic_scipy import Erfinv, Erfcinv
from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.nnet.conv import ConvOp
......@@ -60,7 +61,7 @@ from .nnet import (gpu_crossentropy_softmax_1hot_with_bias_dx,
gpu_softmax_with_bias, gpu_softmax)
from .elemwise import (GpuElemwise, GpuDimShuffle, GpuCAReduceCuda,
GpuCAReduceCPY, gpu_ca_reduce_cuda)
GpuCAReduceCPY, gpu_ca_reduce_cuda, gpu_erfinv, gpu_erfcinv)
from .subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedSubtensor,
GpuAdvancedSubtensor1,
......@@ -697,6 +698,28 @@ def local_gpua_elemwise(op, context_name, inputs, outputs):
name = 'Gpu' + name
if len(outputs) > 1:
return
have_cuda = False
have_opencl = False
if inputs and isinstance(inputs[0].type, GpuArrayType):
kind = inputs[0].type.context.kind
if kind.startswith(b'opencl'):
have_opencl = True
elif kind.startswith(b'cuda'):
have_cuda = True
opname = False
if isinstance(scal_op, Erfinv):
opname = 'erfinv'
if have_cuda:
scal_op = gpu_erfinv
elif isinstance(scal_op, Erfcinv):
opname = 'erfcinv'
if have_cuda:
scal_op = gpu_erfcinv
if opname:
if have_opencl:
_logger.warning('Function "%s" is not supported with OpenCL. Use "device=cuda" instead.' % opname)
if not have_cuda:
return None
res = GpuElemwise(scal_op, name=name,
inplace_pattern=copy.copy(op.inplace_pattern),
nfunc_spec=op.nfunc_spec)
......
from __future__ import absolute_import, print_function, division
import numpy
import scipy.special
import theano
from theano import scalar, gof, tensor
from unittest import TestCase
from theano.tests.unittest_tools import SkipTest, assert_allclose
from theano.tensor.tests import test_elemwise
from .config import mode_with_gpu, test_ctx_name
from .config import mode_with_gpu, mode_without_gpu, test_ctx_name
from .test_basic_ops import rand_gpuarray
from ..elemwise import (GpuElemwise, GpuDimShuffle,
GpuCAReduceCuda, GpuCAReduceCPY)
GpuCAReduceCuda, GpuCAReduceCPY, GpuErfinv, GpuErfcinv)
from ..type import GpuArrayType, get_context
from pygpu import ndgpuarray as gpuarray
......@@ -52,6 +54,69 @@ def test_elemwise_pow():
assert_allclose(out, expected_out)
class TestMathErrorFunctions(TestCase):
dtypes = ["float64", "float32", "float16"]
default_arrays = {}
expected_erfinv_outputs = {}
expected_erfcinv_outputs = {}
def setUp(self):
# NB: erfinv is defined in ]-1;1[, and erfcinv is defined in ]0;2[,
# so we just take some values in an interval that covers both domains
# (this will also allow to test some values outside the domains).
# We take [-5;5[ by default and we concatenate it 1000 times
# to have the GPU ops run on large data.
default_array = [x / 10.0 for x in range(-50, 50)] * 1000
for dtype in self.dtypes:
numpy_array = numpy.asarray(default_array, dtype=dtype)
self.default_arrays[dtype] = numpy_array
self.expected_erfinv_outputs[dtype] = scipy.special.erfinv(numpy_array)
self.expected_erfcinv_outputs[dtype] = scipy.special.erfcinv(numpy_array)
def check_gpu_scalar_op(self, theano_function, scalar_optype):
for node in theano_function.maker.fgraph.apply_nodes:
if isinstance(node.op, GpuElemwise) and isinstance(node.op.scalar_op, scalar_optype):
return True
theano.printing.debugprint(theano_function)
return False
def test_elemwise_erfinv(self):
for dtype in self.dtypes:
vector = theano.tensor.vector(dtype=dtype)
output = theano.tensor.erfinv(vector)
f_host = theano.function([vector], output, name='HOST/erfinv/' + dtype, mode=mode_without_gpu)
f_gpu = theano.function([vector], output, name='GPU/erfinv/' + dtype, mode=mode_with_gpu)
assert len([n for n in f_host.maker.fgraph.apply_nodes if isinstance(n.op, GpuElemwise)]) == 0
if not theano.config.device.startswith('opencl'):
assert self.check_gpu_scalar_op(f_gpu, GpuErfinv), \
'Function graph does not contains scalar op "GpuErfinv".'
vector_val = self.default_arrays[dtype]
f_host(vector_val)
f_gpu(vector_val)
out_host = f_host(vector_val)
out_gpu = f_gpu(vector_val)
assert_allclose(out_host, out_gpu)
assert_allclose(self.expected_erfinv_outputs[dtype], out_gpu)
def test_elemwise_erfcinv(self):
for dtype in self.dtypes:
vector = theano.tensor.vector(dtype=dtype)
output = theano.tensor.erfcinv(vector)
f_host = theano.function([vector], output, name='HOST/erfcinv/' + dtype, mode=mode_without_gpu)
f_gpu = theano.function([vector], output, name='GPU/erfcinv/' + dtype, mode=mode_with_gpu)
assert len([n for n in f_host.maker.fgraph.apply_nodes if isinstance(n.op, GpuElemwise)]) == 0
if not theano.config.device.startswith('opencl'):
assert self.check_gpu_scalar_op(f_gpu, GpuErfcinv), \
'Function graph does not contains scalar op "GpuErfcinv".'
vector_val = self.default_arrays[dtype]
f_host(vector_val)
f_gpu(vector_val)
out_host = f_host(vector_val)
out_gpu = f_gpu(vector_val)
assert_allclose(out_host, out_gpu)
assert_allclose(self.expected_erfcinv_outputs[dtype], out_gpu)
class test_float16():
def test_composite_elemwise_float16(self):
w = theano.tensor.bvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论