提交 15d5b260 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Avoid "floating point exception" crash

Raise a Python exception for integer division by 0
上级 a5c029dc
...@@ -1924,8 +1924,11 @@ class IntDiv(BinaryScalarOp): ...@@ -1924,8 +1924,11 @@ class IntDiv(BinaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
fail = sub['fail']
t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]]) t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
if t in imap(str, discrete_types): if t in imap(str, discrete_types):
check = 'if (%(y)s == 0) {PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero"); %(fail)s}' % locals()
x_div_y_pp = '(%(x)s / %(y)s)' % locals() x_div_y_pp = '(%(x)s / %(y)s)' % locals()
x_div_y_mp = '((-%(x)s) / %(y)s)' % locals() x_div_y_mp = '((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals() x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals()
...@@ -1944,6 +1947,7 @@ class IntDiv(BinaryScalarOp): ...@@ -1944,6 +1947,7 @@ class IntDiv(BinaryScalarOp):
else: else:
raise NotImplementedError('type not supported', t) raise NotImplementedError('type not supported', t)
check = ''
x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals() x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals()
x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals() x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals()
x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals() x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals()
...@@ -1956,6 +1960,7 @@ class IntDiv(BinaryScalarOp): ...@@ -1956,6 +1960,7 @@ class IntDiv(BinaryScalarOp):
raise NotImplementedError('type not supported', t) raise NotImplementedError('type not supported', t)
return dedent(""" return dedent("""
%(check)s
if (%(x)s < 0) { if (%(x)s < 0) {
if (%(y)s < 0) { if (%(y)s < 0) {
%(z)s = %(x_div_y_mm)s; %(z)s = %(x_div_y_mm)s;
...@@ -1972,7 +1977,7 @@ class IntDiv(BinaryScalarOp): ...@@ -1972,7 +1977,7 @@ class IntDiv(BinaryScalarOp):
""") % locals() """) % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def grad(self, inputs, g_output): def grad(self, inputs, g_output):
return [inp.zeros_like(dtype=theano.config.floatX) return [inp.zeros_like(dtype=theano.config.floatX)
...@@ -2004,7 +2009,7 @@ class Mod(BinaryScalarOp): ...@@ -2004,7 +2009,7 @@ class Mod(BinaryScalarOp):
return x % y return x % y
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (6,)
def c_support_code(self): def c_support_code(self):
# We use a macro as python use % as a special string character, # We use a macro as python use % as a special string character,
...@@ -2020,6 +2025,7 @@ class Mod(BinaryScalarOp): ...@@ -2020,6 +2025,7 @@ class Mod(BinaryScalarOp):
""" """
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
fail = sub['fail']
t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]]) t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]])
if (str(t) in imap(str, discrete_types) or if (str(t) in imap(str, discrete_types) or
t in ['uint8', 'int8', 'uint16', 'int16'] or t in ['uint8', 'int8', 'uint16', 'int16'] or
...@@ -2029,6 +2035,7 @@ class Mod(BinaryScalarOp): ...@@ -2029,6 +2035,7 @@ class Mod(BinaryScalarOp):
# keep them out of safety, and verify they are useless with an # keep them out of safety, and verify they are useless with an
# assert. # assert.
assert str(t) in imap(str, discrete_types) assert str(t) in imap(str, discrete_types)
check = 'if (%(y)s == 0) {PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero"); %(fail)s}' % locals()
x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals() x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals()
x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals() x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals()
x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals() x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals()
...@@ -2040,6 +2047,7 @@ class Mod(BinaryScalarOp): ...@@ -2040,6 +2047,7 @@ class Mod(BinaryScalarOp):
# keep them out of safety, and verify they are useless with an # keep them out of safety, and verify they are useless with an
# assert. # assert.
assert str(t) in imap(str, float_types) assert str(t) in imap(str, float_types)
check = ''
x_mod_y = "fmod(%(x)s,%(y)s)" % locals() x_mod_y = "fmod(%(x)s,%(y)s)" % locals()
x_mod_ymm = "fmod(-%(x)s,-%(y)s)" % locals() x_mod_ymm = "fmod(-%(x)s,-%(y)s)" % locals()
x_mod_ypm = "fmod(%(x)s,-%(y)s)" % locals() x_mod_ypm = "fmod(%(x)s,-%(y)s)" % locals()
...@@ -2050,6 +2058,7 @@ class Mod(BinaryScalarOp): ...@@ -2050,6 +2058,7 @@ class Mod(BinaryScalarOp):
raise NotImplementedError('type not supported', t) raise NotImplementedError('type not supported', t)
return dedent(""" return dedent("""
%(check)s
if (%(x)s < 0){ if (%(x)s < 0){
if (%(y)s < 0){ if (%(y)s < 0){
%(z)s = -(%(x_mod_ymm)s); %(z)s = -(%(x_mod_ymm)s);
...@@ -3696,8 +3705,9 @@ class Composite(ScalarOp): ...@@ -3696,8 +3705,9 @@ class Composite(ScalarOp):
def init_c_code(self): def init_c_code(self):
""" """
Return the C code for this Composite Op. Assemble the C code for this Composite Op.
The result is assigned to `self._c_code`.
""" """
# It was already called # It was already called
if hasattr(self, '_c_code'): if hasattr(self, '_c_code'):
......
...@@ -7154,7 +7154,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -7154,7 +7154,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
"test_presence_of_c_code", "test_presence_of_c_code",
["x" for x in i.owner.inputs], ["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs], ["z" for z in i.owner.outputs],
{}) {"fail": "%(fail)s"})
except MethodNotDefined: except MethodNotDefined:
catch = True catch = True
except NotImplementedError: except NotImplementedError:
...@@ -7218,7 +7218,8 @@ your code will run correctly, but may be slower.""") ...@@ -7218,7 +7218,8 @@ your code will run correctly, but may be slower.""")
s_new_out[0].owner.op.c_code(s_new_out[0].owner, s_new_out[0].owner.op.c_code(s_new_out[0].owner,
"test_presence_of_c_code", "test_presence_of_c_code",
["x" for x in s_g], ["x" for x in s_g],
["z" for x in s_new_out], {}) ["z" for x in s_new_out],
{"fail": "%(fail)s"})
except MethodNotDefined: except MethodNotDefined:
_logger.info(("%s does not implement the c_code function." _logger.info(("%s does not implement the c_code function."
" As well as being potentially slow, this disables " " As well as being potentially slow, this disables "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论