提交 9354d3b8 authored 作者: Frederic Bastien's avatar Frederic Bastien

reimplemented the c code for the Module operator. It now always behave like python

上级 2cf31a7f
...@@ -843,42 +843,40 @@ class Mod(BinaryScalarOp): ...@@ -843,42 +843,40 @@ class Mod(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x % y return x % y
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
""" """
We want the result to have the same sign as python, not the other implementaiton of mod. We want the result to have the same sign as python, not the other implementaiton of mod.
""" """
raise NotImplementedError()
#raise NotImplementedError("Unlike Python, C's modulo returns negative modulo on negative dividend (to implement)") #raise NotImplementedError("Unlike Python, C's modulo returns negative modulo on negative dividend (to implement)")
t = node.inputs[0].type.upcast(*[ i.type for i in node.inputs[1:]]) t = node.inputs[0].type.upcast(*[ i.type for i in node.inputs[1:]])
if t in int_types or t in ['uint8','int8','uint16','int16','uint32','int32','uint64','int64']: if t in int_types or t in ['uint8','int8','uint16','int16','uint32','int32','uint64','int64']:
x_mod_y = "%(x)s %% %(y)s"%locals() x_mod_y = "(%(x)s %% %(y)s)"%locals()
x_mod_ymm = "(-%(x)s %% -%(y)s)"%locals()
x_mod_ypm = "(%(x)s %% -%(y)s)"%locals()
x_mod_ymp = "(-%(x)s %% %(y)s)"%locals()
elif t in float_types or t in ['float32','float64']: elif t in float_types or t in ['float32','float64']:
x_mod_y = "fmod(%(x)s,%(y)s)"%locals() x_mod_y = "fmod(%(x)s,%(y)s)"%locals()
x_mod_ymm = "fmod(-%(x)s,-%(y)s)"%locals()
x_mod_ypm = "fmod(%(x)s,-%(y)s)"%locals()
x_mod_ymp = "fmod(-%(x)s,%(y)s)"%locals()
else: else:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
return """ return """
if (%(x)s == 0 || %(y)s == 0) { if (%(x)s < 0){
if (%(y)s == 0) %(z)s = %(x_mod_y)s; if (%(y)s < 0){
%(z)s = 0; %(z)s = -(%(x_mod_ymm)s);
} }else{
//was #if @neg@, I suspect @neg@ to be platform dependant. %(z)s = - %(x_mod_ymp)s + (%(x_mod_ymp)s != 0 ? %(y)s : 0);
//should be true under X86, but could be false for other architecture!
#if 1
else if ((%(x)s > 0) == (%(y)s > 0)) {
%(z)s = %(x_mod_y)s;
}
else { /* handled like Python does */
%(z)s = %(x_mod_y)s;
if (%(z)s) %(z)s += %(y)s;
} }
#else }else if (%(y)s < 0){
else %(z)s = (%(x_mod_ypm)s) + (%(x_mod_ypm)s != 0 ? %(y)s : 0);
}else{
%(z)s = %(x_mod_y)s; %(z)s = %(x_mod_y)s;
#endif }
"""%locals() """%locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return None, None return None, None
mod = Mod(upcast_out, name = 'mod') mod = Mod(upcast_out, name = 'mod')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论