提交 d543d5d2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Update after code review:

- use KERNEL macro - do not use `%(fail)s` on GPU to avoid returning prematurely from kernel - have special block for y == 0 (and reorder other ones) - keep calling // 0 or % 0 on GPU, even though cuda will not fail
上级 87be2075
...@@ -1928,14 +1928,16 @@ class IntDiv(BinaryScalarOp): ...@@ -1928,14 +1928,16 @@ class IntDiv(BinaryScalarOp):
t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]]) t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
if t in imap(str, discrete_types): if t in imap(str, discrete_types):
# If we are in a gpuarray kernel, %(fail)s exits the kernel,
# and we do not have any error report, and we cannot set
# Python error messages either, so for now we just call the
# cuda function, which return a binary pattern of all 1s.
check = dedent(''' check = dedent('''
if (%(y)s == 0) { #ifndef KERNEL
// do not set Python error message in gpuarray kernel for now
#ifndef LID_0
PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero"); PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero");
#endif
%(fail)s %(fail)s
}''') % locals() #endif
''') % locals()
x_div_y_pp = '(%(x)s / %(y)s)' % locals() x_div_y_pp = '(%(x)s / %(y)s)' % locals()
x_div_y_mp = '((-%(x)s) / %(y)s)' % locals() x_div_y_mp = '((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals() x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals()
...@@ -1967,16 +1969,18 @@ class IntDiv(BinaryScalarOp): ...@@ -1967,16 +1969,18 @@ class IntDiv(BinaryScalarOp):
raise NotImplementedError('type not supported', t) raise NotImplementedError('type not supported', t)
return dedent(""" return dedent("""
%(check)s if (%(y)s == 0) {
if (%(x)s < 0) { %(check)s
if (%(y)s < 0) { %(z)s = %(x_div_y_pp)s;
} else if (%(y)s < 0) {
if (%(x)s < 0) {
%(z)s = %(x_div_y_mm)s; %(z)s = %(x_div_y_mm)s;
} else { } else {
%(z)s = - %(x_div_y_mp)s - ((%(x_mod_y_mp)s == 0) ? 0 : 1); %(z)s = - %(x_div_y_pm)s - ((%(x_mod_y_pm)s == 0) ? 0 : 1);
} }
} else { } else {
if (%(y)s < 0) { if (%(x)s < 0) {
%(z)s = - %(x_div_y_pm)s - ((%(x_mod_y_pm)s == 0) ? 0 : 1); %(z)s = - %(x_div_y_mp)s - ((%(x_mod_y_mp)s == 0) ? 0 : 1);
} else { } else {
%(z)s = %(x_div_y_pp)s; %(z)s = %(x_div_y_pp)s;
} }
...@@ -1984,7 +1988,7 @@ class IntDiv(BinaryScalarOp): ...@@ -1984,7 +1988,7 @@ class IntDiv(BinaryScalarOp):
""") % locals() """) % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (5,)
def grad(self, inputs, g_output): def grad(self, inputs, g_output):
return [inp.zeros_like(dtype=theano.config.floatX) return [inp.zeros_like(dtype=theano.config.floatX)
...@@ -2016,13 +2020,13 @@ class Mod(BinaryScalarOp): ...@@ -2016,13 +2020,13 @@ class Mod(BinaryScalarOp):
return x % y return x % y
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
def c_support_code(self): def c_support_code(self):
# We use a macro as python use % as a special string character, # We use a macro as python use % as a special string character,
# and the output of c_code may be run through another level # and the output of c_code may be run through another level
# of string formatting. # of string formatting.
return "#define THEANO_MACRO_MOD(x,y) (x % y)" return "#define THEANO_MACRO_MOD(x, y) (x % y)"
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
""" """
...@@ -2042,14 +2046,16 @@ class Mod(BinaryScalarOp): ...@@ -2042,14 +2046,16 @@ class Mod(BinaryScalarOp):
# keep them out of safety, and verify they are useless with an # keep them out of safety, and verify they are useless with an
# assert. # assert.
assert str(t) in imap(str, discrete_types) assert str(t) in imap(str, discrete_types)
# If we are in a gpuarray kernel, %(fail)s exits the kernel,
# and we do not have any error report, and we cannot set
# Python error messages either, so for now we just call the
# cuda function, returning a binary pattern depending on dtype
check = dedent(''' check = dedent('''
if (%(y)s == 0) { #ifndef KERNEL
// do not set Python error message in gpuarray kernel for now
#ifndef LID_0
PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero"); PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero");
#endif
%(fail)s %(fail)s
}''') % locals() #endif
''') % locals()
x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals() x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals()
x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals() x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals()
x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals() x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals()
...@@ -2062,27 +2068,31 @@ class Mod(BinaryScalarOp): ...@@ -2062,27 +2068,31 @@ class Mod(BinaryScalarOp):
# assert. # assert.
assert str(t) in imap(str, float_types) assert str(t) in imap(str, float_types)
check = '' check = ''
x_mod_y = "fmod(%(x)s,%(y)s)" % locals() x_mod_y = "fmod(%(x)s, %(y)s)" % locals()
x_mod_ymm = "fmod(-%(x)s,-%(y)s)" % locals() x_mod_ymm = "fmod(-%(x)s, -%(y)s)" % locals()
x_mod_ypm = "fmod(%(x)s,-%(y)s)" % locals() x_mod_ypm = "fmod(%(x)s, -%(y)s)" % locals()
x_mod_ymp = "fmod(-%(x)s,%(y)s)" % locals() x_mod_ymp = "fmod(-%(x)s, %(y)s)" % locals()
elif str(t) in imap(str, complex_types): elif str(t) in imap(str, complex_types):
raise self.complex_error raise self.complex_error
else: else:
raise NotImplementedError('type not supported', t) raise NotImplementedError('type not supported', t)
return dedent(""" return dedent("""
%(check)s if (%(y)s == 0) {
if (%(x)s < 0){ %(check)s
if (%(y)s < 0){ %(z)s = %(x_mod_y)s;
%(z)s = -(%(x_mod_ymm)s); } else if (%(y)s < 0){
}else{ if (%(x)s < 0){
%(z)s = - %(x_mod_ymp)s + (%(x_mod_ymp)s != 0 ? %(y)s : 0); %(z)s = -(%(x_mod_ymm)s);
} } else {
}else if (%(y)s < 0){ %(z)s = (%(x_mod_ypm)s) + (%(x_mod_ypm)s != 0 ? %(y)s : 0);
%(z)s = (%(x_mod_ypm)s) + (%(x_mod_ypm)s != 0 ? %(y)s : 0); }
}else{ } else {
%(z)s = %(x_mod_y)s; if (%(x)s < 0){
%(z)s = - %(x_mod_ymp)s + (%(x_mod_ymp)s != 0 ? %(y)s : 0);
} else {
%(z)s = %(x_mod_y)s;
}
} }
""") % locals() """) % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论