提交 b2f33c98 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed some test with numpy 1.5.1 under Windows

The problem was that numpy added new data types, like complex192 or float96, that are currently not supported by Theano. The tests would try to test Theano behavior on these new datatypes. Also in this commit: removed a warning in tests when config.int_division is set to 'int'.
上级 b9e51b92
......@@ -72,7 +72,7 @@ def get_numeric_subclasses(cls=numpy.number, ignore=None):
def get_numeric_types(with_int=True, with_float=True, with_complex=False,
with_128_bit=False):
only_theano_types=True):
"""
Return numpy numeric data types.
......@@ -82,17 +82,21 @@ def get_numeric_types(with_int=True, with_float=True, with_complex=False,
:param with_complex: Whether to include complex types.
:param with_128_bit: Whether to include 128/256-bit types.
:param only_theano_types: If True, then numpy numeric data types that are
not supported by Theano are ignored (i.e. those that are not declared in
scalar/basic.py).
:returns: A list of unique data type objects. Note that multiple data types
may share the same string representation, but can be differentiated through
their `num` attribute.
Note that we could probably rely on the lists of types defined in the
`scalar` module. However with this function we can test more unique dtype
objects, and possibly detect defects in dtypes that may be introduced in
numpy in the future.
Note that when `only_theano_types` is True we could simply return the list
of types defined in the `scalar` module. However with this function we can
test more unique dtype objects, and in the future we may use it to
automatically detect new data types introduced in numpy.
"""
if only_theano_types:
theano_types = [d.dtype for d in theano.scalar.all_types]
rval = []
def is_within(cls1, cls2):
# Return True if scalars defined from `cls1` are within the hierarchy
......@@ -109,8 +113,7 @@ def get_numeric_types(with_int=True, with_float=True, with_complex=False,
if ((not with_complex and is_within(cls, numpy.complexfloating)) or
(not with_int and is_within(cls, numpy.integer)) or
(not with_float and is_within(cls, numpy.floating)) or
(not with_128_bit and ('128' in str(dtype) or
'256' in str(dtype)))):
(only_theano_types and dtype not in theano_types)):
# Ignore this class.
continue
rval.append([str(dtype), dtype, dtype.num])
......@@ -4566,6 +4569,10 @@ class test_arithmetic_cast(unittest.TestCase):
numpy_array = lambda dtype: numpy.array([1], dtype=dtype)
theano_i_scalar = lambda dtype: theano.scalar.Scalar(str(dtype))()
numpy_i_scalar = numpy_scalar
if config.int_division == 'int':
# Avoid deprecation warning during tests.
warnings.filterwarnings('ignore', message='Division of two integer',
category=DeprecationWarning)
try:
for cfg in ('numpy+floatX', ): # Used to test 'numpy' as well.
config.cast_policy = cfg
......@@ -4666,7 +4673,13 @@ class test_arithmetic_cast(unittest.TestCase):
assert False
finally:
config.cast_policy = backup_config
if config.int_division == 'int':
# Restore default deprecation warning behavior.
warnings.filterwarnings(
'default',
message='Division of two integer',
category=DeprecationWarning)
class test_broadcast(unittest.TestCase):
def test_broadcast_bigdim(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论