提交 4a2e513e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make GPUA_mrg_uniform work with f16.

上级 a0ca5c47
...@@ -28,6 +28,7 @@ if cuda_available: ...@@ -28,6 +28,7 @@ if cuda_available:
from theano.sandbox.gpuarray.basic_ops import GpuKernelBase, Kernel from theano.sandbox.gpuarray.basic_ops import GpuKernelBase, Kernel
from theano.sandbox.gpuarray.type import GpuArrayType from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.fp16_help import write_w
def matVecModM(A, s, m): def matVecModM(A, s, m):
...@@ -777,6 +778,7 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp): ...@@ -777,6 +778,7 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp):
class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
# GpuArray version # GpuArray version
_f16_ok = True
@classmethod @classmethod
def new(cls, rstate, ndim, dtype, size): def new(cls, rstate, ndim, dtype, size):
...@@ -790,14 +792,22 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -790,14 +792,22 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
return super(GPUA_mrg_uniform, self).c_headers() + ['numpy_compat.h'] return super(GPUA_mrg_uniform, self).c_headers() + ['numpy_compat.h']
def gpu_kernels(self, node, name): def gpu_kernels(self, node, name):
if self.output_type.dtype == 'float32': write = write_w(self.output_type.dtype)
if self.output_type.dtype == 'float16':
otype = 'ga_half'
# Same as for float32
NORM = '4.6566126e-10f' # numpy.float32(1.0/(2**31+65))
elif self.output_type.dtype == 'float32':
otype = 'float' otype = 'float'
NORM = '4.6566126e-10f' # numpy.float32(1.0/(2**31+65)) NORM = '4.6566126e-10f' # numpy.float32(1.0/(2**31+65))
# this was determined by finding the biggest number such that # this was determined by finding the biggest number such that
# numpy.float32(number * M1) < 1.0 # numpy.float32(number * M1) < 1.0
else: elif self.output_type.dtype == 'float64':
otype = 'double' otype = 'double'
NORM = '4.656612873077392578125e-10' NORM = '4.656612873077392578125e-10'
else:
raise ValueError('Unsupported data type for output',
self.output_type.dtype)
code = """ code = """
KERNEL void mrg_uniform( KERNEL void mrg_uniform(
GLOBAL_MEM %(otype)s *sample_data, GLOBAL_MEM %(otype)s *sample_data,
...@@ -860,11 +870,11 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -860,11 +870,11 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
x21 = y2; x21 = y2;
if (x11 <= x21) { if (x11 <= x21) {
sample_data[i] = (x11 - x21 + M1) * %(NORM)s; sample_data[i] = %(write)s((x11 - x21 + M1) * %(NORM)s);
} }
else else
{ {
sample_data[i] = (x11 - x21) * %(NORM)s; sample_data[i] = %(write)s((x11 - x21) * %(NORM)s);
} }
} }
...@@ -896,17 +906,9 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -896,17 +906,9 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
o_type_num = numpy.asarray(0, dtype=self.output_type.dtype).dtype.num o_type_num = numpy.asarray(0, dtype=self.output_type.dtype).dtype.num
fail = sub['fail'] fail = sub['fail']
kname = self.gpu_kernels(node, nodename)[0].objvar kname = self.gpu_kernels(node, nodename)[0].objvar
otypecode = str(self.output_type.typecode)
if self.output_type.dtype == 'float32':
otype = 'float'
otypecode = 'GA_FLOAT'
else:
otype = 'double'
otypecode = 'GA_DOUBLE'
return """ return """
//////// <code generated by mrg_uniform>
size_t odims[%(ndim)s]; size_t odims[%(ndim)s];
unsigned int n_elements = 1; unsigned int n_elements = 1;
unsigned int n_streams; unsigned int n_streams;
...@@ -1003,12 +1005,10 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -1003,12 +1005,10 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
%(fail)s %(fail)s
} }
} }
//////// </ code generated by mrg_uniform>
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (3, self.GpuKernelBase_version) return (5, self.GpuKernelBase_version)
def guess_n_streams(size, warn=False): def guess_n_streams(size, warn=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论