提交 f462a1c0 authored 作者: lamblin's avatar lamblin

Merge pull request #1291 from gdesjardins/erfinv_to_erfinvgpu_rebased

make erfinv work on the gpu.
......@@ -9,6 +9,8 @@ import copy, logging, StringIO, sys
import numpy
from theano.scalar.basic import upgrade_to_float_no_complex, complex_types
from theano.scalar.basic_scipy import Erfinv
from theano import Apply, Constant, Op, Type, Variable
from theano import gof, scalar, tensor
......@@ -1021,3 +1023,25 @@ nd_collapse_[i]=0;
#print sio.getvalue()
return sio.getvalue()
class ErfinvGPU(Erfinv):
"""
Provides a c-code implementation of the inverse error function for GPU.
Note: We do not add this c_code to theano.scalar.basic_scipy.Erfinv, as we
currently rely on Nvidia's cublas library to provide the erfinv
c-implementation (which requires different c_headers). As it stands,
theano.scalar.basic_scipy.Erfinv does not have c_code as scipy does not
export the required C function
"""
def c_headers(self):
return ['math_functions.h', 'cublas_v2.h']
def c_code(self, node, name, inp, out, sub):
x, = inp
z, = out
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = erfinv(%(x)s);" % locals()
erfinv_gpu = ErfinvGPU(upgrade_to_float_no_complex, name='erfinv_gpu')
......@@ -33,6 +33,8 @@ from theano.sandbox.cuda.nnet import (
GpuCrossentropySoftmax1HotWithBiasDx,
GpuSoftmax, GpuSoftmaxWithBias)
from theano.sandbox.cuda.elemwise import SupportCodeError
from theano.scalar.basic_scipy import Erfinv
from theano.sandbox.cuda.elemwise import ErfinvGPU, erfinv_gpu
from theano.sandbox.cuda.var import CudaNdarrayConstant
from theano.scan_module import scan_utils, scan_op
from theano.tensor.blas import _is_real_vector, _is_real_matrix
......@@ -177,6 +179,10 @@ def local_gpu_elemwise_0(node):
if numpy.all([o.type.dtype == 'float32' for o in node.outputs]):
# Don't set any inplace pattern.
# gpu_inplace_elemwise_optimizer will do it later
if isinstance(node.op.scalar_op, Erfinv):
new_op = GpuElemwise(erfinv_gpu)
else:
try:
new_op = GpuElemwise(node.op.scalar_op)
except SupportCodeError:
......@@ -234,11 +240,16 @@ def local_gpu_elemwise_1(node):
elemwise_node = host_i.owner
# Don't set any inplace pattern.
# gpu_inplace_elemwise_optimizer will do it later
if isinstance(node.op.scalar_op, Erfinv):
new_op = GpuElemwise(erfinv_gpu)
else:
try:
new_op = GpuElemwise(elemwise_node.op.scalar_op)
except SupportCodeError:
# This happens when scalar_op requires support code
return False
if all([i.dtype == 'float32' for i in elemwise_node.inputs]):
gpu_elemwise = new_op(*[gpu_from_host(i)
for i in elemwise_node.inputs])
......
......@@ -17,6 +17,7 @@ if cuda.cuda_available == False:
from theano.sandbox.cuda import basic_ops
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.scalar.basic_scipy import erfinv
if theano.config.mode=='FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
......@@ -368,6 +369,18 @@ def test_incsubtensor_mixed():
client, idx = packed
assert isinstance(client.op, cuda.GpuFromHost)
def test_erfinvgpu():
""" Test that local_gpu_elemwise_0 replaces Erfinv with ErfinvGPU """
x = tensor.fmatrix()
f = theano.function([x], tensor.Elemwise(erfinv)(x), mode=mode_with_gpu)
f2 = theano.function([x], tensor.Elemwise(erfinv)(x), mode=mode_without_gpu)
assert isinstance(f.maker.fgraph.toposort()[1].op, cuda.GpuElemwise)
assert isinstance(f.maker.fgraph.toposort()[1].op.scalar_op, cuda.elemwise.ErfinvGPU)
xv=numpy.random.rand(7,8).astype('float32')
assert numpy.allclose(f(xv),f2(xv))
if __name__ == '__main__':
test_gpualloc()
test_opt_gpujoin_onlyajoin()
......
......@@ -78,6 +78,15 @@ erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
class Erfinv(UnaryScalarOp):
"""
Implements the inverse error function.
Note: This op can still be executed on GPU, despite not having c_code. When
running on GPU, sandbox.cuda.opt.local_gpu_elemwise_[0,1] replaces this op
with sandbox.cuda.elemwise.ErfinvGPU.
(TODO) Find a C implementation of erfinv for CPU.
"""
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfinv(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论