提交 205adfa4 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5181 from nouiz/fix_gh-5077

[REG, CRASH] fix crash related to float16 introduced in gh-5077
...@@ -49,11 +49,6 @@ class GpuElemwise(HideC, Elemwise): ...@@ -49,11 +49,6 @@ class GpuElemwise(HideC, Elemwise):
nout = property(lambda self: self.scalar_op.nout) nout = property(lambda self: self.scalar_op.nout)
_f16_ok = True _f16_ok = True
def __init__(self, scalar_op, *args, **kwargs):
if isinstance(scalar_op, Composite):
scalar_op = scalar_op.clone_float32()
Elemwise.__init__(self, scalar_op, *args, **kwargs)
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -108,7 +103,14 @@ class GpuElemwise(HideC, Elemwise): ...@@ -108,7 +103,14 @@ class GpuElemwise(HideC, Elemwise):
inps, outs = self._get_vnames(node) inps, outs = self._get_vnames(node)
scal_v_ins = [get_scal(i.dtype)() for i in node.inputs] scal_v_ins = [get_scal(i.dtype)() for i in node.inputs]
fake_node = self.scalar_op.make_node(*scal_v_ins) # As float16 isn't a c type and most GPU don't compute on it,
# We convert the computation to float32, and let libgpuarray
# load in float16 and cast to float32 and do the reverse for
# the output.
scalar_op = self.scalar_op
if isinstance(scalar_op, (scalar.Cast, Composite)):
scalar_op = scalar_op.clone_float32()
fake_node = scalar_op.make_node(*scal_v_ins)
scal_v_out = fake_node.outputs scal_v_out = fake_node.outputs
assert len(scal_v_out) == len(node.outputs) assert len(scal_v_out) == len(node.outputs)
...@@ -116,6 +118,8 @@ class GpuElemwise(HideC, Elemwise): ...@@ -116,6 +118,8 @@ class GpuElemwise(HideC, Elemwise):
inps, outs, inps, outs,
dict(fail='return;')) dict(fail='return;'))
# If the following assert fail, then we need to update the
# code handler above.
assert 'npy_float16' not in kop assert 'npy_float16' not in kop
support_code = "" support_code = ""
......
...@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division ...@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division
import numpy import numpy
import theano import theano
from theano import scalar, gof from theano import scalar, gof, tensor
from theano.tests.unittest_tools import SkipTest, assert_allclose from theano.tests.unittest_tools import SkipTest, assert_allclose
from theano.tensor.tests import test_elemwise from theano.tensor.tests import test_elemwise
...@@ -52,6 +52,61 @@ def test_elemwise_pow(): ...@@ -52,6 +52,61 @@ def test_elemwise_pow():
assert_allclose(out, expected_out) assert_allclose(out, expected_out)
class test_float16():
def test_composite_elemwise_float16(self):
w = theano.tensor.bvector()
x = theano.tensor.vector(dtype='float16')
y = theano.tensor.fvector()
cz = tensor.tanh(x + tensor.cast(y, 'float16'))
o = (cz - cz**2 +
tensor.cast(x, 'int16') + tensor.cast(x, 'float32') +
tensor.cast(w, 'float16') -
tensor.constant(numpy.float16(1.0)))
theano.function([w, x, y], o, mode=mode_with_gpu)
v = theano.tensor.vector(dtype='uint8')
w = theano.tensor.vector(dtype='float16')
x = theano.tensor.vector(dtype='float16')
y = theano.tensor.vector(dtype='float16')
z = theano.tensor.vector(dtype='float16')
o = tensor.switch(v, tensor.mul(w, x, y), z)
theano.function([v, w, x, y, z], o, mode=mode_with_gpu)
def test_cast_float16(self):
f16 = theano.tensor.vector(dtype='float16')
f32 = theano.tensor.fvector()
i8 = theano.tensor.bvector()
f = theano.function([f16, f32, i8],
[f16.astype('float32'),
f32.astype('float16'),
f32.astype('float64'),
f16.astype('int8'),
f32.astype('int8'),
i8.astype('float16'),
i8.astype('float32')],
mode=mode_with_gpu)
d1 = (numpy.random.rand(4) * 10).astype('float16')
d2 = (numpy.random.rand(5) * 10).astype('float32')
d3 = (numpy.random.rand(6) * 10).astype('int8')
res = f(d1, d2, d3)
for i, out in enumerate(f.outputs):
dtype = out.variable.dtype
assert res[i].dtype == dtype
inp = out.variable.owner.inputs[0]
if inp.dtype == 'float16':
d = d1
elif inp.dtype == 'float32':
d = d2
else:
d = d3
assert_allclose(d.astype(dtype), res[i])
class test_GpuDimShuffle(test_elemwise.test_DimShuffle): class test_GpuDimShuffle(test_elemwise.test_DimShuffle):
op = GpuDimShuffle op = GpuDimShuffle
......
...@@ -2220,6 +2220,11 @@ class Cast(UnaryScalarOp): ...@@ -2220,6 +2220,11 @@ class Cast(UnaryScalarOp):
def __str__(self): def __str__(self):
return '%s{%s}' % (self.__class__.__name__, self.o_type.dtype) return '%s{%s}' % (self.__class__.__name__, self.o_type.dtype)
def clone_float32(self):
if self.o_type == float16:
return convert_to_float32
return self
def make_new_inplace(self, output_types_preference=None, name=None): def make_new_inplace(self, output_types_preference=None, name=None):
""" """
This op.__init__ fct don't have the same parameter as other scalar op. This op.__init__ fct don't have the same parameter as other scalar op.
......
...@@ -21,11 +21,11 @@ from theano import gof ...@@ -21,11 +21,11 @@ from theano import gof
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.scalar.basic import (floats, float16, float32, float64, from theano.scalar.basic import (floats, float16, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64, uint8,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar, InRange, and_, eq, neq, invert, mul, Scalar, InRange,
cast, constant) cast, constant, switch)
from theano.scalar.basic import ( from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad, true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh, rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
...@@ -87,6 +87,18 @@ class test_composite(unittest.TestCase): ...@@ -87,6 +87,18 @@ class test_composite(unittest.TestCase):
nc = c.clone_float32() nc = c.clone_float32()
assert not has_f16(nc) assert not has_f16(nc)
v = uint8()
w = float16()
x = float16()
y = float16()
z = float16()
c = Composite([v, w, x, y, z], [switch(v, mul(w, x, y), z)])
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div_proxy(x, y)) e = mul(add(x, y), div_proxy(x, y))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论