提交 0e82b608 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix output dtype of Ops in tensor.nnet

上级 d75c09d6
......@@ -31,6 +31,11 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
return 0.0
if x > 30.0:
return 1.0
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return 1.0 / (1.0 + numpy.exp(-x, sig='f'))
return 1.0 / (1.0 + numpy.exp(-x))
def impl(self, x):
......@@ -268,8 +273,11 @@ def hard_sigmoid(x):
Removing the slope and shift does not make it faster.
"""
slope = 0.2
shift = 0.5
# Use the same dtype as determined by "upgrade_to_float",
# and perform computation in that dtype.
out_dtype = scalar.upgrade_to_float(scalar.Scalar(dtype=x.dtype))[0].dtype
slope = tensor.constant(0.2, dtype=out_dtype)
shift = tensor.constant(0.5, dtype=out_dtype)
x = (x * slope) + shift
x = tensor.clip(x, 0, 1)
return x
......@@ -300,6 +308,11 @@ class ScalarSoftplus(scalar.UnaryScalarOp):
return 0.0
if x > 30.0:
return x
# If x is an int8 or uint8, numpy.exp will compute the result in
# half-precision (float16), where we want float32.
x_dtype = str(getattr(x, 'dtype', ''))
if x_dtype in ('int8', 'uint8'):
return numpy.log1p(numpy.exp(x, sig='f'))
return numpy.log1p(numpy.exp(x))
def impl(self, x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论