提交 530999f7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make GammaLn work on the GPU

上级 f4c7ecfb
...@@ -8,7 +8,7 @@ from six.moves import StringIO, xrange ...@@ -8,7 +8,7 @@ from six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.scalar import Scalar, Composite from theano.scalar import Scalar, Composite
from theano.tensor.elemwise import (Elemwise, DimShuffle, CAReduceDtype) from theano.tensor.elemwise import (Elemwise, DimShuffle, CAReduceDtype)
from theano.scalar.basic_scipy import Erfinv, Erfcinv from theano.scalar.basic_scipy import Erfinv, Erfcinv, GammaLn
from theano.scalar.basic import upgrade_to_float_no_complex, complex_types from theano.scalar.basic import upgrade_to_float_no_complex, complex_types
try: try:
...@@ -2493,6 +2493,13 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype): ...@@ -2493,6 +2493,13 @@ class GpuCAReduceCuda(GpuKernelBase, HideC, CAReduceDtype):
return kernels return kernels
class GpuGammaLn(GammaLn):
def c_headers(self):
return ['math_functions.h']
gpu_gammaln = GpuGammaLn(upgrade_to_float_no_complex, name='gpu_gammaln')
class GpuErfinv(Erfinv): class GpuErfinv(Erfinv):
""" """
Inverse error function for GPU. Inverse error function for GPU.
...@@ -2512,6 +2519,7 @@ class GpuErfinv(Erfinv): ...@@ -2512,6 +2519,7 @@ class GpuErfinv(Erfinv):
# For consistency of CPU and GPU ops, we wrap the CUDA erfinv in the following conditions # For consistency of CPU and GPU ops, we wrap the CUDA erfinv in the following conditions
# to ensure that GPU op returns the same values as CPU op. # to ensure that GPU op returns the same values as CPU op.
return "%(z)s = (%(x)s <= -1) ? erfinv(-1.0): ((%(x)s >= 1) ? erfinv(1.0): erfinv(%(x)s));" % locals() return "%(z)s = (%(x)s <= -1) ? erfinv(-1.0): ((%(x)s >= 1) ? erfinv(1.0): erfinv(%(x)s));" % locals()
gpu_erfinv = GpuErfinv(upgrade_to_float_no_complex, name='gpu_erfinv')
class GpuErfcinv(Erfcinv): class GpuErfcinv(Erfcinv):
...@@ -2533,8 +2541,6 @@ class GpuErfcinv(Erfcinv): ...@@ -2533,8 +2541,6 @@ class GpuErfcinv(Erfcinv):
# For consistency of CPU and GPU ops, we wrap the CUDA erfcinv in the following conditions # For consistency of CPU and GPU ops, we wrap the CUDA erfcinv in the following conditions
# to ensure that GPU op returns the same values as CPU op. # to ensure that GPU op returns the same values as CPU op.
return "%(z)s = (%(x)s <= 0) ? erfcinv(0.0): ((%(x)s >= 2) ? erfcinv(2.0): erfcinv(%(x)s));" % locals() return "%(z)s = (%(x)s <= 0) ? erfcinv(0.0): ((%(x)s >= 2) ? erfcinv(2.0): erfcinv(%(x)s));" % locals()
gpu_erfinv = GpuErfinv(upgrade_to_float_no_complex, name='gpu_erfinv')
gpu_erfcinv = GpuErfcinv(upgrade_to_float_no_complex, name='gpu_erfcinv') gpu_erfcinv = GpuErfcinv(upgrade_to_float_no_complex, name='gpu_erfcinv')
......
...@@ -19,7 +19,7 @@ from theano.ifelse import IfElse ...@@ -19,7 +19,7 @@ from theano.ifelse import IfElse
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
from theano.scalar.basic import Scalar, Pow, Cast from theano.scalar.basic import Scalar, Pow, Cast
from theano.scalar.basic_scipy import Erfinv, Erfcinv from theano.scalar.basic_scipy import Erfinv, Erfcinv, GammaLn
from theano.scan_module import scan_utils, scan_op, scan_opt from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.nnet import bn from theano.tensor.nnet import bn
...@@ -61,7 +61,7 @@ from .nnet import (gpu_crossentropy_softmax_1hot_with_bias_dx, ...@@ -61,7 +61,7 @@ from .nnet import (gpu_crossentropy_softmax_1hot_with_bias_dx,
gpu_crossentropy_softmax_argmax_1hot_with_bias, gpu_crossentropy_softmax_argmax_1hot_with_bias,
gpu_softmax_with_bias, gpu_softmax) gpu_softmax_with_bias, gpu_softmax)
from .elemwise import (GpuElemwise, GpuDimShuffle, GpuCAReduceCuda, from .elemwise import (GpuElemwise, GpuDimShuffle, GpuCAReduceCuda,
GpuCAReduceCPY, gpu_erfinv, gpu_erfcinv, GpuCAReduceCPY, gpu_erfinv, gpu_erfcinv, gpu_gammaln,
max_inputs_to_GpuElemwise) max_inputs_to_GpuElemwise)
from .subtensor import (GpuIncSubtensor, GpuSubtensor, from .subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedSubtensor, GpuAdvancedSubtensor,
...@@ -711,18 +711,16 @@ def local_gpua_elemwise(op, context_name, inputs, outputs): ...@@ -711,18 +711,16 @@ def local_gpua_elemwise(op, context_name, inputs, outputs):
have_opencl = True have_opencl = True
elif kind.startswith(b'cuda'): elif kind.startswith(b'cuda'):
have_cuda = True have_cuda = True
opname = False convert = {Erfinv: gpu_erfinv,
if isinstance(scal_op, Erfinv): Erfcinv: gpu_erfcinv,
opname = 'erfinv' GammaLn: gpu_gammaln}
if have_cuda:
scal_op = gpu_erfinv if scal_op.__class__ in convert:
elif isinstance(scal_op, Erfcinv): scal_op = convert[scal_op.__class__]
opname = 'erfcinv'
if have_cuda:
scal_op = gpu_erfcinv
if opname:
if have_opencl: if have_opencl:
_logger.warning('Function "%s" is not supported with OpenCL. Use "device=cuda" instead.' % opname) _logger.warning(
'Function "%s" is not supported with OpenCL. Use "device=cuda" instead.' %
scal_op)
if not have_cuda: if not have_cuda:
return None return None
res = GpuElemwise(scal_op, name=name, res = GpuElemwise(scal_op, name=name,
......
...@@ -271,11 +271,17 @@ class GammaLn(UnaryScalarOp): ...@@ -271,11 +271,17 @@ class GammaLn(UnaryScalarOp):
z, = out z, = out
# no c code for complex # no c code for complex
# [u]int* will be casted to float64 before computation # [u]int* will be casted to float64 before computation
if x.type in complex_types: if node.inputs[0].type in complex_types:
raise NotImplementedError( raise NotImplementedError(
'gammaln complex c code is not implemented') 'gammaln complex c code is not implemented')
return """%(z)s = # For some reason, on the GPU, uint64 inputs don't get casted
lgamma(%(x)s);""" % locals() # automatically to float64. This make the compilation crash
dtype = ""
if node.outputs[0].dtype == 'float64':
dtype = "(double)"
elif node.outputs[0].dtype == 'float32':
dtype = "(float)"
return """%(z)s = lgamma(%(dtype)s%(x)s);""" % locals()
gammaln = GammaLn(upgrade_to_float, name='gammaln') gammaln = GammaLn(upgrade_to_float, name='gammaln')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论