提交 88f7f401 authored 作者: James Bergstra's avatar James Bergstra

revised boundary conditions for c impls of sigmoid and softplus

上级 19580d30
...@@ -32,16 +32,18 @@ class ScalarSigmoid(scalar.UnaryScalarOp): ...@@ -32,16 +32,18 @@ class ScalarSigmoid(scalar.UnaryScalarOp):
y = scalar_sigmoid(x) y = scalar_sigmoid(x)
return [gz * y * (1.0 - y)] return [gz * y * (1.0 - y)]
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in [scalar.float32, scalar.float64]: if node.inputs[0].type == scalar.float32:
return """%(z)s = # These constants were obtained by looking at the output of python commands like:
%(x)s < -30.0 # for i in xrange(750):
? 0.0 # print i, repr( numpy.asarray(1.0, dtype=dt) / (numpy.asarray(1.0, dtype=dt) + numpy.exp(-numpy.asarray([i,-i], dtype=dt))))
: %(x)s > 30.0 # the boundary checks prevent us from generating inf
? 1.0 return """%(z)s = %(x)s < -88.0f ? 0.0 : %(x)s > 15.0f ? 1.0f : 1.0f /(1.0f + exp(-%(x)s));""" % locals()
: 1.0 /(1.0+exp(-%(x)s));""" % locals() elif node.inputs[0].type == scalar.float64:
return """%(z)s = %(x)s < -709.0 ? 0.0 : %(x)s > 19.0 ? 1.0 : 1.0 /(1.0+exp(-%(x)s));""" % locals()
else:
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid') scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid')
sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid') sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid')
...@@ -61,16 +63,18 @@ class ScalarSoftplus(scalar.UnaryScalarOp): ...@@ -61,16 +63,18 @@ class ScalarSoftplus(scalar.UnaryScalarOp):
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [gz * scalar_sigmoid(x)] return [gz * scalar_sigmoid(x)]
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
if node.inputs[0].type in [scalar.float32, scalar.float64]: if node.inputs[0].type == scalar.float32:
return """%(z)s = # These constants were obtained by looking at the output of python commands like:
%(x)s < -30.0 # for i in xrange(750):
? 0.0 # print i, repr( numpy.log1p(numpy.exp(numpy.asarray([i,-i], dtype=dt))))
: %(x)s > 30.0 # the boundary checks prevent us from generating inf
? %(x)s return """%(z)s = %(x)s < -103.0f ? 0.0 : %(x)s > 14.0f ? %(x)s : log1p(exp(%(x)s));""" % locals()
: log1p(exp(%(x)s));""" % locals() elif node.inputs[0].type == scalar.float64:
raise NotImplementedError('only floating point x is implemented') return """%(z)s = %(x)s < -745.0 ? 0.0 : %(x)s > 16.0 ? %(x)s : log1p(exp(%(x)s));""" % locals()
else:
raise NotImplementedError('only floatingpoint is implemented')
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus') scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus')
softplus = elemwise.Elemwise(scalar_softplus, name='softplus') softplus = elemwise.Elemwise(scalar_softplus, name='softplus')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论