提交 afb65d0b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Limit the amount of entropy we use for generating float16 numbers to avoid oversampling 0.0.

We are effectively "wasting" 16 bits of entropy per sample taken, but this should be otherwise fair.
上级 2708c77e
...@@ -788,15 +788,20 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -788,15 +788,20 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
write = write_w(self.output_type.dtype) write = write_w(self.output_type.dtype)
if self.output_type.dtype == 'float16': if self.output_type.dtype == 'float16':
otype = 'ga_half' otype = 'ga_half'
# Same as for float32 # limit the values of the state that we use.
NORM = '4.6566126e-10f' # numpy.float32(1.0/(2**31+65)) mask = '& 0xffff'
NORM = '1.5199e-05f' # numpy.float16(1.0/(2**16+130))
# this was determined by finding the biggest number such that
# numpy.float16(number * (M1 & 0xffff)) < 1.0
elif self.output_type.dtype == 'float32': elif self.output_type.dtype == 'float32':
otype = 'float' otype = 'float'
mask = ''
NORM = '4.6566126e-10f' # numpy.float32(1.0/(2**31+65)) NORM = '4.6566126e-10f' # numpy.float32(1.0/(2**31+65))
# this was determined by finding the biggest number such that # this was determined by finding the biggest number such that
# numpy.float32(number * M1) < 1.0 # numpy.float32(number * M1) < 1.0
elif self.output_type.dtype == 'float64': elif self.output_type.dtype == 'float64':
otype = 'double' otype = 'double'
mask = ''
NORM = '4.656612873077392578125e-10' NORM = '4.656612873077392578125e-10'
else: else:
raise ValueError('Unsupported data type for output', raise ValueError('Unsupported data type for output',
...@@ -863,11 +868,11 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -863,11 +868,11 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
x21 = y2; x21 = y2;
if (x11 <= x21) { if (x11 <= x21) {
sample_data[i] = %(write)s((x11 - x21 + M1) * %(NORM)s); sample_data[i] = %(write)s(((x11 - x21 + M1) %(mask)s) * %(NORM)s);
} }
else else
{ {
sample_data[i] = %(write)s((x11 - x21) * %(NORM)s); sample_data[i] = %(write)s(((x11 - x21) %(mask)s) * %(NORM)s);
} }
} }
...@@ -1001,7 +1006,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -1001,7 +1006,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (5, self.GpuKernelBase_version) return (6, self.GpuKernelBase_version)
def guess_n_streams(size, warn=False): def guess_n_streams(size, warn=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论