提交 c69746ed authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix Erf, Erfc, GammaLn

上级 f0ebe1f8
......@@ -52,7 +52,8 @@ class Erf(UnaryScalarOp):
z, = out
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = erf(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = erf((%(cast)s)%(x)s);" % locals()
erf = Erf(upgrade_to_float, name='erf')
......@@ -83,7 +84,8 @@ class Erfc(UnaryScalarOp):
z, = out
if node.inputs[0].type in complex_types:
raise NotImplementedError('type not supported', type)
return "%(z)s = erfc(%(x)s);" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return "%(z)s = erfc((%(cast)s)%(x)s);" % locals()
# scipy.special.erfc don't support complex. Why?
erfc = Erfc(upgrade_to_float_no_complex, name='erfc')
......@@ -275,11 +277,8 @@ class GammaLn(UnaryScalarOp):
# For some reason, on the GPU, uint64 inputs don't get casted
# automatically to float64. This make the compilation crash
dtype = ""
if node.outputs[0].dtype == 'float64':
dtype = "(double)"
elif node.outputs[0].dtype == 'float32':
dtype = "(float)"
return """%(z)s = lgamma(%(dtype)s%(x)s);""" % locals()
cast = node.outputs[0].type.dtype_specs()[1]
return """%(z)s = lgamma((%(cast)s)%(x)s);""" % locals()
gammaln = GammaLn(upgrade_to_float, name='gammaln')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论