提交 f85a0676 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix remaining dtype warnings in tests.tensor.test_basic

上级 f84d2b00
差异被折叠。
import builtins import builtins
import operator import operator
import pickle import pickle
import warnings
from copy import copy from copy import copy
from functools import reduce from functools import reduce
from itertools import product from itertools import product
...@@ -144,6 +143,7 @@ from aesara.tensor.type_other import NoneConst ...@@ -144,6 +143,7 @@ from aesara.tensor.type_other import NoneConst
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.test_link import make_function from tests.link.test_link import make_function
from tests.tensor.utils import ( from tests.tensor.utils import (
ALL_DTYPES,
_bad_build_broadcast_binary_normal, _bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal, _bad_runtime_broadcast_binary_normal,
_bad_runtime_reciprocal, _bad_runtime_reciprocal,
...@@ -177,7 +177,6 @@ from tests.tensor.utils import ( ...@@ -177,7 +177,6 @@ from tests.tensor.utils import (
copymod, copymod,
div_grad_rtol, div_grad_rtol,
eval_outputs, eval_outputs,
get_numeric_types,
ignore_isfinite_mode, ignore_isfinite_mode,
inplace_func, inplace_func,
integers, integers,
...@@ -2225,95 +2224,94 @@ class TestArithmeticCast: ...@@ -2225,95 +2224,94 @@ class TestArithmeticCast:
""" """
def test_arithmetic_cast(self): @pytest.mark.parametrize(
dtypes = get_numeric_types(with_complex=True) "op",
[
operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.floordiv,
],
)
@pytest.mark.parametrize("a_type", ALL_DTYPES)
@pytest.mark.parametrize("b_type", ALL_DTYPES)
@pytest.mark.parametrize(
"combo",
[
("scalar", "scalar"),
("array", "array"),
("scalar", "array"),
("array", "scalar"),
("i_scalar", "i_scalar"),
],
)
def test_arithmetic_cast(self, op, a_type, b_type, combo):
if op is operator.floordiv and (
a_type.startswith("complex") or b_type.startswith("complex")
):
pytest.skip("Not supported by NumPy")
# Here: # Here:
# scalar == scalar stored as a 0d array # scalar == scalar stored as a 0d array
# array == 1d array # array == 1d array
# i_scalar == scalar type used internally by Aesara # i_scalar == scalar type used internally by Aesara
def Aesara_scalar(dtype): def aesara_scalar(dtype):
return scalar(dtype=str(dtype)) return scalar(dtype=str(dtype))
def numpy_scalar(dtype): def numpy_scalar(dtype):
return np.array(1, dtype=dtype) return np.array(1, dtype=dtype)
def Aesara_array(dtype): def aesara_array(dtype):
return vector(dtype=str(dtype)) return vector(dtype=str(dtype))
def numpy_array(dtype): def numpy_array(dtype):
return np.array([1], dtype=dtype) return np.array([1], dtype=dtype)
def Aesara_i_scalar(dtype): def aesara_i_scalar(dtype):
return aes.ScalarType(str(dtype))() return aes.ScalarType(str(dtype))()
def numpy_i_scalar(dtype): def numpy_i_scalar(dtype):
return numpy_scalar(dtype) return numpy_scalar(dtype)
with warnings.catch_warnings(): with config.change_flags(cast_policy="numpy+floatX"):
# Avoid deprecation warning during tests.
warnings.simplefilter("ignore", category=DeprecationWarning)
for cfg in ("numpy+floatX",): # Used to test 'numpy' as well.
with config.change_flags(cast_policy=cfg):
for op in (
operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.floordiv,
):
for a_type in dtypes:
for b_type in dtypes:
# We will test all meaningful combinations of # We will test all meaningful combinations of
# scalar and array operations. # scalar and array operations.
for combo in ( aesara_args = list(map(eval, [f"aesara_{c}" for c in combo]))
("scalar", "scalar"), numpy_args = list(map(eval, [f"numpy_{c}" for c in combo]))
("array", "array"), aesara_arg_1 = aesara_args[0](a_type)
("scalar", "array"), aesara_arg_2 = aesara_args[1](b_type)
("array", "scalar"), aesara_dtype = op(
("i_scalar", "i_scalar"), aesara_arg_1,
): aesara_arg_2,
Aesara_args = list(
map(eval, [f"Aesara_{c}" for c in combo])
)
numpy_args = list(
map(eval, [f"numpy_{c}" for c in combo])
)
Aesara_dtype = op(
Aesara_args[0](a_type),
Aesara_args[1](b_type),
).type.dtype ).type.dtype
# For numpy we have a problem: # For numpy we have a problem:
# http://projects.scipy.org/numpy/ticket/1827 # http://projects.scipy.org/numpy/ticket/1827
# As a result we only consider the highest data # As a result we only consider the highest data
# type that numpy may return. # type that numpy may return.
numpy_arg_1 = numpy_args[0](a_type)
numpy_arg_2 = numpy_args[1](b_type)
numpy_dtypes = [ numpy_dtypes = [
op( op(numpy_arg_1, numpy_arg_2).dtype,
numpy_args[0](a_type), numpy_args[1](b_type) op(numpy_arg_2, numpy_arg_1).dtype,
).dtype,
op(
numpy_args[1](b_type), numpy_args[0](a_type)
).dtype,
] ]
numpy_dtype = aes.upcast( numpy_dtype = aes.upcast(*list(map(str, numpy_dtypes)))
*list(map(str, numpy_dtypes))
) if numpy_dtype == aesara_dtype:
if numpy_dtype == Aesara_dtype:
# Same data type found, all is good! # Same data type found, all is good!
continue return
if ( if (
cfg == "numpy+floatX" config.floatX == "float32"
and config.floatX == "float32"
and a_type != "float64" and a_type != "float64"
and b_type != "float64" and b_type != "float64"
and numpy_dtype == "float64" and numpy_dtype == "float64"
): ):
# We should keep float32. # We should keep float32.
assert Aesara_dtype == "float32" assert aesara_dtype == "float32"
continue return
if "array" in combo and "scalar" in combo: if "array" in combo and "scalar" in combo:
# For mixed scalar / array operations, # For mixed scalar / array operations,
# Aesara may differ from numpy as it does # Aesara may differ from numpy as it does
...@@ -2333,21 +2331,20 @@ class TestArithmeticCast: ...@@ -2333,21 +2331,20 @@ class TestArithmeticCast:
array_type != up_type array_type != up_type
and and
# Aesara upcasted the result array. # Aesara upcasted the result array.
Aesara_dtype == up_type aesara_dtype == up_type
and and
# But Numpy kept its original type. # But Numpy kept its original type.
array_type == numpy_dtype array_type == numpy_dtype
): ):
# Then we accept this difference in # Then we accept this difference in
# behavior. # behavior.
continue return
if ( if (
cfg == "numpy+floatX" {a_type, b_type} == {"complex128", "float32"}
and a_type == "complex128" or {a_type, b_type} == {"complex128", "float16"}
and (b_type == "float32" or b_type == "float16") and set(combo) == {"scalar", "array"}
and combo == ("scalar", "array") and aesara_dtype == "complex128"
and Aesara_dtype == "complex128"
and numpy_dtype == "complex64" and numpy_dtype == "complex64"
): ):
# In numpy 1.6.x adding a complex128 with # In numpy 1.6.x adding a complex128 with
......
...@@ -109,89 +109,6 @@ def eval_outputs(outputs, ops=(), mode=None): ...@@ -109,89 +109,6 @@ def eval_outputs(outputs, ops=(), mode=None):
return variables return variables
def get_numeric_subclasses(cls=np.number, ignore=None):
"""Return subclasses of `cls` in the numpy scalar hierarchy.
We only return subclasses that correspond to unique data types. The
hierarchy can be seen here:
http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
"""
if ignore is None:
ignore = []
rval = []
dtype = np.dtype(cls)
dtype_num = dtype.num
if dtype_num not in ignore:
# Safety check: we should be able to represent 0 with this data type.
np.array(0, dtype=dtype)
rval.append(cls)
ignore.append(dtype_num)
for sub_ in cls.__subclasses__():
rval += [c for c in get_numeric_subclasses(sub_, ignore=ignore)]
return rval
def get_numeric_types(
with_int=True, with_float=True, with_complex=False, only_aesara_types=True
):
"""Return NumPy numeric data types.
Parameters
----------
with_int
Whether to include integer types.
with_float
Whether to include floating point types.
with_complex
Whether to include complex types.
only_aesara_types
If ``True``, then numpy numeric data types that are not supported by
Aesara are ignored (i.e. those that are not declared in
``scalar/basic.py``).
Returns
-------
A list of unique data type objects. Note that multiple data types may share
the same string representation, but can be differentiated through their
`num` attribute.
Note that when `only_aesara_types` is True we could simply return the list
of types defined in the `scalar` module. However with this function we can
test more unique dtype objects, and in the future we may use it to
automatically detect new data types introduced in numpy.
"""
if only_aesara_types:
aesara_types = [d.dtype for d in aesara.scalar.all_types]
rval = []
def is_within(cls1, cls2):
# Return True if scalars defined from `cls1` are within the hierarchy
# starting from `cls2`.
# The third test below is to catch for instance the fact that
# one can use ``dtype=numpy.number`` and obtain a float64 scalar, even
# though `numpy.number` is not under `numpy.floating` in the class
# hierarchy.
return (
cls1 is cls2
or issubclass(cls1, cls2)
or isinstance(np.array([0], dtype=cls1)[0], cls2)
)
for cls in get_numeric_subclasses():
dtype = np.dtype(cls)
if (
(not with_complex and is_within(cls, np.complexfloating))
or (not with_int and is_within(cls, np.integer))
or (not with_float and is_within(cls, np.floating))
or (only_aesara_types and dtype not in aesara_types)
):
# Ignore this class.
continue
rval.append([str(dtype), dtype, dtype.num])
# We sort it to be deterministic, then remove the string and num elements.
return [x[1] for x in sorted(rval, key=str)]
def _numpy_checker(x, y): def _numpy_checker(x, y):
"""Checks if `x.data` and `y.data` have the same contents. """Checks if `x.data` and `y.data` have the same contents.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论