提交 034a4081 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix crash introduced by gh-5077 by keeping GpuElemwise as original and only…

Fix crash introduced by gh-5077 by keeping GpuElemwise as original and only casting to float16 in the code generator.
上级 fe1083e8
...@@ -49,11 +49,6 @@ class GpuElemwise(HideC, Elemwise): ...@@ -49,11 +49,6 @@ class GpuElemwise(HideC, Elemwise):
nout = property(lambda self: self.scalar_op.nout) nout = property(lambda self: self.scalar_op.nout)
_f16_ok = True _f16_ok = True
def __init__(self, scalar_op, *args, **kwargs):
if isinstance(scalar_op, Composite):
scalar_op = scalar_op.clone_float32()
Elemwise.__init__(self, scalar_op, *args, **kwargs)
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -108,7 +103,15 @@ class GpuElemwise(HideC, Elemwise): ...@@ -108,7 +103,15 @@ class GpuElemwise(HideC, Elemwise):
inps, outs = self._get_vnames(node) inps, outs = self._get_vnames(node)
scal_v_ins = [get_scal(i.dtype)() for i in node.inputs] scal_v_ins = [get_scal(i.dtype)() for i in node.inputs]
fake_node = self.scalar_op.make_node(*scal_v_ins) # As float16 isn't a c type and most GPU don't compute on it,
# We convert the computation to float32, and let libgpuarray
# load in float16 and cast to float32 and do the reverse for
# the output.
scalar_op = self.scalar_op
if isinstance(scalar_op, Composite):
s = scalar_op.clone_float32()
scalar_op = s
fake_node = scalar_op.make_node(*scal_v_ins)
scal_v_out = fake_node.outputs scal_v_out = fake_node.outputs
assert len(scal_v_out) == len(node.outputs) assert len(scal_v_out) == len(node.outputs)
......
...@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division ...@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division
import numpy import numpy
import theano import theano
from theano import scalar, gof from theano import scalar, gof, tensor
from theano.tests.unittest_tools import SkipTest, assert_allclose from theano.tests.unittest_tools import SkipTest, assert_allclose
from theano.tensor.tests import test_elemwise from theano.tensor.tests import test_elemwise
...@@ -52,6 +52,31 @@ def test_elemwise_pow(): ...@@ -52,6 +52,31 @@ def test_elemwise_pow():
assert_allclose(out, expected_out) assert_allclose(out, expected_out)
def test_composite_elemwise_float16():
w = theano.tensor.bvector()
x = theano.tensor.vector(dtype='float16')
y = theano.tensor.fvector()
cz = tensor.tanh(x + tensor.cast(y, 'float16'))
o = (cz - cz**2 +
tensor.cast(x, 'int16') + tensor.cast(x, 'float32') +
tensor.cast(w, 'float16') -
tensor.constant(numpy.float16(1.0)))
f = theano.function([w, x, y], o, mode=mode_with_gpu)
theano.printing.debugprint(f)
v = theano.tensor.vector(dtype='uint8')
w = theano.tensor.vector(dtype='float16')
x = theano.tensor.vector(dtype='float16')
y = theano.tensor.vector(dtype='float16')
z = theano.tensor.vector(dtype='float16')
o = tensor.switch(v, tensor.mul(w, x, y), z)
f = theano.function([v, w, x, y, z], o, mode=mode_with_gpu)
theano.printing.debugprint(f)
class test_GpuDimShuffle(test_elemwise.test_DimShuffle): class test_GpuDimShuffle(test_elemwise.test_DimShuffle):
op = GpuDimShuffle op = GpuDimShuffle
......
...@@ -21,11 +21,11 @@ from theano import gof ...@@ -21,11 +21,11 @@ from theano import gof
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.scalar.basic import (floats, float16, float32, float64, from theano.scalar.basic import (floats, float16, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64, uint8,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar, InRange, and_, eq, neq, invert, mul, Scalar, InRange,
cast, constant) cast, constant, switch)
from theano.scalar.basic import ( from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad, true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh, rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
...@@ -87,6 +87,22 @@ class test_composite(unittest.TestCase): ...@@ -87,6 +87,22 @@ class test_composite(unittest.TestCase):
nc = c.clone_float32() nc = c.clone_float32()
assert not has_f16(nc) assert not has_f16(nc)
v = uint8()
w = float16()
x = float16()
y = float16()
z = float16()
o = switch(v, mul(w, x, y), z)
cz = Composite([x, y], [tanh(x + cast(y, 'float16'))])
c = Composite([w, x, y], [cz(x, y) - cz(x, y)**2 +
cast(x, 'int16') + cast(x, 'float32') +
cast(w, 'float16') -
constant(np.float16(1.0))])
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div_proxy(x, y)) e = mul(add(x, y), div_proxy(x, y))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论