提交 d133bf69 authored 作者: notoraptor's avatar notoraptor

Add C code to RoundHalfToEven op.

上级 9b0f65a8
......@@ -347,6 +347,8 @@ class Scalar(Type):
# we declare them here and they will be re-used by TensorType
l.append('<numpy/arrayobject.h>')
l.append('<numpy/arrayscalars.h>')
# npy_math header contains npy_rint() and npy_rintf() declarations for rounding operations.
l.append('<numpy/npy_math.h>')
if config.lib.amdlibm and c_compiler.supports_amdlibm:
l += ['<amdlibm.h>']
return l
......@@ -2506,58 +2508,14 @@ class RoundHalfToEven(UnaryScalarOp):
return [rval]
def c_code___(self, node, name, inputs, outputs, sub):
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
typ = node.outputs[0].type.dtype
if typ not in ['float32', 'float64']:
raise NotImplementedError("The output should be float32 or float64")
return dedent("""
#ifndef ROUNDING_EPSILON
#define ROUNDING_EPSILON 0.0000001
#endif
if (%(x)s < 0.0){
// We implement the else part like that: -else( -%(x)s);
%(typ)s i;
std::modf( -%(x)s, &i );
// If %(x)s is exactly halfway between two integers
if ((-%(x)s -(i +0.5)) < epsilon){
// If 'i' is even then return 'i'
if (std::fmod( i, 2.0 ) < epsilon){
%(z)s = - i;
}else{
// Else return the nearest even integer
%(z)s = - ceil( i +0.5 );
}
}else{
// round to closest
%(z)s = - round(%(x)s+5);
}
}else{
%(typ)s i;
std::modf( %(x)s, &i );
// If %(x)s is exactly halfway between two integers
if ((%(x)s -(i +0.5)) < epsilon){
// If 'i' is even then return 'i'
if (std::fmod( i, 2.0 ) < epsilon){
%(z)s = i;
}else{
// Else return the nearest even integer
%(z)s = ceil( i +0.5 );
}
}else{
// round to closest
%(z)s = round(%(x)s+5);
}
}
#undef ROUNDING_EPSILON
""" % locals())
round_function = 'npy_rint' if typ == 'float64' else 'npy_rintf'
return "%(z)s = %(round_function)s(%(x)s);" % locals()
round_half_to_even = RoundHalfToEven(same_out_float_only)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论