added type checks to scalar.Abs.c_impl

上级 aeca4a25
......@@ -308,7 +308,12 @@ class Abs(UnaryScalarOp):
def grad(self, (x, ), (gz, )):
return gz * sgn(x),
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = abs(%(x)s);" % locals()
dtype = self.inputs[0].dtype
if dtype in ('int32', 'int64'):
return "%(z)s = abs(%(x)s);" % locals()
if dtype in ('float32', 'float64'):
return "%(z)s = fabs(%(x)s);" % locals()
raise NotImplementedError('Abs not implemented for dtype', dtype)
class Sgn(UnaryScalarOp):
def impl(self, x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论