提交 7ad1eb6d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Implement integer division in C code.

上级 0c1dcabd
...@@ -1305,15 +1305,69 @@ true_div = TrueDiv(upcast_out, name='true_div') ...@@ -1305,15 +1305,69 @@ true_div = TrueDiv(upcast_out, name='true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
complex_error = ComplexError(
"Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.")
def impl(self, x, y): def impl(self, x, y):
return x // y return x // y
def c_support_code(self):
# We use a macro as python use % as a special string character,
# and the output of c_code may be run through another level
# of string formatting.
return "#define THEANO_MACRO_MOD(x,y) (x % y)"
def c_code(self, node, name, (x, y), (z,), sub): def c_code(self, node, name, (x, y), (z,), sub):
raise NotImplementedError("For integer arguments the behavior of" t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
" division in C and in Python [can] differ" if t in imap(str, discrete_types):
" when the quotient is negative. C actually" x_div_y_pp = '(%(x)s / %(y)s)' % locals()
" does not even specify a correct behaviour" x_div_y_mp = '((-%(x)s) / %(y)s)' % locals()
" in this case, it is up to the chip.") x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals()
x_div_y_pm = '(%(x)s / (-%(y)s))' % locals()
x_mod_y_pm = 'THEANO_MACRO_MOD(%(x)s, (-%(y)s))' % locals()
x_div_y_mm = '((-%(x)s) / (-%(y)s))' % locals()
elif t in imap(str, float_types):
# We need to call different functions of math.h
# depending on the type
if t == 'float32':
floor = 'floorf'
fmod = 'fmodf'
elif t == 'float64':
floor = 'floor'
fmod = 'fmod'
else:
raise NotImplementedError('type not supported', t)
x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals()
x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals()
x_div_y_pm = '%(floor)s(%(x)s / (-%(y)s))' % locals()
x_mod_y_pm = '%(fmod)s(%(x)s, (-%(y)s))' % locals()
x_div_y_mm = '%(floor)s((-%(x)s) / (-%(y)s))' % locals()
elif t in complex_types:
raise self.complex_error
else:
raise NotImplementedError('type not supported', t)
return """
if (%(x)s < 0) {
if (%(y)s < 0) {
%(z)s = %(x_div_y_mm)s;
} else {
%(z)s = - %(x_div_y_mp)s - ((%(x_mod_y_mp)s == 0) ? 0 : 1);
}
} else {
if (%(y)s < 0) {
%(z)s = - %(x_div_y_pm)s - ((%(x_mod_y_pm)s == 0) ? 0 : 1);
} else {
%(z)s = %(x_div_y_mm)s;
}
}
""" % locals()
def c_code_cache_version(self):
return (1,)
def grad(self, inputs, g_output): def grad(self, inputs, g_output):
return [None] * len(inputs) return [None] * len(inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论