提交 f85a0676 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix remaining dtype warnings in tests.tensor.test_basic

上级 f84d2b00
import itertools import itertools
import warnings
from copy import copy, deepcopy
from functools import partial from functools import partial
from tempfile import mkstemp from tempfile import mkstemp
...@@ -135,7 +133,6 @@ from tests.tensor.utils import ( ...@@ -135,7 +133,6 @@ from tests.tensor.utils import (
_good_broadcast_unary_normal, _good_broadcast_unary_normal,
_grad_broadcast_unary_normal, _grad_broadcast_unary_normal,
eval_outputs, eval_outputs,
get_numeric_types,
inplace_func, inplace_func,
integers, integers,
integers_ranged, integers_ranged,
...@@ -148,6 +145,8 @@ from tests.tensor.utils import ( ...@@ -148,6 +145,8 @@ from tests.tensor.utils import (
) )
pytestmark = pytest.mark.filterwarnings("error")
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
mode_opt = "FAST_RUN" mode_opt = "FAST_RUN"
else: else:
...@@ -866,7 +865,7 @@ class TestTriangle: ...@@ -866,7 +865,7 @@ class TestTriangle:
assert np.allclose(result, np.tri(N, M_, k, dtype=dtype)) assert np.allclose(result, np.tri(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype) assert result.dtype == np.dtype(dtype)
for dtype in ALL_DTYPES: for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
check(dtype, 3) check(dtype, 3)
# M != N, k = 0 # M != N, k = 0
check(dtype, 3, 5) check(dtype, 3, 5)
...@@ -882,6 +881,10 @@ class TestTriangle: ...@@ -882,6 +881,10 @@ class TestTriangle:
check(dtype, 5, 3, -1) check(dtype, 5, 3, -1)
def test_tril_triu(self): def test_tril_triu(self):
"""
TODO FIXME: Parameterize this.
"""
def check_l(m, k=0): def check_l(m, k=0):
m_symb = matrix(dtype=m.dtype) m_symb = matrix(dtype=m.dtype)
k_symb = iscalar() k_symb = iscalar()
...@@ -934,7 +937,7 @@ class TestTriangle: ...@@ -934,7 +937,7 @@ class TestTriangle:
assert np.allclose(result, np.triu(m, k)) assert np.allclose(result, np.triu(m, k))
assert result.dtype == np.dtype(dtype) assert result.dtype == np.dtype(dtype)
for dtype in ALL_DTYPES: for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
m = random_of_dtype((10, 10), dtype) m = random_of_dtype((10, 10), dtype)
check_l(m, 0) check_l(m, 0)
check_l(m, 1) check_l(m, 1)
...@@ -1316,18 +1319,10 @@ class TestJoinAndSplit: ...@@ -1316,18 +1319,10 @@ class TestJoinAndSplit:
stack([a, b], 4) stack([a, b], 4)
with pytest.raises(IndexError): with pytest.raises(IndexError):
stack([a, b], -4) stack([a, b], -4)
# Testing depreciation warning # Testing depreciation warning
with warnings.catch_warnings(record=True) as w: with pytest.warns(DeprecationWarning):
s = stack(a, b) s = stack(a, b)
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)
with warnings.catch_warnings(record=True) as w:
s = stack([a, b])
s = stack([a, b], 1)
s = stack([a, b], axis=1)
s = stack(tensors=[a, b])
s = stack(tensors=[a, b], axis=1)
assert not w
def test_stack_hessian(self): def test_stack_hessian(self):
# Test the gradient of stack when used in hessian, see gh-1589 # Test the gradient of stack when used in hessian, see gh-1589
...@@ -2025,17 +2020,17 @@ def test_ScalarFromTensor(cast_policy): ...@@ -2025,17 +2020,17 @@ def test_ScalarFromTensor(cast_policy):
scalar_from_tensor(vector()) scalar_from_tensor(vector())
class TestOpCache: def test_op_cache():
def test_basic(self): # TODO: What is this actually testing?
# trigger bug in ticket #162 # trigger bug in ticket #162
v = matrix() v = matrix()
v.name = "v" v.name = "v"
gv = fill(v / v, 1.0) / v - (fill(v / v, 1.0) * v) / (v * v) gv = fill(v / v, 1.0) / v - (fill(v / v, 1.0) * v) / (v * v)
fn_py = inplace_func([v], gv) fn_py = inplace_func([v], gv)
fn_c_or_py = inplace_func([v], gv) fn_c_or_py = inplace_func([v], gv)
a = random(5, 2).astype(config.floatX) a = random(5, 2).astype(config.floatX)
assert np.all(fn_py(a) == fn_c_or_py(a)) assert np.all(fn_py(a) == fn_c_or_py(a))
def test_dimshuffle(): def test_dimshuffle():
...@@ -2175,6 +2170,11 @@ def test_is_flat(): ...@@ -2175,6 +2170,11 @@ def test_is_flat():
def test_tile(): def test_tile():
"""
TODO FIXME: Split this apart and parameterize. Also, find out why it's
unreasonably slow.
"""
def run_tile(x, x_, reps, use_symbolic_reps): def run_tile(x, x_, reps, use_symbolic_reps):
if use_symbolic_reps: if use_symbolic_reps:
rep_symbols = [iscalar() for _ in range(len(reps))] rep_symbols = [iscalar() for _ in range(len(reps))]
...@@ -2556,96 +2556,69 @@ class TestARange: ...@@ -2556,96 +2556,69 @@ class TestARange:
fstop_v32 = np.float32(fstop_v) fstop_v32 = np.float32(fstop_v)
assert np.all(ff(fstop_v32) == np.arange(fstop_v)) assert np.all(ff(fstop_v32) == np.arange(fstop_v))
@pytest.mark.parametrize( @config.change_flags(cast_policy="custom")
"cast_policy", def test_upcast_custom(self):
[
"custom",
"numpy+floatX",
],
)
def test_upcast(self, cast_policy):
"""Test that arange computes output type adequately.""" """Test that arange computes output type adequately."""
with config.change_flags(cast_policy=cast_policy): assert arange(iscalar()).dtype == "int64"
if config.cast_policy == "custom": assert arange(fscalar()).dtype == fscalar().dtype
assert arange(iscalar()).dtype == "int64" assert arange(dscalar()).dtype == dscalar().dtype
assert arange(fscalar()).dtype == fscalar().dtype
assert arange(dscalar()).dtype == dscalar().dtype
# int32 + float32 -> float64 # int32 + float32 -> float64
assert arange(iscalar(), fscalar()).dtype == dscalar().dtype assert arange(iscalar(), fscalar()).dtype == dscalar().dtype
assert arange(iscalar(), dscalar()).dtype == dscalar().dtype assert arange(iscalar(), dscalar()).dtype == dscalar().dtype
assert arange(fscalar(), dscalar()).dtype == dscalar().dtype assert arange(fscalar(), dscalar()).dtype == dscalar().dtype
assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype
elif config.cast_policy == "numpy+floatX":
for dtype in get_numeric_types(): @pytest.mark.parametrize(
# Test with a single argument. "dtype", [dtype for dtype in ALL_DTYPES if not dtype.startswith("complex")]
arange_dtype = arange(scalar(dtype=str(dtype))).dtype )
numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype @pytest.mark.parametrize(
if ( "stop_dtype", [dtype for dtype in ALL_DTYPES if not dtype.startswith("complex")]
dtype != "float64" )
and numpy_dtype == "float64" @config.change_flags(cast_policy="numpy+floatX")
and config.cast_policy == "numpy+floatX" def test_upcast_numpy(self, dtype, stop_dtype):
and config.floatX == "float32" """Make sure our `ARange` output dtypes match NumPy's under different casting policies."""
): # Test with a single argument.
# We want a float32 arange. arange_dtype = arange(scalar(dtype=str(dtype))).dtype
assert arange_dtype == "float32" numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype
else: if (
# Follow numpy. dtype != "float64"
assert arange_dtype == numpy_dtype and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
# Test with two arguments. and config.floatX == "float32"
for stop_dtype in get_numeric_types(): ):
arange_dtype = arange( # We want a float32 arange.
start=scalar(dtype=str(dtype)), assert arange_dtype == "float32"
stop=scalar(dtype=str(stop_dtype)), else:
).dtype # Follow numpy.
numpy_dtype = np.arange( assert arange_dtype == numpy_dtype
start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype), # Test with two arguments.
).dtype arange_dtype = arange(
if ( start=scalar(dtype=str(dtype)),
dtype != "float64" stop=scalar(dtype=str(stop_dtype)),
and stop_dtype != "float64" ).dtype
and numpy_dtype == "float64" numpy_dtype = np.arange(
and config.cast_policy == "numpy+floatX" start=np.array(0, dtype=dtype),
and config.floatX == "float32" stop=np.array(1, dtype=stop_dtype),
): ).dtype
# We want a float32 arange.
assert arange_dtype == "float32" if (
else: dtype != "float64"
# Follow numpy. and stop_dtype != "float64"
assert arange_dtype == numpy_dtype and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
# Test with three arguments. and config.floatX == "float32"
for step_dtype in get_numeric_types(): ):
arange_dtype = arange( # We want a float32 arange.
start=scalar(dtype=str(dtype)), assert arange_dtype == "float32"
stop=scalar(dtype=str(stop_dtype)), else:
step=scalar(dtype=str(step_dtype)), # Follow numpy.
).dtype assert arange_dtype == numpy_dtype
numpy_dtype = np.arange(
start=np.array(0, dtype=dtype),
stop=np.array(1, dtype=stop_dtype),
step=np.array(1, dtype=step_dtype),
).dtype
if (
dtype != "float64"
and stop_dtype != "float64"
and step_dtype != "float64"
and numpy_dtype == "float64"
and config.cast_policy == "numpy+floatX"
and config.floatX == "float32"
):
# We want a float32 arange.
assert arange_dtype == "float32"
else:
# Follow numpy.
assert arange_dtype == numpy_dtype
def test_dtype_cache(self): def test_dtype_cache(self):
# Checks that the same Op is returned on repeated calls to arange """Check that the same `Op` is returned on repeated calls to `ARange` using the same dtype."""
# using the same dtype, but not for different dtypes.
start, stop, step = iscalars("start", "stop", "step") start, stop, step = iscalars("start", "stop", "step")
out1 = arange(start, stop, step) out1 = arange(start, stop, step)
...@@ -3002,18 +2975,8 @@ def test_default_state(): ...@@ -3002,18 +2975,8 @@ def test_default_state():
assert np.allclose(f(np.asarray(2.2, dtype=config.floatX)), 7) assert np.allclose(f(np.asarray(2.2, dtype=config.floatX)), 7)
def test_autocast(): @config.change_flags(cast_policy="custom")
# Call test functions for all possible values of `config.cast_policy`. def test_autocast_custom():
for autocast_cfg in (
"custom",
# 'numpy', # Commented out until it is implemented properly.
"numpy+floatX",
):
with config.change_flags(cast_policy=autocast_cfg):
eval("_test_autocast_" + autocast_cfg.replace("+", "_"))()
def _test_autocast_custom():
# Called from `test_autocast`. # Called from `test_autocast`.
assert config.cast_policy == "custom" assert config.cast_policy == "custom"
orig_autocast = autocast_float.dtypes orig_autocast = autocast_float.dtypes
...@@ -3063,10 +3026,10 @@ def _test_autocast_custom(): ...@@ -3063,10 +3026,10 @@ def _test_autocast_custom():
assert (fvector() + 1.0).dtype == "float32" assert (fvector() + 1.0).dtype == "float32"
assert (dvector() + np.float32(1.1)).dtype == "float64" assert (dvector() + np.float32(1.1)).dtype == "float64"
assert (dvector() + np.float64(1.1)).dtype == "float64" assert (dvector() + np.float64(1.1)).dtype == "float64"
assert (dvector() + np.float(1.1)).dtype == "float64" assert (dvector() + float(1.1)).dtype == "float64"
assert (fvector() + np.float32(1.1)).dtype == "float32" assert (fvector() + np.float32(1.1)).dtype == "float32"
assert (fvector() + np.float64(1.1)).dtype == "float64" assert (fvector() + np.float64(1.1)).dtype == "float64"
assert (fvector() + np.float(1.1)).dtype == config.floatX assert (fvector() + float(1.1)).dtype == config.floatX
assert (lvector() + np.int64(1)).dtype == "int64" assert (lvector() + np.int64(1)).dtype == "int64"
assert (lvector() + np.int32(1)).dtype == "int64" assert (lvector() + np.int32(1)).dtype == "int64"
assert (lvector() + np.int16(1)).dtype == "int64" assert (lvector() + np.int16(1)).dtype == "int64"
...@@ -3078,7 +3041,9 @@ def _test_autocast_custom(): ...@@ -3078,7 +3041,9 @@ def _test_autocast_custom():
assert (fvector() + 1.0).dtype == "float64" assert (fvector() + 1.0).dtype == "float64"
def _test_autocast_numpy(): @pytest.mark.skip(reason="Not implemented")
@config.change_flags(cast_policy="numpy")
def test_autocast_numpy():
# Called from `test_autocast`. # Called from `test_autocast`.
assert config.cast_policy == "numpy" assert config.cast_policy == "numpy"
# Go through some typical scalar values. # Go through some typical scalar values.
...@@ -3098,7 +3063,8 @@ def _test_autocast_numpy(): ...@@ -3098,7 +3063,8 @@ def _test_autocast_numpy():
ok(n_x) ok(n_x)
def _test_autocast_numpy_floatX(): @config.change_flags(cast_policy="numpy+floatX")
def test_autocast_numpy_floatX():
# Called from `test_autocast`. # Called from `test_autocast`.
assert config.cast_policy == "numpy+floatX" assert config.cast_policy == "numpy+floatX"
...@@ -3579,30 +3545,6 @@ class TestAllocDiag: ...@@ -3579,30 +3545,6 @@ class TestAllocDiag:
assert np.all(true_grad_input == grad_input) assert np.all(true_grad_input == grad_input)
class TestNumpyAssumptions:
# Verify that some assumptions Aesara makes on Numpy's behavior still hold.
def test_ndarray_copy(self):
# A copy or deepcopy of the ndarray type should not create a new object.
#
# This is because Aesara makes some comparisons of the form:
# if type(x) is np.ndarray
assert copy(np.ndarray) is np.ndarray
assert deepcopy(np.ndarray) is np.ndarray
def test_dtype_equality(self):
# Ensure dtype string comparisons are consistent.
#
# Aesara often uses string representations of dtypes (e.g. 'float32'). We
# need to make sure that comparing the string representations is the same
# as comparing the dtype objects themselves.
dtypes = get_numeric_types(with_complex=True)
# Perform all pairwise comparisons of dtypes, making sure comparing
# their string representation yields the same result.
for dtype1_idx, dtype1 in enumerate(dtypes):
for dtype2 in dtypes[dtype1_idx + 1 :]:
assert (dtype1 == dtype2) == (str(dtype1) == str(dtype2))
def test_transpose(): def test_transpose():
x1 = dvector("x1") x1 = dvector("x1")
x2 = dmatrix("x2") x2 = dmatrix("x2")
......
import builtins import builtins
import operator import operator
import pickle import pickle
import warnings
from copy import copy from copy import copy
from functools import reduce from functools import reduce
from itertools import product from itertools import product
...@@ -144,6 +143,7 @@ from aesara.tensor.type_other import NoneConst ...@@ -144,6 +143,7 @@ from aesara.tensor.type_other import NoneConst
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.test_link import make_function from tests.link.test_link import make_function
from tests.tensor.utils import ( from tests.tensor.utils import (
ALL_DTYPES,
_bad_build_broadcast_binary_normal, _bad_build_broadcast_binary_normal,
_bad_runtime_broadcast_binary_normal, _bad_runtime_broadcast_binary_normal,
_bad_runtime_reciprocal, _bad_runtime_reciprocal,
...@@ -177,7 +177,6 @@ from tests.tensor.utils import ( ...@@ -177,7 +177,6 @@ from tests.tensor.utils import (
copymod, copymod,
div_grad_rtol, div_grad_rtol,
eval_outputs, eval_outputs,
get_numeric_types,
ignore_isfinite_mode, ignore_isfinite_mode,
inplace_func, inplace_func,
integers, integers,
...@@ -2225,139 +2224,137 @@ class TestArithmeticCast: ...@@ -2225,139 +2224,137 @@ class TestArithmeticCast:
""" """
def test_arithmetic_cast(self): @pytest.mark.parametrize(
dtypes = get_numeric_types(with_complex=True) "op",
[
operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.floordiv,
],
)
@pytest.mark.parametrize("a_type", ALL_DTYPES)
@pytest.mark.parametrize("b_type", ALL_DTYPES)
@pytest.mark.parametrize(
"combo",
[
("scalar", "scalar"),
("array", "array"),
("scalar", "array"),
("array", "scalar"),
("i_scalar", "i_scalar"),
],
)
def test_arithmetic_cast(self, op, a_type, b_type, combo):
if op is operator.floordiv and (
a_type.startswith("complex") or b_type.startswith("complex")
):
pytest.skip("Not supported by NumPy")
# Here: # Here:
# scalar == scalar stored as a 0d array # scalar == scalar stored as a 0d array
# array == 1d array # array == 1d array
# i_scalar == scalar type used internally by Aesara # i_scalar == scalar type used internally by Aesara
def Aesara_scalar(dtype): def aesara_scalar(dtype):
return scalar(dtype=str(dtype)) return scalar(dtype=str(dtype))
def numpy_scalar(dtype): def numpy_scalar(dtype):
return np.array(1, dtype=dtype) return np.array(1, dtype=dtype)
def Aesara_array(dtype): def aesara_array(dtype):
return vector(dtype=str(dtype)) return vector(dtype=str(dtype))
def numpy_array(dtype): def numpy_array(dtype):
return np.array([1], dtype=dtype) return np.array([1], dtype=dtype)
def Aesara_i_scalar(dtype): def aesara_i_scalar(dtype):
return aes.ScalarType(str(dtype))() return aes.ScalarType(str(dtype))()
def numpy_i_scalar(dtype): def numpy_i_scalar(dtype):
return numpy_scalar(dtype) return numpy_scalar(dtype)
with warnings.catch_warnings(): with config.change_flags(cast_policy="numpy+floatX"):
# Avoid deprecation warning during tests. # We will test all meaningful combinations of
warnings.simplefilter("ignore", category=DeprecationWarning) # scalar and array operations.
for cfg in ("numpy+floatX",): # Used to test 'numpy' as well. aesara_args = list(map(eval, [f"aesara_{c}" for c in combo]))
with config.change_flags(cast_policy=cfg): numpy_args = list(map(eval, [f"numpy_{c}" for c in combo]))
for op in ( aesara_arg_1 = aesara_args[0](a_type)
operator.add, aesara_arg_2 = aesara_args[1](b_type)
operator.sub, aesara_dtype = op(
operator.mul, aesara_arg_1,
operator.truediv, aesara_arg_2,
operator.floordiv, ).type.dtype
):
for a_type in dtypes: # For numpy we have a problem:
for b_type in dtypes: # http://projects.scipy.org/numpy/ticket/1827
# As a result we only consider the highest data
# We will test all meaningful combinations of # type that numpy may return.
# scalar and array operations. numpy_arg_1 = numpy_args[0](a_type)
for combo in ( numpy_arg_2 = numpy_args[1](b_type)
("scalar", "scalar"), numpy_dtypes = [
("array", "array"), op(numpy_arg_1, numpy_arg_2).dtype,
("scalar", "array"), op(numpy_arg_2, numpy_arg_1).dtype,
("array", "scalar"), ]
("i_scalar", "i_scalar"), numpy_dtype = aes.upcast(*list(map(str, numpy_dtypes)))
):
if numpy_dtype == aesara_dtype:
Aesara_args = list( # Same data type found, all is good!
map(eval, [f"Aesara_{c}" for c in combo]) return
)
numpy_args = list( if (
map(eval, [f"numpy_{c}" for c in combo]) config.floatX == "float32"
) and a_type != "float64"
Aesara_dtype = op( and b_type != "float64"
Aesara_args[0](a_type), and numpy_dtype == "float64"
Aesara_args[1](b_type), ):
).type.dtype # We should keep float32.
assert aesara_dtype == "float32"
# For numpy we have a problem: return
# http://projects.scipy.org/numpy/ticket/1827
# As a result we only consider the highest data if "array" in combo and "scalar" in combo:
# type that numpy may return. # For mixed scalar / array operations,
numpy_dtypes = [ # Aesara may differ from numpy as it does
op( # not try to prevent the scalar from
numpy_args[0](a_type), numpy_args[1](b_type) # upcasting the array.
).dtype, array_type, scalar_type = (
op( (a_type, b_type)[list(combo).index(arg)]
numpy_args[1](b_type), numpy_args[0](a_type) for arg in ("array", "scalar")
).dtype, )
] up_type = aes.upcast(array_type, scalar_type)
numpy_dtype = aes.upcast( if (
*list(map(str, numpy_dtypes)) # The two data types are different.
) scalar_type != array_type
if numpy_dtype == Aesara_dtype: and
# Same data type found, all is good! # The array type is not enough to hold
continue # the scalar type as well.
if ( array_type != up_type
cfg == "numpy+floatX" and
and config.floatX == "float32" # Aesara upcasted the result array.
and a_type != "float64" aesara_dtype == up_type
and b_type != "float64" and
and numpy_dtype == "float64" # But Numpy kept its original type.
): array_type == numpy_dtype
# We should keep float32. ):
assert Aesara_dtype == "float32" # Then we accept this difference in
continue # behavior.
if "array" in combo and "scalar" in combo: return
# For mixed scalar / array operations,
# Aesara may differ from numpy as it does if (
# not try to prevent the scalar from {a_type, b_type} == {"complex128", "float32"}
# upcasting the array. or {a_type, b_type} == {"complex128", "float16"}
array_type, scalar_type = ( and set(combo) == {"scalar", "array"}
(a_type, b_type)[list(combo).index(arg)] and aesara_dtype == "complex128"
for arg in ("array", "scalar") and numpy_dtype == "complex64"
) ):
up_type = aes.upcast(array_type, scalar_type) # In numpy 1.6.x adding a complex128 with
if ( # a float32 may result in a complex64. As
# The two data types are different. # of 1.9.2. this is still the case so it is
scalar_type != array_type # probably by design
and pytest.skip("Known issue with" "numpy see #761")
# The array type is not enough to hold # In any other situation: something wrong is
# the scalar type as well. # going on!
array_type != up_type raise AssertionError()
and
# Aesara upcasted the result array.
Aesara_dtype == up_type
and
# But Numpy kept its original type.
array_type == numpy_dtype
):
# Then we accept this difference in
# behavior.
continue
if (
cfg == "numpy+floatX"
and a_type == "complex128"
and (b_type == "float32" or b_type == "float16")
and combo == ("scalar", "array")
and Aesara_dtype == "complex128"
and numpy_dtype == "complex64"
):
# In numpy 1.6.x adding a complex128 with
# a float32 may result in a complex64. As
# of 1.9.2. this is still the case so it is
# probably by design
pytest.skip("Known issue with" "numpy see #761")
# In any other situation: something wrong is
# going on!
raise AssertionError()
def test_divmod(): def test_divmod():
......
...@@ -109,89 +109,6 @@ def eval_outputs(outputs, ops=(), mode=None): ...@@ -109,89 +109,6 @@ def eval_outputs(outputs, ops=(), mode=None):
return variables return variables
def get_numeric_subclasses(cls=np.number, ignore=None):
"""Return subclasses of `cls` in the numpy scalar hierarchy.
We only return subclasses that correspond to unique data types. The
hierarchy can be seen here:
http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html
"""
if ignore is None:
ignore = []
rval = []
dtype = np.dtype(cls)
dtype_num = dtype.num
if dtype_num not in ignore:
# Safety check: we should be able to represent 0 with this data type.
np.array(0, dtype=dtype)
rval.append(cls)
ignore.append(dtype_num)
for sub_ in cls.__subclasses__():
rval += [c for c in get_numeric_subclasses(sub_, ignore=ignore)]
return rval
def get_numeric_types(
with_int=True, with_float=True, with_complex=False, only_aesara_types=True
):
"""Return NumPy numeric data types.
Parameters
----------
with_int
Whether to include integer types.
with_float
Whether to include floating point types.
with_complex
Whether to include complex types.
only_aesara_types
If ``True``, then numpy numeric data types that are not supported by
Aesara are ignored (i.e. those that are not declared in
``scalar/basic.py``).
Returns
-------
A list of unique data type objects. Note that multiple data types may share
the same string representation, but can be differentiated through their
`num` attribute.
Note that when `only_aesara_types` is True we could simply return the list
of types defined in the `scalar` module. However with this function we can
test more unique dtype objects, and in the future we may use it to
automatically detect new data types introduced in numpy.
"""
if only_aesara_types:
aesara_types = [d.dtype for d in aesara.scalar.all_types]
rval = []
def is_within(cls1, cls2):
# Return True if scalars defined from `cls1` are within the hierarchy
# starting from `cls2`.
# The third test below is to catch for instance the fact that
# one can use ``dtype=numpy.number`` and obtain a float64 scalar, even
# though `numpy.number` is not under `numpy.floating` in the class
# hierarchy.
return (
cls1 is cls2
or issubclass(cls1, cls2)
or isinstance(np.array([0], dtype=cls1)[0], cls2)
)
for cls in get_numeric_subclasses():
dtype = np.dtype(cls)
if (
(not with_complex and is_within(cls, np.complexfloating))
or (not with_int and is_within(cls, np.integer))
or (not with_float and is_within(cls, np.floating))
or (only_aesara_types and dtype not in aesara_types)
):
# Ignore this class.
continue
rval.append([str(dtype), dtype, dtype.num])
# We sort it to be deterministic, then remove the string and num elements.
return [x[1] for x in sorted(rval, key=str)]
def _numpy_checker(x, y): def _numpy_checker(x, y):
"""Checks if `x.data` and `y.data` have the same contents. """Checks if `x.data` and `y.data` have the same contents.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论