提交 cee1e02e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make sure scalar ops do not compute in float16

Ops that are defined with upcast_to_float should upcast to float32 or float64 minimally. Add tests for the cases where the inputs were int8, which is when float16 values appeared.
上级 5b79ef29
...@@ -2166,7 +2166,7 @@ neg = Neg(same_out, name='neg') ...@@ -2166,7 +2166,7 @@ neg = Neg(same_out, name='neg')
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
""" multiplicative inverse. Also called reciprocal""" """ multiplicative inverse. Also called reciprocal"""
def impl(self, x): def impl(self, x):
return 1.0 / x return numpy.float32(1.0) / x
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.type in complex_types: if x.type in complex_types:
...@@ -2190,6 +2190,11 @@ class Log(UnaryScalarOp): ...@@ -2190,6 +2190,11 @@ class Log(UnaryScalarOp):
amd_float64 = "amd_vrda_log" amd_float64 = "amd_vrda_log"
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.log will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.log(x, sig='f')
return numpy.log(x) return numpy.log(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2219,6 +2224,11 @@ class Log2(UnaryScalarOp): ...@@ -2219,6 +2224,11 @@ class Log2(UnaryScalarOp):
amd_float64 = "amd_vrda_log2" amd_float64 = "amd_vrda_log2"
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.log2 will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.log2(x, sig='f')
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2245,6 +2255,11 @@ class Log10(UnaryScalarOp): ...@@ -2245,6 +2255,11 @@ class Log10(UnaryScalarOp):
amd_float64 = "amd_vrda_log10" amd_float64 = "amd_vrda_log10"
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.log10 will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.log10(x, sig='f')
return numpy.log10(x) return numpy.log10(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2268,6 +2283,11 @@ log10 = Log10(upgrade_to_float, name='log10') ...@@ -2268,6 +2283,11 @@ log10 = Log10(upgrade_to_float, name='log10')
class Log1p(UnaryScalarOp): class Log1p(UnaryScalarOp):
""" log(1+x) """ """ log(1+x) """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.log1p will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.log1p(x, sig='f')
return numpy.log1p(x) return numpy.log1p(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2293,6 +2313,11 @@ class Exp(UnaryScalarOp): ...@@ -2293,6 +2313,11 @@ class Exp(UnaryScalarOp):
amd_float64 = "amd_vrda_exp" amd_float64 = "amd_vrda_exp"
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.exp(x, sig='f')
return numpy.exp(x) return numpy.exp(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2315,6 +2340,11 @@ exp = Exp(upgrade_to_float, name='exp') ...@@ -2315,6 +2340,11 @@ exp = Exp(upgrade_to_float, name='exp')
class Exp2(UnaryScalarOp): class Exp2(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.exp2 will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.exp2(x, sig='f')
return numpy.exp2(x) return numpy.exp2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2337,6 +2367,11 @@ exp2 = Exp2(upgrade_to_float, name='exp2') ...@@ -2337,6 +2367,11 @@ exp2 = Exp2(upgrade_to_float, name='exp2')
class Expm1(UnaryScalarOp): class Expm1(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.expm1 will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.expm1(x, sig='f')
return numpy.expm1(x) return numpy.expm1(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2382,6 +2417,11 @@ sqr = Sqr(same_out, name='sqr') ...@@ -2382,6 +2417,11 @@ sqr = Sqr(same_out, name='sqr')
class Sqrt(UnaryScalarOp): class Sqrt(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.sqrt will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.sqrt(x, sig='f')
return numpy.sqrt(x) return numpy.sqrt(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2404,6 +2444,11 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt') ...@@ -2404,6 +2444,11 @@ sqrt = Sqrt(upgrade_to_float, name='sqrt')
class Deg2Rad(UnaryScalarOp): class Deg2Rad(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.deg2rad will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.deg2rad(x, sig='f')
return numpy.deg2rad(x) return numpy.deg2rad(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2426,6 +2471,11 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad') ...@@ -2426,6 +2471,11 @@ deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad')
class Rad2Deg(UnaryScalarOp): class Rad2Deg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.rad2deg will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.rad2deg(x, sig='f')
return numpy.rad2deg(x) return numpy.rad2deg(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2451,6 +2501,11 @@ class Cos(UnaryScalarOp): ...@@ -2451,6 +2501,11 @@ class Cos(UnaryScalarOp):
amd_float64 = "amd_vrda_cos" amd_float64 = "amd_vrda_cos"
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.cos will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.cos(x, sig='f')
return numpy.cos(x) return numpy.cos(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2473,6 +2528,11 @@ cos = Cos(upgrade_to_float, name='cos') ...@@ -2473,6 +2528,11 @@ cos = Cos(upgrade_to_float, name='cos')
class ArcCos(UnaryScalarOp): class ArcCos(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arccos will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.arccos(x, sig='f')
return numpy.arccos(x) return numpy.arccos(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2498,6 +2558,11 @@ class Sin(UnaryScalarOp): ...@@ -2498,6 +2558,11 @@ class Sin(UnaryScalarOp):
amd_float64 = "amd_vrda_sin" amd_float64 = "amd_vrda_sin"
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.sin will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.sin(x, sig='f')
return numpy.sin(x) return numpy.sin(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2520,6 +2585,11 @@ sin = Sin(upgrade_to_float, name='sin') ...@@ -2520,6 +2585,11 @@ sin = Sin(upgrade_to_float, name='sin')
class ArcSin(UnaryScalarOp): class ArcSin(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arcsin will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.arcsin(x, sig='f')
return numpy.arcsin(x) return numpy.arcsin(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2542,6 +2612,11 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin') ...@@ -2542,6 +2612,11 @@ arcsin = ArcSin(upgrade_to_float, name='arcsin')
class Tan(UnaryScalarOp): class Tan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.tan will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.tan(x, sig='f')
return numpy.tan(x) return numpy.tan(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2564,6 +2639,11 @@ tan = Tan(upgrade_to_float, name='tan') ...@@ -2564,6 +2639,11 @@ tan = Tan(upgrade_to_float, name='tan')
class ArcTan(UnaryScalarOp): class ArcTan(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arctan will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.arctan(x, sig='f')
return numpy.arctan(x) return numpy.arctan(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
...@@ -2586,6 +2666,13 @@ arctan = ArcTan(upgrade_to_float, name='arctan') ...@@ -2586,6 +2666,13 @@ arctan = ArcTan(upgrade_to_float, name='arctan')
class ArcTan2(BinaryScalarOp): class ArcTan2(BinaryScalarOp):
def impl(self, y, x): def impl(self, y, x):
# If x and y are int8 or uint8, numpy.arctan2 will compute the result
# in half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
y_dtype = str(getattr(x, 'dtype', ''))
if y_dtype in ('int8', 'uint8'):
return numpy.arctan2(y, x, sig='f')
return numpy.arctan2(y, x) return numpy.arctan2(y, x)
def grad(self, (y, x), (gz,)): def grad(self, (y, x), (gz,)):
...@@ -2621,6 +2708,11 @@ class Cosh(UnaryScalarOp): ...@@ -2621,6 +2708,11 @@ class Cosh(UnaryScalarOp):
cosh(x) = (exp(x) + exp(-x)) / 2 cosh(x) = (exp(x) + exp(-x)) / 2
""" """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.cosh will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.cosh(x, sig='f')
return numpy.cosh(x) return numpy.cosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2643,6 +2735,11 @@ cosh = Cosh(upgrade_to_float, name='cosh') ...@@ -2643,6 +2735,11 @@ cosh = Cosh(upgrade_to_float, name='cosh')
class ArcCosh(UnaryScalarOp): class ArcCosh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arccosh will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.arccosh(x, sig='f')
return numpy.arccosh(x) return numpy.arccosh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2668,6 +2765,11 @@ class Sinh(UnaryScalarOp): ...@@ -2668,6 +2765,11 @@ class Sinh(UnaryScalarOp):
sinh(x) = (exp(x) - exp(-x)) / 2 sinh(x) = (exp(x) - exp(-x)) / 2
""" """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.sinh will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.sinh(x, sig='f')
return numpy.sinh(x) return numpy.sinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2690,6 +2792,11 @@ sinh = Sinh(upgrade_to_float, name='sinh') ...@@ -2690,6 +2792,11 @@ sinh = Sinh(upgrade_to_float, name='sinh')
class ArcSinh(UnaryScalarOp): class ArcSinh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arcsinh will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.arcsinh(x, sig='f')
return numpy.arcsinh(x) return numpy.arcsinh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2716,6 +2823,11 @@ class Tanh(UnaryScalarOp): ...@@ -2716,6 +2823,11 @@ class Tanh(UnaryScalarOp):
= (exp(2*x) - 1) / (exp(2*x) + 1) = (exp(2*x) - 1) / (exp(2*x) + 1)
""" """
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.tanh will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.tanh(x, sig='f')
return numpy.tanh(x) return numpy.tanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -2738,6 +2850,11 @@ tanh = Tanh(upgrade_to_float, name='tanh') ...@@ -2738,6 +2850,11 @@ tanh = Tanh(upgrade_to_float, name='tanh')
class ArcTanh(UnaryScalarOp): class ArcTanh(UnaryScalarOp):
def impl(self, x): def impl(self, x):
# If x is an int8 or uint8, numpy.arctanh will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.arctanh(x, sig='f')
return numpy.arctanh(x) return numpy.arctanh(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
......
...@@ -10,6 +10,7 @@ If you do want to rewrite these tests, bear in mind: ...@@ -10,6 +10,7 @@ If you do want to rewrite these tests, bear in mind:
""" """
import unittest import unittest
import numpy as np
import theano import theano
from theano.gof import FunctionGraph from theano.gof import FunctionGraph
...@@ -20,8 +21,12 @@ from theano.scalar.basic import (floats, float32, float64, ...@@ -20,8 +21,12 @@ from theano.scalar.basic import (floats, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, clip, Composite, add, div_proxy, clip,
and_, eq, neq, invert, mul) and_, eq, neq, invert, mul, Scalar)
import numpy from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
sinh, arcsinh, tanh, arctanh)
def inputs(): def inputs():
return floats('xyz') return floats('xyz')
...@@ -75,7 +80,7 @@ class test_ScalarOps(unittest.TestCase): ...@@ -75,7 +80,7 @@ class test_ScalarOps(unittest.TestCase):
g3 = theano.gradient.grad(a3, x) g3 = theano.gradient.grad(a3, x)
fn3 = gof.DualLinker().accept(FunctionGraph([x], [g3])).make_function() fn3 = gof.DualLinker().accept(FunctionGraph([x], [g3])).make_function()
rng = numpy.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
ntests = 50 ntests = 50
for i in xrange(ntests): for i in xrange(ntests):
...@@ -235,6 +240,124 @@ class test_logical(unittest.TestCase): ...@@ -235,6 +240,124 @@ class test_logical(unittest.TestCase):
self.assertTrue(fn(a,b) == ~a, (a,)) self.assertTrue(fn(a,b) == ~a, (a,))
#class test_upgrade_to_float(unittest.TestCase):
class test_upgrade_to_float(object):
# Test for Ops whose output has to be floating point, even when all
# inputs are ints.
# In particular, when the inputs are int8, the output should be
# at least float32, not float16.
unary_ops_vals = [
(inv, range(-127, 0) + range(1, 127)),
(sqrt, range(0, 128)),
(log, range(1, 128)),
(log2, range(1, 128)),
(log10, range(1, 128)),
(log1p, range(0, 128)),
(exp, range(-127, 89)),
(exp2, range(-127, 89)),
(expm1, range(-127, 89)),
(deg2rad, range(-127, 128)),
(rad2deg, range(-127, 128)),
(cos, range(-127, 128)),
(arccos, range(-1, 2)),
(cosh, range(-89, 90)),
(arccosh, range(1, 128)),
(sin, range(-127, 128)),
(arcsin, range(-1, 2)),
(sinh, range(-89, 90)),
(arcsinh, range(-127, 128)),
(tan, range(-3, 4)),
(arctan, range(-127, 128)),
(tanh, range(-127, 128)),
(arctanh, [0])]
binary_ops_vals = [
(arctan2, range(-127, 128), range(-127, 128))]
@staticmethod
def _test_unary(unary_op, x_range):
xi = int8('xi')
xf = float32('xf')
ei = unary_op(xi)
fi = theano.function([xi], ei)
ef = unary_op(xf)
ff = theano.function([xf], ef)
for x_val in x_range:
outi = fi(x_val)
outf = ff(x_val)
assert outi.dtype == outf.dtype, 'incorrect dtype'
assert np.allclose(outi, outf), 'insufficient precision'
@staticmethod
def _test_binary(binary_op, x_range, y_range):
xi = int8('xi')
yi = int8('yi')
xf = float32('xf')
yf = float32('yf')
ei = binary_op(xi, yi)
fi = theano.function([xi, yi], ei)
ef = binary_op(xf, yf)
ff = theano.function([xf, yf], ef)
for x_val in x_range:
for y_val in y_range:
outi = fi(x_val, y_val)
outf = ff(x_val, y_val)
assert outi.dtype == outf.dtype, 'incorrect dtype'
assert np.allclose(outi, outf), 'insufficient precision'
def test_true_div(self):
# true_div's upcast policy is not exactly "upgrade_to_float",
# so the test is a little bit different
x_range = range(-127, 128)
y_range = range(-127, 0) + range(1, 127)
xi = int8('xi')
yi = int8('yi')
xf = Scalar(theano.config.floatX)('xf')
yf = Scalar(theano.config.floatX)('yf')
ei = true_div(xi, yi)
fi = theano.function([xi, yi], ei)
ef = true_div(xf, yf)
ff = theano.function([xf, yf], ef)
for x_val in x_range:
for y_val in y_range:
outi = fi(x_val, y_val)
outf = ff(x_val, y_val)
assert outi.dtype == outf.dtype, 'incorrect dtype'
assert np.allclose(outi, outf), 'insufficient precision'
def test_unary(self):
# Automatically define all individual unary tests
for unary_op, x_range in self.unary_ops_vals:
test_name = 'test_%s' % unary_op.name
# Make a lambda function so we can name the test
test = lambda: self._test_unary(unary_op, x_range)
test.description = test_name
yield test
def test_binary(self):
# Automatically define all individual binary tests
for binary_op, x_range, y_range in self.binary_ops_vals:
test_name = 'test_%s' % binary_op.name
# Make a lambda function so we can name the test
test = lambda: self._test_binary(binary_op, x_range, y_range)
test.description = test_name
yield test
class test_complex_mod(unittest.TestCase): class test_complex_mod(unittest.TestCase):
"""Make sure % fails on complex numbers.""" """Make sure % fails on complex numbers."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论