提交 902c3972 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3402 from abergeron/fix_softmax

Fix GpuSoftmax and GpuSoftmaxWithBias for non-float32 operation.
...@@ -524,7 +524,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op): ...@@ -524,7 +524,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
gpu_crossentropy_softmax_1hot_with_bias_dx = GpuCrossentropySoftmax1HotWithBiasDx() gpu_crossentropy_softmax_1hot_with_bias_dx = GpuCrossentropySoftmax1HotWithBiasDx()
class GpuSoftmax (GpuKernelBase, Op): class GpuSoftmax(GpuKernelBase, Op):
""" """
Implement Softmax on the gpu. Implement Softmax on the gpu.
...@@ -541,7 +541,7 @@ class GpuSoftmax (GpuKernelBase, Op): ...@@ -541,7 +541,7 @@ class GpuSoftmax (GpuKernelBase, Op):
return shape return shape
def c_code_cache_version(self): def c_code_cache_version(self):
return (13,) + inline_softmax.code_version return (14,) + inline_softmax.code_version
def c_header_dirs(self): def c_header_dirs(self):
if pygpu.get_default_context().kind == 'opencl': if pygpu.get_default_context().kind == 'opencl':
...@@ -656,7 +656,8 @@ class GpuSoftmax (GpuKernelBase, Op): ...@@ -656,7 +656,8 @@ class GpuSoftmax (GpuKernelBase, Op):
work_sm = work_dtype(dtype_sm) work_sm = work_dtype(dtype_sm)
flags = Kernel.get_flags(dtype_x, dtype_sm) flags = Kernel.get_flags(dtype_x, dtype_sm)
type_x = gpuarray.dtype_to_ctype(dtype_x) type_x = gpuarray.dtype_to_ctype(dtype_x)
type_sm = gpuarray.dtype_to_ctype(work_sm) type_sm = gpuarray.dtype_to_ctype(dtype_sm)
type_acc = gpuarray.dtype_to_ctype(work_sm)
params = [ params = [
'uintp', 'uintp', 'uintp', 'uintp',
gpuarray.GpuArray, 'uintp', 'intp', 'intp', gpuarray.GpuArray, 'uintp', 'intp', 'intp',
...@@ -672,8 +673,8 @@ class GpuSoftmax (GpuKernelBase, Op): ...@@ -672,8 +673,8 @@ class GpuSoftmax (GpuKernelBase, Op):
'%s * sm' % type_sm, 'const ga_size offset_sm', '%s * sm' % type_sm, 'const ga_size offset_sm',
'const ga_ssize sm_s0', 'const ga_ssize sm_s1'], 'const ga_ssize sm_s0', 'const ga_ssize sm_s1'],
body=[ body=[
"extern __shared__ %s buf[]" % type_sm, "extern __shared__ %s buf[]" % type_acc,
"%s * buf2 = buf + N" % type_sm, "%s * buf2 = buf + N" % type_acc,
"x = (const %s *)(((char *)x)+offset_x)" % type_x, "x = (const %s *)(((char *)x)+offset_x)" % type_x,
"sm = (%s *)(((char *)sm)+offset_sm)" % type_sm, "sm = (%s *)(((char *)sm)+offset_sm)" % type_sm,
"for (int blockIDX = blockIdx.x; blockIDX < M;" "for (int blockIDX = blockIdx.x; blockIDX < M;"
...@@ -683,8 +684,8 @@ class GpuSoftmax (GpuKernelBase, Op): ...@@ -683,8 +684,8 @@ class GpuSoftmax (GpuKernelBase, Op):
"buf2[tx] = buf[tx]", "buf2[tx] = buf[tx]",
"}", "}",
"__syncthreads()", "__syncthreads()",
inline_softmax('N', 'buf', 'buf2', inline_softmax('N', 'buf', 'buf2', 'threadIdx.x',
'threadIdx.x', 'blockDim.x', work_sm), 'blockDim.x', dtype=work_sm),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){", "for (int tx = threadIdx.x; tx< N; tx += blockDim.x){",
# This set all value correctly # This set all value correctly
"sm[blockIDX * sm_s0 + tx * sm_s1] = %s(buf[tx])" % write_sm, "sm[blockIDX * sm_s0 + tx * sm_s1] = %s(buf[tx])" % write_sm,
...@@ -703,7 +704,7 @@ class GpuSoftmax (GpuKernelBase, Op): ...@@ -703,7 +704,7 @@ class GpuSoftmax (GpuKernelBase, Op):
'%s * sm' % type_sm, 'const ga_size offset_sm', '%s * sm' % type_sm, 'const ga_size offset_sm',
'const ga_ssize sm_s0', 'const ga_ssize sm_s1'], 'const ga_ssize sm_s0', 'const ga_ssize sm_s1'],
body=[ body=[
"extern __shared__ %s buf[]" % type_sm, "extern __shared__ %s buf[]" % type_acc,
"x = (const %s *)(((char *)x)+offset_x)" % type_x, "x = (const %s *)(((char *)x)+offset_x)" % type_x,
"sm = (%s *)(((char *)sm)+offset_sm)" % type_sm, "sm = (%s *)(((char *)sm)+offset_sm)" % type_sm,
"for (int blockIDX = blockIdx.x; blockIDX < M;" "for (int blockIDX = blockIdx.x; blockIDX < M;"
...@@ -745,7 +746,7 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op): ...@@ -745,7 +746,7 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op):
return [shape[0]] return [shape[0]]
def c_code_cache_version(self): def c_code_cache_version(self):
return (12,) + inline_softmax.code_version return (13,) + inline_softmax.code_version
def c_header_dirs(self): def c_header_dirs(self):
if pygpu.get_default_context().kind == 'opencl': if pygpu.get_default_context().kind == 'opencl':
...@@ -880,7 +881,8 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op): ...@@ -880,7 +881,8 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op):
flags = Kernel.get_flags(dtype_x, dtype_b, dtype_sm) flags = Kernel.get_flags(dtype_x, dtype_b, dtype_sm)
type_x = gpuarray.dtype_to_ctype(dtype_x) type_x = gpuarray.dtype_to_ctype(dtype_x)
type_b = gpuarray.dtype_to_ctype(dtype_b) type_b = gpuarray.dtype_to_ctype(dtype_b)
type_sm = gpuarray.dtype_to_ctype(work_sm) type_sm = gpuarray.dtype_to_ctype(dtype_sm)
type_acc = gpuarray.dtype_to_ctype(work_sm)
params = [ params = [
'uintp', 'uintp', 'uintp', 'uintp',
gpuarray.GpuArray, 'uintp', 'intp', 'intp', gpuarray.GpuArray, 'uintp', 'intp', 'intp',
...@@ -899,8 +901,8 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op): ...@@ -899,8 +901,8 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op):
'%s * sm' % type_sm, 'const ga_size offset_sm', '%s * sm' % type_sm, 'const ga_size offset_sm',
'const ga_ssize sm_s0', 'const ga_ssize sm_s1'], 'const ga_ssize sm_s0', 'const ga_ssize sm_s1'],
body=[ body=[
"extern __shared__ %s buf[]" % type_sm, "extern __shared__ %s buf[]" % type_acc,
"%s * buf2 = buf + N" % type_sm, "%s * buf2 = buf + N" % type_acc,
"x = (const %s *)(((char *)x)+offset_x)" % type_x, "x = (const %s *)(((char *)x)+offset_x)" % type_x,
"b = (const %s *)(((char *)b)+offset_b)" % type_b, "b = (const %s *)(((char *)b)+offset_b)" % type_b,
"sm = (%s *)(((char *)sm)+offset_sm)" % type_sm, "sm = (%s *)(((char *)sm)+offset_sm)" % type_sm,
...@@ -933,7 +935,7 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op): ...@@ -933,7 +935,7 @@ class GpuSoftmaxWithBias (GpuKernelBase, Op):
'%s * sm' % type_sm, 'const ga_size offset_sm', '%s * sm' % type_sm, 'const ga_size offset_sm',
'const ga_ssize sm_s0', 'const ga_ssize sm_s1'], 'const ga_ssize sm_s0', 'const ga_ssize sm_s1'],
body=[ body=[
"extern __shared__ %s buf[]" % type_sm, "extern __shared__ %s buf[]" % type_acc,
"x = (const %s *)(((char *)x)+offset_x)" % type_x, "x = (const %s *)(((char *)x)+offset_x)" % type_x,
"b = (const %s *)(((char *)b)+offset_b)" % type_b, "b = (const %s *)(((char *)b)+offset_b)" % type_b,
"sm = (%s *)(((char *)sm)+offset_sm)" % type_sm, "sm = (%s *)(((char *)sm)+offset_sm)" % type_sm,
......
...@@ -85,13 +85,9 @@ def test_GpuCrossentropySoftmaxArgmax1HotWithBias(): ...@@ -85,13 +85,9 @@ def test_GpuCrossentropySoftmaxArgmax1HotWithBias():
gout = classify_gpu(yy, b_values, dot_value) gout = classify_gpu(yy, b_values, dot_value)
assert len(out) == len(gout) == 3 assert len(out) == len(gout) == 3
assert numpy.allclose(out[0], gout[0]) utt.assert_allclose(out[0], gout[0])
assert numpy.allclose(out[2], gout[2], atol=3e-6), numpy.absolute( utt.assert_allclose(out[2], gout[2], atol=3e-6)
gout[2] - out[2]).max() utt.assert_allclose(out[1], gout[1])
assert numpy.allclose(out[1], gout[1]), [(id, out[1][id], gout[1][id], val)
for id, val in enumerate(out[1] -
gout[1])
if val != 0]
def test_GpuCrossentropySoftmax1HotWithBiasDx(): def test_GpuCrossentropySoftmax1HotWithBiasDx():
...@@ -162,6 +158,14 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx(): ...@@ -162,6 +158,14 @@ def test_GpuCrossentropySoftmax1HotWithBiasDx():
rtol, atol) rtol, atol)
def test_softmax_with_bias_float16():
softmax_with_bias_unittest_template(dtypeInput='float16',
dtypeBias='float32')
softmax_with_bias_unittest_template(dtypeInput='float16',
dtypeBias='float16')
softmax_with_bias_unittest_template(dtypeInput='float32',
dtypeBias='float16')
def test_softmax_with_bias_float32(): def test_softmax_with_bias_float32():
softmax_with_bias_unittest_template(dtypeInput='float32', softmax_with_bias_unittest_template(dtypeInput='float32',
dtypeBias='float32') dtypeBias='float32')
...@@ -178,52 +182,36 @@ def test_softmax_with_bias_float64(): ...@@ -178,52 +182,36 @@ def test_softmax_with_bias_float64():
def softmax_with_bias_unittest_template(dtypeInput, dtypeBias): def softmax_with_bias_unittest_template(dtypeInput, dtypeBias):
""" """
This is basic test for GpuSoftmaxWithBias with float64 variables This is a basic test for GpuSoftmaxWithBias.
We check that we loop when their is too much block We check that we loop when there are too many blocks.
TODO: check that we loop when their is too much thread.(THIS IS TODO: check that we loop when there are too many threads. (THIS IS
NOT IMPLEMENTED) NOT IMPLEMENTED)
""" """
assert dtypeInput in ['float32', 'float64'] x = T.matrix('x', dtype=dtypeInput)
assert dtypeBias in ['float32', 'float64'] b = T.vector('b', dtype=dtypeBias)
if dtypeInput == 'float32': z = T.nnet.softmax_with_bias(x, b)
x = T.fmatrix('x')
elif dtypeInput == 'float64':
x = T.dmatrix('x')
# We can't use zeros_like(x[0,::]) as this don't allow to test with
# 0 shape
if dtypeBias == 'float32':
z = T.nnet.softmax_with_bias(x, T.arange(x.shape[1] * 2,
dtype='float32')[::2])
elif dtypeBias == 'float64':
z = T.nnet.softmax_with_bias(x, T.arange(x.shape[1] * 2,
dtype='float64')[::2])
f = theano.function([x], z, mode=mode_without_gpu) f = theano.function([x, b], z, mode=mode_without_gpu)
f_gpu = theano.function([x], z, mode=mode_with_gpu) f_gpu = theano.function([x, b], z, mode=mode_with_gpu)
assert f.maker.fgraph.toposort()[-1].op == T.nnet.softmax_with_bias assert f.maker.fgraph.toposort()[-1].op == T.nnet.softmax_with_bias
assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op,
GpuSoftmaxWithBias) GpuSoftmaxWithBias)
def cmp(n, m): def cmp(n, m):
# print "test_softmax",n,m data = numpy.random.uniform(1e-7, 1, (n, m)).astype(dtype=dtypeInput)
if dtypeInput == 'float32': b_data = numpy.random.uniform(1e-7, 1, (m,)).astype(dtype=dtypeBias)
data = numpy.arange(n * m, dtype='float32').reshape(n, m)
elif dtypeInput == 'float64':
data = numpy.arange(n * m, dtype='float64').reshape(n, m)
out = f(data) out = f(data, b_data)
gout = f_gpu(data) gout = f_gpu(data, b_data)
assert numpy.allclose(out, gout), numpy.absolute(out - gout) utt.assert_allclose(out, gout)
cmp(2, 5) cmp(2, 5)
# we need to test n>32*1024 to check that we make the block loop. # we need to test n>32*1024 to check that we make the block loop.
cmp(2 << 15, 5) cmp(2 << 15, 5)
cmp(4074, 400) cmp(4074, 400)
cmp(0, 10)
cmp(784, 784) cmp(784, 784)
cmp(4, 1000) cmp(4, 1000)
cmp(4, 1024) cmp(4, 1024)
...@@ -237,51 +225,43 @@ def softmax_with_bias_unittest_template(dtypeInput, dtypeBias): ...@@ -237,51 +225,43 @@ def softmax_with_bias_unittest_template(dtypeInput, dtypeBias):
cmp(128, 64 * 1024) cmp(128, 64 * 1024)
def test_softmax_float16():
softmax_unittest_template('float16')
def test_softmax_float32(): def test_softmax_float32():
softmax_unittest_template('float32') softmax_unittest_template('float32')
def test_softmax_float64(): def test_softmax_float64():
softmax_unittest_template('float64') softmax_unittest_template('float64')
def softmax_unittest_template(dtypeInput): def softmax_unittest_template(dtypeInput):
""" """
This is basic test for GpuSoftmax with float64 variables This is basic test for GpuSoftmax.
We check that we loop when their is too much block We check that we loop when their is too much block
We use slower code when there isn't enough shared memory We use slower code when there isn't enough shared memory
""" """
assert dtypeInput in ['float32', 'float64'] x = T.matrix('x', dtype=dtypeInput)
if dtypeInput == 'float32':
x = T.fmatrix('x')
elif dtypeInput == 'float64':
x = T.dmatrix('x')
z = T.nnet.softmax(x) z = T.nnet.softmax(x)
mode = mode_with_gpu.excluding('cudnn')
f = theano.function([x], z, mode=mode_without_gpu) f = theano.function([x], z, mode=mode_without_gpu)
f_gpu = theano.function([x], z, mode=mode) f_gpu = theano.function([x], z, mode=mode_wo_cudnn)
assert f.maker.fgraph.toposort()[-1].op == T.nnet.softmax_op assert f.maker.fgraph.toposort()[-1].op == T.nnet.softmax_op
assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op,
GpuSoftmax) GpuSoftmax)
def cmp(n, m): def cmp(n, m):
if dtypeInput == 'float32': data = numpy.random.uniform(0, 1, (n, m)).astype(dtype=dtypeInput)
data = numpy.arange(n * m, dtype='float32').reshape(n, m)
elif dtypeInput == 'float64':
data = numpy.arange(n * m, dtype='float64').reshape(n, m)
out = f(data) out = f(data)
gout = f_gpu(data) gout = f_gpu(data)
assert numpy.allclose(out, gout), numpy.absolute(out - gout) utt.assert_allclose(out, gout)
# we need to test n>32*1024 to check that we make the block loop. # we need to test n>32*1024 to check that we make the block loop.
cmp(2, 5) cmp(2, 5)
cmp(2 << 15, 5) cmp(2 << 15, 5)
cmp(4074, 400) cmp(4074, 400)
cmp(0, 10)
cmp(784, 784) cmp(784, 784)
cmp(4, 1000) cmp(4, 1000)
cmp(4, 1024) cmp(4, 1024)
...@@ -350,7 +330,7 @@ class test_SoftMax(unittest.TestCase): ...@@ -350,7 +330,7 @@ class test_SoftMax(unittest.TestCase):
data = numpy.arange(n * m, dtype='float32').reshape(n, m) data = numpy.arange(n * m, dtype='float32').reshape(n, m)
out = f(data) out = f(data)
gout = f_gpu(data) gout = f_gpu(data)
assert numpy.allclose(out, gout), numpy.absolute(out - gout) utt.assert_allclose(out, gout)
def _check_types(self, graph, graph_gpu, f_type, f_gpu_type): def _check_types(self, graph, graph_gpu, f_type, f_gpu_type):
assert isinstance(graph.maker.fgraph.toposort()[-1].op, f_type) assert isinstance(graph.maker.fgraph.toposort()[-1].op, f_type)
......
...@@ -162,13 +162,7 @@ class GpuArrayType(Type): ...@@ -162,13 +162,7 @@ class GpuArrayType(Type):
return tensor.TensorType.values_eq_approx( return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf, an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol) allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
narrow = 'float32', 'complex64' atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol_ = theano.tensor.basic.float32_atol
rtol_ = theano.tensor.basic.float32_rtol
else:
atol_ = theano.tensor.basic.float64_atol
rtol_ = theano.tensor.basic.float64_rtol
if rtol is not None: if rtol is not None:
rtol_ = rtol rtol_ = rtol
if atol is not None: if atol is not None:
......
...@@ -459,18 +459,29 @@ if int(config.tensor.cmp_sloppy) > 1: ...@@ -459,18 +459,29 @@ if int(config.tensor.cmp_sloppy) > 1:
# When config.tensor.cmp_sloppy>1 we are even more sloppy. This is # When config.tensor.cmp_sloppy>1 we are even more sloppy. This is
# useful to test the GPU as they don't use extended precision and # useful to test the GPU as they don't use extended precision and
# this cause some difference bigger then the normal sloppy. # this cause some difference bigger then the normal sloppy.
float16_atol = 5e-3
float16_rtol = 1e-2
float32_atol = 5e-4 float32_atol = 5e-4
float32_rtol = 1e-3 float32_rtol = 1e-3
float64_rtol = 1e-4 float64_rtol = 1e-4
float64_atol = 1e-3 float64_atol = 1e-3
elif int(config.tensor.cmp_sloppy): elif int(config.tensor.cmp_sloppy):
float16_atol = 1e-3
float16_rtol = 5e-3
float32_atol = 1e-4 float32_atol = 1e-4
float32_rtol = 1e-3 float32_rtol = 1e-3
float64_rtol = 1e-4 float64_rtol = 1e-4
float64_atol = 1e-3 float64_atol = 1e-3
else: else:
# If you change those value in test don't forget to put them back # If you change those value in test don't forget to put them back
# when the test end. Don't forget the case when the test fail. # when the test end. Don't forget the case when the test fail.
float16_atol = 5e-4
float16_rtol = 5e-4
float32_atol = 1e-5 float32_atol = 1e-5
float32_rtol = 1e-5 float32_rtol = 1e-5
...@@ -481,16 +492,25 @@ else: ...@@ -481,16 +492,25 @@ else:
float64_rtol = 1.0000000000000001e-06 float64_rtol = 1.0000000000000001e-06
def _get_atol_rtol(a, b):
tiny = ('float16',)
narrow = ('float32', 'complex64')
if (str(a.dtype) in tiny) or (str(b.dtype) in tiny):
atol = float16_atol
rtol = float16_rtol
elif (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol = float32_atol
rtol = float32_rtol
else:
atol = float64_atol
rtol = float64_rtol
return atol, rtol
def _allclose(a, b, rtol=None, atol=None): def _allclose(a, b, rtol=None, atol=None):
a = numpy.asarray(a) a = numpy.asarray(a)
b = numpy.asarray(b) b = numpy.asarray(b)
narrow = 'float32', 'complex64' atol_, rtol_ = _get_atol_rtol(a, b)
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol_ = float32_atol
rtol_ = float32_rtol
else:
atol_ = float64_atol
rtol_ = float64_rtol
if rtol is not None: if rtol is not None:
rtol_ = rtol rtol_ = rtol
if atol is not None: if atol is not None:
......
...@@ -81,10 +81,22 @@ class SoftmaxWithBias(gof.Op): ...@@ -81,10 +81,22 @@ class SoftmaxWithBias(gof.Op):
# sm[i] *= 1.0 / numpy.sum(sm[i]) # sm[i] *= 1.0 / numpy.sum(sm[i])
# output_storage[0][0] = sm # output_storage[0][0] = sm
if x.size == 0:
# Numpy doesn't like the max of a zero-sized object.
output_storage[0][0] = numpy.zeros(x.shape, dtype=x.dtype)
return
x_dtype = x.dtype
# Perform computations in float32 otherwise the result is too imprecise
if x.dtype == 'float16':
x = x.astype('float32')
x_plus_b = x + b[None, :] x_plus_b = x + b[None, :]
e_x = numpy.exp(x_plus_b - x_plus_b.max(axis=1)[:, None]) e_x = numpy.exp(x_plus_b - x_plus_b.max(axis=1)[:, None])
e_x *= 1.0 / e_x.sum(axis=1)[:, None] e_x *= 1.0 / e_x.sum(axis=1)[:, None]
output_storage[0][0] = e_x # default for copy is True and we don't need a copy if the
# data type matches.
output_storage[0][0] = e_x.astype(x_dtype, copy=False)
def grad(self, inp, grads): def grad(self, inp, grads):
x, b = inp x, b = inp
......
...@@ -309,15 +309,7 @@ def str_diagnostic(expected, value, rtol, atol): ...@@ -309,15 +309,7 @@ def str_diagnostic(expected, value, rtol, atol):
print(ssio.getvalue(), file=sio) print(ssio.getvalue(), file=sio)
except Exception: except Exception:
pass pass
# Use the same formula as in _allclose to find the tolerance used atol_, rtol_ = T.basic._get_atol_rtol(expected, value)
narrow = 'float32', 'complex64'
if ((str(expected.dtype) in narrow) or
(str(value.dtype) in narrow)):
atol_ = T.basic.float32_atol
rtol_ = T.basic.float32_rtol
else:
atol_ = T.basic.float64_atol
rtol_ = T.basic.float64_rtol
if rtol is not None: if rtol is not None:
rtol_ = rtol rtol_ = rtol
if atol is not None: if atol is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论