提交 f33a3315 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5441 from notoraptor/round-half-to-even-c-code

Add C code to RoundHalfToEven op.
...@@ -2506,58 +2506,42 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -2506,58 +2506,42 @@ class RoundHalfToEven(UnaryScalarOp):
return [rval] return [rval]
def c_code___(self, node, name, inputs, outputs, sub): def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
typ = node.outputs[0].type.dtype typ = node.outputs[0].type.dtype
if typ not in ['float32', 'float64']: if typ not in ['float32', 'float64']:
raise NotImplementedError("The output should be float32 or float64") raise NotImplementedError("The output should be float32 or float64")
if typ == 'float32':
return dedent(""" ctype = 'float'
#ifndef ROUNDING_EPSILON floor_function = 'floorf'
#define ROUNDING_EPSILON 0.0000001 else:
#endif ctype = 'double'
floor_function = 'floor'
if (%(x)s < 0.0){ return """
// We implement the else part like that: -else( -%(x)s); /* Code inspired from NumPy npy_rint implementation. */
%(typ)s i; {
std::modf( -%(x)s, &i ); %(ctype)s y, r;
y = %(floor_function)s(%(x)s);
// If %(x)s is exactly halfway between two integers r = %(x)s - y;
if ((-%(x)s -(i +0.5)) < epsilon){ if(r > 0.5) {
// If 'i' is even then return 'i' y += 1;
if (std::fmod( i, 2.0 ) < epsilon){ } else if(r == 0.5) {
%(z)s = - i; r = y - 2.0*%(floor_function)s(0.5*y);
}else{ /*
// Else return the nearest even integer If y is even, then r == 0
%(z)s = - ceil( i +0.5 ); If y is odd, then r == 1
} So we can just add r to y, so that
}else{ y will be incremented only if he's odd.
// round to closest */
%(z)s = - round(%(x)s+5); y += (int)r;
}
}else{
%(typ)s i;
std::modf( %(x)s, &i );
// If %(x)s is exactly halfway between two integers
if ((%(x)s -(i +0.5)) < epsilon){
// If 'i' is even then return 'i'
if (std::fmod( i, 2.0 ) < epsilon){
%(z)s = i;
}else{
// Else return the nearest even integer
%(z)s = ceil( i +0.5 );
}
}else{
// round to closest
%(z)s = round(%(x)s+5);
}
} }
%(z)s = y;
#undef ROUNDING_EPSILON }
""" % locals()
""" % locals())
round_half_to_even = RoundHalfToEven(same_out_float_only) round_half_to_even = RoundHalfToEven(same_out_float_only)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论