提交 249f0373 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

The exception raised when trying to use Mod on complex numbers is now raised in…

The exception raised when trying to use Mod on complex numbers is now raised in more such situations
上级 58512374
......@@ -28,6 +28,10 @@ builtin_int = int
builtin_float = float
class ComplexError(Exception):
"""Raised if complex numbers are used in an unsupported operation."""
pass
class IntegerDivisionError(Exception):
"""Raised if someone tries to divide integers with '/' instead of '//'."""
pass
......@@ -409,7 +413,7 @@ complex_types = complex64, complex128
discrete_types = int_types + uint_types
continuous_types = float_types + complex_types
class _scalar_py_operators:
#UNARY
......@@ -441,8 +445,8 @@ class _scalar_py_operators:
def __sub__(self,other): return sub(self,other)
def __mul__(self,other): return mul(self,other)
def __div__(self,other): return div_proxy(self,other)
def __floordiv__(self,other): return int_div(self,other)
def __mod__(self,other): return mod(self,other)
def __floordiv__(self, other): return int_div(self, other)
def __mod__(self, other): return mod_check(self, other)
def __pow__(self,other): return pow(self,other)
#ARITHMETIC - RIGHT-OPERAND
......@@ -1121,18 +1125,29 @@ int_div = IntDiv(upcast_out, name = 'int_div')
floor_div = int_div
def raise_complex_error():
raise ComplexError(
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
def mod_check(x, y):
if (as_scalar(x).type in complex_types or
as_scalar(y).type in complex_types):
# Currently forbidden.
raise_complex_error()
else:
return mod(x, y)
class Mod(BinaryScalarOp):
def impl(self, x, y):
if isinstance(x, numpy.complex) or isinstance(y, numpy.complex):
self.raise_complex_error()
raise_complex_error()
return x % y
def raise_complex_error(self):
raise TypeError(
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
def c_code_cache_version(self):
return (5,)
......@@ -1169,7 +1184,7 @@ class Mod(BinaryScalarOp):
x_mod_ypm = "fmod(%(x)s,-%(y)s)"%locals()
x_mod_ymp = "fmod(-%(x)s,%(y)s)"%locals()
elif str(t) in imap(str, complex_types):
self.raise_complex_error()
raise_complex_error()
else:
raise NotImplementedError('type not supported', type)
......
......@@ -182,9 +182,9 @@ class test_complex_mod(unittest.TestCase):
x = complex64()
y = int32()
try:
theano.function([x, y], x % y)
x % y
assert False
except TypeError:
except ComplexError:
pass
......
......@@ -24,8 +24,8 @@ from theano.gof.python25 import partial, any, all
from theano import compile, printing
from theano.printing import pprint
# We use this exception as well.
from theano.scalar import IntegerDivisionError
# We use these exceptions as well.
from theano.scalar import ComplexError, IntegerDivisionError
### set up the external interface
from elemwise import Elemwise, DimShuffle, CAReduce, Sum
......@@ -1161,7 +1161,11 @@ class _tensor_py_operators:
return NotImplemented
def __mod__(self,other):
try:
return mod(self,other)
return mod_check(self, other)
except ComplexError:
# This is to raise the exception that occurs when trying to compute
# x % y with either x or y a complex number.
raise
except Exception, e:
return NotImplemented
......@@ -2582,7 +2586,6 @@ def minimum(x,y):
"""
# see decorator for function body
def div_proxy(x, y):
"""Proxy for either true_div or int_div, depending on types of x, y."""
f = eval('%s_div' % scal.int_or_true_div(
......@@ -2590,7 +2593,6 @@ def div_proxy(x, y):
as_tensor_variable(y).dtype in discrete_dtypes))
return f(x, y)
@_scal_elemwise_with_nfunc('add', 2, 1)
def add(a, *other_terms):
"""elementwise addition"""
......@@ -2621,6 +2623,15 @@ def int_div(a, b):
"""elementwise integer-division"""
# see decorator for function body
def mod_check(x, y):
"""Make sure we do not try to use complex numbers."""
if (as_tensor_variable(x).dtype in complex_dtypes or
as_tensor_variable(y).dtype in complex_dtypes):
# Currently forbidden.
scal.raise_complex_error()
else:
return mod(x, y)
@_scal_elemwise_with_nfunc('mod', 2, 1)
def mod(a, b):
"""elementwise modulo"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论