提交 249f0373 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

The exception raised when trying to use Mod on complex numbers is now raised in…

The exception raised when trying to use Mod on complex numbers is now raised in more such situations
上级 58512374
...@@ -28,6 +28,10 @@ builtin_int = int ...@@ -28,6 +28,10 @@ builtin_int = int
builtin_float = float builtin_float = float
class ComplexError(Exception):
"""Raised if complex numbers are used in an unsupported operation."""
pass
class IntegerDivisionError(Exception): class IntegerDivisionError(Exception):
"""Raised if someone tries to divide integers with '/' instead of '//'.""" """Raised if someone tries to divide integers with '/' instead of '//'."""
pass pass
...@@ -409,7 +413,7 @@ complex_types = complex64, complex128 ...@@ -409,7 +413,7 @@ complex_types = complex64, complex128
discrete_types = int_types + uint_types discrete_types = int_types + uint_types
continuous_types = float_types + complex_types continuous_types = float_types + complex_types
class _scalar_py_operators: class _scalar_py_operators:
#UNARY #UNARY
...@@ -441,8 +445,8 @@ class _scalar_py_operators: ...@@ -441,8 +445,8 @@ class _scalar_py_operators:
def __sub__(self,other): return sub(self,other) def __sub__(self,other): return sub(self,other)
def __mul__(self,other): return mul(self,other) def __mul__(self,other): return mul(self,other)
def __div__(self,other): return div_proxy(self,other) def __div__(self,other): return div_proxy(self,other)
def __floordiv__(self,other): return int_div(self,other) def __floordiv__(self, other): return int_div(self, other)
def __mod__(self,other): return mod(self,other) def __mod__(self, other): return mod_check(self, other)
def __pow__(self,other): return pow(self,other) def __pow__(self,other): return pow(self,other)
#ARITHMETIC - RIGHT-OPERAND #ARITHMETIC - RIGHT-OPERAND
...@@ -1121,18 +1125,29 @@ int_div = IntDiv(upcast_out, name = 'int_div') ...@@ -1121,18 +1125,29 @@ int_div = IntDiv(upcast_out, name = 'int_div')
floor_div = int_div floor_div = int_div
def raise_complex_error():
raise ComplexError(
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
def mod_check(x, y):
if (as_scalar(x).type in complex_types or
as_scalar(y).type in complex_types):
# Currently forbidden.
raise_complex_error()
else:
return mod(x, y)
class Mod(BinaryScalarOp): class Mod(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
if isinstance(x, numpy.complex) or isinstance(y, numpy.complex): if isinstance(x, numpy.complex) or isinstance(y, numpy.complex):
self.raise_complex_error() raise_complex_error()
return x % y return x % y
def raise_complex_error(self):
raise TypeError(
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (5,)
...@@ -1169,7 +1184,7 @@ class Mod(BinaryScalarOp): ...@@ -1169,7 +1184,7 @@ class Mod(BinaryScalarOp):
x_mod_ypm = "fmod(%(x)s,-%(y)s)"%locals() x_mod_ypm = "fmod(%(x)s,-%(y)s)"%locals()
x_mod_ymp = "fmod(-%(x)s,%(y)s)"%locals() x_mod_ymp = "fmod(-%(x)s,%(y)s)"%locals()
elif str(t) in imap(str, complex_types): elif str(t) in imap(str, complex_types):
self.raise_complex_error() raise_complex_error()
else: else:
raise NotImplementedError('type not supported', type) raise NotImplementedError('type not supported', type)
......
...@@ -182,9 +182,9 @@ class test_complex_mod(unittest.TestCase): ...@@ -182,9 +182,9 @@ class test_complex_mod(unittest.TestCase):
x = complex64() x = complex64()
y = int32() y = int32()
try: try:
theano.function([x, y], x % y) x % y
assert False assert False
except TypeError: except ComplexError:
pass pass
......
...@@ -24,8 +24,8 @@ from theano.gof.python25 import partial, any, all ...@@ -24,8 +24,8 @@ from theano.gof.python25 import partial, any, all
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint from theano.printing import pprint
# We use this exception as well. # We use these exceptions as well.
from theano.scalar import IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
### set up the external interface ### set up the external interface
from elemwise import Elemwise, DimShuffle, CAReduce, Sum from elemwise import Elemwise, DimShuffle, CAReduce, Sum
...@@ -1161,7 +1161,11 @@ class _tensor_py_operators: ...@@ -1161,7 +1161,11 @@ class _tensor_py_operators:
return NotImplemented return NotImplemented
def __mod__(self,other): def __mod__(self,other):
try: try:
return mod(self,other) return mod_check(self, other)
except ComplexError:
# This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number.
raise
except Exception, e: except Exception, e:
return NotImplemented return NotImplemented
...@@ -2582,7 +2586,6 @@ def minimum(x,y): ...@@ -2582,7 +2586,6 @@ def minimum(x,y):
""" """
# see decorator for function body # see decorator for function body
def div_proxy(x, y): def div_proxy(x, y):
"""Proxy for either true_div or int_div, depending on types of x, y.""" """Proxy for either true_div or int_div, depending on types of x, y."""
f = eval('%s_div' % scal.int_or_true_div( f = eval('%s_div' % scal.int_or_true_div(
...@@ -2590,7 +2593,6 @@ def div_proxy(x, y): ...@@ -2590,7 +2593,6 @@ def div_proxy(x, y):
as_tensor_variable(y).dtype in discrete_dtypes)) as_tensor_variable(y).dtype in discrete_dtypes))
return f(x, y) return f(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1) @_scal_elemwise_with_nfunc('add', 2, 1)
def add(a, *other_terms): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
...@@ -2621,6 +2623,15 @@ def int_div(a, b): ...@@ -2621,6 +2623,15 @@ def int_div(a, b):
"""elementwise integer-division""" """elementwise integer-division"""
# see decorator for function body # see decorator for function body
def mod_check(x, y):
"""Make sure we do not try to use complex numbers."""
if (as_tensor_variable(x).dtype in complex_dtypes or
as_tensor_variable(y).dtype in complex_dtypes):
# Currently forbidden.
scal.raise_complex_error()
else:
return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1) @_scal_elemwise_with_nfunc('mod', 2, 1)
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论