提交 116f70fe authored 作者: Frederic's avatar Frederic

pep8

上级 025d484e
import theano import theano
import numpy
import theano.tensor import theano.tensor
class ScalarSoftsign(theano.scalar.UnaryScalarOp): class ScalarSoftsign(theano.scalar.UnaryScalarOp):
@staticmethod @staticmethod
def static_impl(x): def static_impl(x):
return x / (1.0 + abs(x)) return x / (1.0 + abs(x))
def impl(self, x): def impl(self, x):
return ScalarSoftsign.static_impl(x) return ScalarSoftsign.static_impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
...@@ -17,11 +19,15 @@ class ScalarSoftsign(theano.scalar.UnaryScalarOp): ...@@ -17,11 +19,15 @@ class ScalarSoftsign(theano.scalar.UnaryScalarOp):
return [gz / (d * d)] return [gz / (d * d)]
else: else:
return NotImplemented return NotImplemented
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
z, = out z, = out
if node.inputs[0].type in [theano.scalar.float32, theano.scalar.float64]: if node.inputs[0].type in [theano.scalar.float32,
theano.scalar.float64]:
return "%(z)s = %(x)s / (1.0+fabs(%(x)s));" % locals() return "%(z)s = %(x)s / (1.0+fabs(%(x)s));" % locals()
raise NotImplementedError('only floating point x is implemented') raise NotImplementedError('only floating point x is implemented')
scalar_softsign = ScalarSoftsign(theano.scalar.upgrade_to_float, name='scalar_softsign')
scalar_softsign = ScalarSoftsign(theano.scalar.upgrade_to_float,
name='scalar_softsign')
softsign = theano.tensor.Elemwise(scalar_softsign, name='softsign') softsign = theano.tensor.Elemwise(scalar_softsign, name='softsign')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论