提交 a1ec4942 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove goto in OpenMP Elemwise fail code

上级 70340839
...@@ -1928,22 +1928,24 @@ class IntDiv(BinaryScalarOp): ...@@ -1928,22 +1928,24 @@ class IntDiv(BinaryScalarOp):
t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]]) t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
if t in imap(str, discrete_types): if t in imap(str, discrete_types):
x_div_y_pp = '(%(x)s / %(y)s)' % locals()
x_div_y_mp = '((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals()
x_div_y_pm = '(%(x)s / (-%(y)s))' % locals()
x_mod_y_pm = 'THEANO_MACRO_MOD(%(x)s, (-%(y)s))' % locals()
x_div_y_mm = '((-%(x)s) / (-%(y)s))' % locals()
# If we are in a gpuarray kernel, %(fail)s exits the kernel, # If we are in a gpuarray kernel, %(fail)s exits the kernel,
# and we do not have any error report, and we cannot set # and we do not have any error report, and we cannot set
# Python error messages either, so for now we just call the # Python error messages either, so for now we just call the
# cuda function, which return a binary pattern of all 1s. # cuda function, which return a binary pattern of all 1s.
check = dedent(''' div_zero = dedent('''
#ifndef KERNEL #ifdef KERNEL
%(z)s = %(x_div_y_pp)s;
#else
PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero"); PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero");
%(fail)s %(fail)s
#endif #endif
''') % locals() ''') % locals()
x_div_y_pp = '(%(x)s / %(y)s)' % locals()
x_div_y_mp = '((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals()
x_div_y_pm = '(%(x)s / (-%(y)s))' % locals()
x_mod_y_pm = 'THEANO_MACRO_MOD(%(x)s, (-%(y)s))' % locals()
x_div_y_mm = '((-%(x)s) / (-%(y)s))' % locals()
elif t in imap(str, float_types): elif t in imap(str, float_types):
# We need to call different functions of math.h # We need to call different functions of math.h
# depending on the type # depending on the type
...@@ -1956,13 +1958,13 @@ class IntDiv(BinaryScalarOp): ...@@ -1956,13 +1958,13 @@ class IntDiv(BinaryScalarOp):
else: else:
raise NotImplementedError('type not supported', t) raise NotImplementedError('type not supported', t)
check = ''
x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals() x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals()
x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals() x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals() x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals()
x_div_y_pm = '%(floor)s(%(x)s / (-%(y)s))' % locals() x_div_y_pm = '%(floor)s(%(x)s / (-%(y)s))' % locals()
x_mod_y_pm = '%(fmod)s(%(x)s, (-%(y)s))' % locals() x_mod_y_pm = '%(fmod)s(%(x)s, (-%(y)s))' % locals()
x_div_y_mm = '%(floor)s((-%(x)s) / (-%(y)s))' % locals() x_div_y_mm = '%(floor)s((-%(x)s) / (-%(y)s))' % locals()
div_zero = '%(z)s = %(x_div_y_pp)s;' % locals()
elif t in complex_types: elif t in complex_types:
raise self.complex_error raise self.complex_error
else: else:
...@@ -1970,8 +1972,7 @@ class IntDiv(BinaryScalarOp): ...@@ -1970,8 +1972,7 @@ class IntDiv(BinaryScalarOp):
return dedent(""" return dedent("""
if (%(y)s == 0) { if (%(y)s == 0) {
%(check)s %(div_zero)s;
%(z)s = %(x_div_y_pp)s;
} else if (%(y)s < 0) { } else if (%(y)s < 0) {
if (%(x)s < 0) { if (%(x)s < 0) {
%(z)s = %(x_div_y_mm)s; %(z)s = %(x_div_y_mm)s;
...@@ -1988,7 +1989,7 @@ class IntDiv(BinaryScalarOp): ...@@ -1988,7 +1989,7 @@ class IntDiv(BinaryScalarOp):
""") % locals() """) % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (6,)
def grad(self, inputs, g_output): def grad(self, inputs, g_output):
return [inp.zeros_like(dtype=theano.config.floatX) return [inp.zeros_like(dtype=theano.config.floatX)
...@@ -2020,7 +2021,7 @@ class Mod(BinaryScalarOp): ...@@ -2020,7 +2021,7 @@ class Mod(BinaryScalarOp):
return x % y return x % y
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
def c_support_code(self): def c_support_code(self):
# We use a macro as python use % as a special string character, # We use a macro as python use % as a special string character,
...@@ -2046,20 +2047,22 @@ class Mod(BinaryScalarOp): ...@@ -2046,20 +2047,22 @@ class Mod(BinaryScalarOp):
# keep them out of safety, and verify they are useless with an # keep them out of safety, and verify they are useless with an
# assert. # assert.
assert str(t) in imap(str, discrete_types) assert str(t) in imap(str, discrete_types)
x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals()
x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals()
x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals()
x_mod_ymp = "THEANO_MACRO_MOD(-%(x)s, %(y)s)" % locals()
# If we are in a gpuarray kernel, %(fail)s exits the kernel, # If we are in a gpuarray kernel, %(fail)s exits the kernel,
# and we do not have any error report, and we cannot set # and we do not have any error report, and we cannot set
# Python error messages either, so for now we just call the # Python error messages either, so for now we just call the
# cuda function, returning a binary pattern depending on dtype # cuda function, returning a binary pattern depending on dtype
check = dedent(''' mod_zero = dedent('''
#ifndef KERNEL #ifdef KERNEL
%(z)s = %(x_mod_y)s;
#else
PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero"); PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero");
%(fail)s %(fail)s
#endif #endif
''') % locals() ''') % locals()
x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals()
x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals()
x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals()
x_mod_ymp = "THEANO_MACRO_MOD(-%(x)s, %(y)s)" % locals()
elif (str(t) in imap(str, float_types) or elif (str(t) in imap(str, float_types) or
t in ['float32', 'float64'] or t in ['float32', 'float64'] or
t in float_types): t in float_types):
...@@ -2067,11 +2070,11 @@ class Mod(BinaryScalarOp): ...@@ -2067,11 +2070,11 @@ class Mod(BinaryScalarOp):
# keep them out of safety, and verify they are useless with an # keep them out of safety, and verify they are useless with an
# assert. # assert.
assert str(t) in imap(str, float_types) assert str(t) in imap(str, float_types)
check = ''
x_mod_y = "fmod(%(x)s, %(y)s)" % locals() x_mod_y = "fmod(%(x)s, %(y)s)" % locals()
x_mod_ymm = "fmod(-%(x)s, -%(y)s)" % locals() x_mod_ymm = "fmod(-%(x)s, -%(y)s)" % locals()
x_mod_ypm = "fmod(%(x)s, -%(y)s)" % locals() x_mod_ypm = "fmod(%(x)s, -%(y)s)" % locals()
x_mod_ymp = "fmod(-%(x)s, %(y)s)" % locals() x_mod_ymp = "fmod(-%(x)s, %(y)s)" % locals()
mod_zero = "%(z)s = %(x_mod_y)s;" % locals()
elif str(t) in imap(str, complex_types): elif str(t) in imap(str, complex_types):
raise self.complex_error raise self.complex_error
else: else:
...@@ -2079,8 +2082,7 @@ class Mod(BinaryScalarOp): ...@@ -2079,8 +2082,7 @@ class Mod(BinaryScalarOp):
return dedent(""" return dedent("""
if (%(y)s == 0) { if (%(y)s == 0) {
%(check)s %(mod_zero)s;
%(z)s = %(x_mod_y)s;
} else if (%(y)s < 0){ } else if (%(y)s < 0){
if (%(x)s < 0){ if (%(x)s < 0){
%(z)s = -(%(x_mod_ymm)s); %(z)s = -(%(x_mod_ymm)s);
......
...@@ -1037,12 +1037,26 @@ second dimension ...@@ -1037,12 +1037,26 @@ second dimension
# the index of the last of these aliased outputs. # the index of the last of these aliased outputs.
# We generate the C code of the inner loop using the scalar op # We generate the C code of the inner loop using the scalar op
if self.openmp:
# If we are using openmp, we need to get rid of the "goto"
# statement in sub['fail']. For now we recreate it here.
fail = '''
{
%(failure_var)s = %(id)s;
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError,
"Unexpected error in an Op's C code. "
"No Python exception was set.");
}
}''' % sub
else:
fail = sub['fail']
task_code = self.scalar_op.c_code( task_code = self.scalar_op.c_code(
node.tag.fake_node, node.tag.fake_node,
nodename + '_scalar_', nodename + '_scalar_',
["%s_i" % s for s in _inames], ["%s_i" % s for s in _inames],
["%s_i" % s for s in onames], ["%s_i" % s for s in onames],
sub) dict(sub, fail=fail))
code = """ code = """
{ {
%(defines)s %(defines)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论