提交 c1d3a0d7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Enable gammaln c code (also for GPU support) for [u]int*

上级 bea31470
...@@ -269,10 +269,13 @@ class GammaLn(UnaryScalarOp): ...@@ -269,10 +269,13 @@ class GammaLn(UnaryScalarOp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x, = inp x, = inp
z, = out z, = out
if node.inputs[0].type in float_types: # no c code for complex
return """%(z)s = # [u]int* will be casted to float64 before computation
lgamma(%(x)s);""" % locals() if x.type in complex_types:
raise NotImplementedError('only floating point is implemented') raise NotImplementedError(
'gammaln complex c code is not implemented')
return """%(z)s =
lgamma(%(x)s);""" % locals()
gammaln = GammaLn(upgrade_to_float, name='gammaln') gammaln = GammaLn(upgrade_to_float, name='gammaln')
......
...@@ -1807,7 +1807,8 @@ _good_broadcast_unary_gammaln = dict( ...@@ -1807,7 +1807,8 @@ _good_broadcast_unary_gammaln = dict(
empty=(np.asarray([], dtype=config.floatX),), empty=(np.asarray([], dtype=config.floatX),),
int=(randint_ranged(1, 10, (2, 3)),), int=(randint_ranged(1, 10, (2, 3)),),
uint8=(randint_ranged(1, 6, (2, 3)).astype('uint8'),), uint8=(randint_ranged(1, 6, (2, 3)).astype('uint8'),),
uint16=(randint_ranged(1, 10, (2, 3)).astype('uint16'),)) uint16=(randint_ranged(1, 10, (2, 3)).astype('uint16'),),
uint64=(randint_ranged(1, 10, (2, 3)).astype('uint64'),))
_grad_broadcast_unary_gammaln = dict( _grad_broadcast_unary_gammaln = dict(
# smaller range as our grad method does not estimate it well enough. # smaller range as our grad method does not estimate it well enough.
normal=(rand_ranged(1e-1, 8, (2, 3)),),) normal=(rand_ranged(1e-1, 8, (2, 3)),),)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论