提交 28e4c452 authored 作者: John Salvatier's avatar John Salvatier

fixed scalar bugs

上级 6b3cbbd5
...@@ -76,11 +76,11 @@ class GammaLn(UnaryScalarOp): ...@@ -76,11 +76,11 @@ class GammaLn(UnaryScalarOp):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
return [gz * scalar_psi(x)] return [gz * psi(x)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
z, = out z, = out
if node.inputs[0].type in [scalar.float32, scalar.float64]: if node.inputs[0].type in float_types:
return """%(z)s = return """%(z)s =
lgamma(%(x)s);""" % locals() lgamma(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
...@@ -146,7 +146,7 @@ double _psi(double x){ ...@@ -146,7 +146,7 @@ double _psi(double x){
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
z, = out z, = out
if node.inputs[0].type in [scalar.float32, scalar.float64]: if node.inputs[0].type in float_types:
return """%(z)s = return """%(z)s =
_psi(%(x)s);""" % locals() _psi(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floatingpoint is implemented')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论