提交 2114f1d0 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Using a custom exception for integer divisions (safer), and fixed crash due to…

Using a custom exception for integer divisions (safer), and fixed crash due to this newly raised exception
上级 875aa3d7
...@@ -27,6 +27,11 @@ builtin_int = int ...@@ -27,6 +27,11 @@ builtin_int = int
builtin_float = float builtin_float = float
class IntegerDivisionError(Exception):
"""Raised if someone tries to divide integers with '/' instead of '//'."""
pass
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
# Should we try to keep float32 instead of float64? This is used so that # Should we try to keep float32 instead of float64? This is used so that
# for instance mixing int64 with float32 yields float32 instead of float64. # for instance mixing int64 with float32 yields float32 instead of float64.
...@@ -1028,7 +1033,7 @@ def div_proxy(x, y): ...@@ -1028,7 +1033,7 @@ def div_proxy(x, y):
# Following discussion on theano-dev ("Inconsistent behavior in integer # Following discussion on theano-dev ("Inconsistent behavior in integer
# division"), we will change the semantics of "/" on integer types in # division"), we will change the semantics of "/" on integer types in
# Theano 0.4. Until then, it is forbidden to use "/" on integers. # Theano 0.4. Until then, it is forbidden to use "/" on integers.
raise NotImplementedError( raise IntegerDivisionError(
"Dividing two integers with '/' is forbidden until Theano v0.4" "Dividing two integers with '/' is forbidden until Theano v0.4"
" is released (where the result will be a floating point " " is released (where the result will be a floating point "
"number). In the meantime, please either use '//' for integer " "number). In the meantime, please either use '//' for integer "
......
...@@ -7,6 +7,7 @@ import sys # for sys.maxint ...@@ -7,6 +7,7 @@ import sys # for sys.maxint
from theano.configparser import config, AddConfigVar, BoolParam from theano.configparser import config, AddConfigVar, BoolParam
import traceback #for overriding Op.__call__ import traceback #for overriding Op.__call__
import warnings import warnings
from itertools import izip
import numpy, theano import numpy, theano
#from copy import copy as python_copy #from copy import copy as python_copy
...@@ -23,6 +24,9 @@ from theano.gof.python25 import partial, any, all ...@@ -23,6 +24,9 @@ from theano.gof.python25 import partial, any, all
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint from theano.printing import pprint
# We use this exception as well.
from theano.scalar import IntegerDivisionError
### set up the external interface ### set up the external interface
from elemwise import Elemwise, DimShuffle, CAReduce, Sum from elemwise import Elemwise, DimShuffle, CAReduce, Sum
...@@ -1138,7 +1142,7 @@ class _tensor_py_operators: ...@@ -1138,7 +1142,7 @@ class _tensor_py_operators:
def __div__(self,other): def __div__(self,other):
try: try:
return div_proxy(self,other) return div_proxy(self,other)
except NotImplementedError: except IntegerDivisionError:
# This is to raise the exception that occurs when trying to divide # This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden). # two integer arrays (currently forbidden).
raise raise
...@@ -2579,7 +2583,7 @@ def div_proxy(x, y): ...@@ -2579,7 +2583,7 @@ def div_proxy(x, y):
if (as_tensor_variable(x).dtype in discrete_dtypes and if (as_tensor_variable(x).dtype in discrete_dtypes and
as_tensor_variable(y).dtype in discrete_dtypes): as_tensor_variable(y).dtype in discrete_dtypes):
# See the same in scalar/basic.py # See the same in scalar/basic.py
raise NotImplementedError( raise IntegerDivisionError(
"Dividing two integer arrays with '/' is forbidden until " "Dividing two integer arrays with '/' is forbidden until "
"Theano v0.4 is released (where the result will be a floating " "Theano v0.4 is released (where the result will be a floating "
"point number). In the meantime, please either use '//' for " "point number). In the meantime, please either use '//' for "
...@@ -2921,7 +2925,7 @@ class Subtensor(Op): ...@@ -2921,7 +2925,7 @@ class Subtensor(Op):
padded = ( actual_idx_list + padded = ( actual_idx_list +
[slice(None, None, None)]*(len(xshp)-len(self.idx_list))) [slice(None, None, None)]*(len(xshp)-len(self.idx_list)))
i = 0 i = 0
for idx, xl in zip(padded, xshp): for idx, xl in izip(padded, xshp):
if isinstance(idx, slice): if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant, # If it is the default (None, None, None) slice, or a variant,
# the shape will be xl # the shape will be xl
...@@ -2931,7 +2935,7 @@ class Subtensor(Op): ...@@ -2931,7 +2935,7 @@ class Subtensor(Op):
outshp.append(xl) outshp.append(xl)
else: else:
cnf = get_canonical_form_slice(idx, xl) cnf = get_canonical_form_slice(idx, xl)
length = (cnf[0].stop - cnf[0].start -1)/cnf[0].step + 1 length = (cnf[0].stop - cnf[0].start -1) // cnf[0].step + 1
length = switch(lt(length,0), 0, length) length = switch(lt(length,0), 0, length)
outshp.append(length) outshp.append(length)
i += 1 i += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论