提交 68b11b4f authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Fix bug from gammaln and psi function

上级 23886e9c
...@@ -81,7 +81,7 @@ class GammaLn(UnaryScalarOp): ...@@ -81,7 +81,7 @@ class GammaLn(UnaryScalarOp):
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
return special.gammaln(x) return scipy.special.gammaln(x)
def impl(self, x): def impl(self, x):
return GammaLn.st_impl(x) return GammaLn.st_impl(x)
...@@ -97,7 +97,7 @@ class GammaLn(UnaryScalarOp): ...@@ -97,7 +97,7 @@ class GammaLn(UnaryScalarOp):
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
return """%(z)s = return """%(z)s =
lgamma(%(x)s);""" % locals() lgamma(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floating point is implemented')
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -113,12 +113,14 @@ class Psi(UnaryScalarOp): ...@@ -113,12 +113,14 @@ class Psi(UnaryScalarOp):
""" """
@staticmethod @staticmethod
def st_impl(x): def st_impl(x):
return special.psi(x) return scipy.special.psi(x)
def impl(self, x): def impl(self, x):
return Psi.st_impl(x) return Psi.st_impl(x)
#def grad() no gradient now def grad(self, inputs, outputs_gradients):
raise NotImplementedError()
return [None]
def c_support_code(self): def c_support_code(self):
return ( return (
...@@ -167,7 +169,7 @@ double _psi(double x){ ...@@ -167,7 +169,7 @@ double _psi(double x){
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
return """%(z)s = return """%(z)s =
_psi(%(x)s);""" % locals() _psi(%(x)s);""" % locals()
raise NotImplementedError('only floatingpoint is implemented') raise NotImplementedError('only floating point is implemented')
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
......
...@@ -202,11 +202,11 @@ def erf_inplace(a): ...@@ -202,11 +202,11 @@ def erf_inplace(a):
@_scal_inplace @_scal_inplace
def erfc_inplace(a): def erfc_inplace(a):
"""complementary error function""" """complementary error function"""
@_scal_inplace @_scal_inplace
def gammaln_inplace(a): def gammaln_inplace(a):
"""log gamma function""" """log gamma function"""
@_scal_inplace @_scal_inplace
def psi_inplace(a): def psi_inplace(a):
"""derivative of log gamma function""" """derivative of log gamma function"""
......
...@@ -1282,37 +1282,45 @@ ErfcInplaceTester = makeBroadcastTester( ...@@ -1282,37 +1282,45 @@ ErfcInplaceTester = makeBroadcastTester(
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy)
_good_broadcast_unary_gammaln = dict(
normal=(rand_ranged(-1 + 1e-2, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
_grad_broadcast_unary_gammaln = dict(
normal=(rand_ranged(1e-8, 10, (2, 3)),),)
GammaLnTester = makeBroadcastTester( GammaLnTester = makeBroadcastTester(
op=tensor.gammaln, op=tensor.gammaln,
expected=expected_gammaln, expected=expected_gammaln,
good=_good_broadcast_unary_normal_no_int_no_complex, good=_good_broadcast_unary_gammaln,
grad=_grad_broadcast_unary_normal, grad=_grad_broadcast_unary_gammaln,
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
skip=skip_scipy) skip=skip_scipy)
GammaLnInplaceTester = makeBroadcastTester( GammaLnInplaceTester = makeBroadcastTester(
op=inplace.gammaln_inplace, op=inplace.gammaln_inplace,
expected=expected_erfc, expected=expected_gammaln,
good=_good_broadcast_unary_normal_no_int_no_complex, good=_good_broadcast_unary_gammaln,
grad=_grad_broadcast_unary_normal, grad=_grad_broadcast_unary_gammaln,
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
inplace=True, inplace=True,
skip=skip_scipy) skip=skip_scipy)
_good_broadcast_unary_psi = dict(
normal=(rand_ranged(1, 10, (2, 3)),),
empty=(numpy.asarray([]),),)
PsiTester = makeBroadcastTester( PsiTester = makeBroadcastTester(
op=tensor.psi, op=tensor.psi,
expected=expected_psi, expected=expected_psi,
good=_good_broadcast_unary_normal_no_int_no_complex, good=_good_broadcast_unary_psi,
grad=_grad_broadcast_unary_normal,
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
skip=skip_scipy) skip=skip_scipy)
PsiInplaceTester = makeBroadcastTester( PsiInplaceTester = makeBroadcastTester(
op=inplace.psi_inplace, op=inplace.psi_inplace,
expected=expected_psi, expected=expected_psi,
good=_good_broadcast_unary_normal_no_int_no_complex, good=_good_broadcast_unary_psi,
grad=_grad_broadcast_unary_normal,
eps=2e-10, eps=2e-10,
mode=mode_no_scipy, mode=mode_no_scipy,
inplace=True, inplace=True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论