提交 3500fec8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update support for unsigned integers in aesara.tensor.subtensor

上级 5b935bc6
...@@ -41,6 +41,10 @@ from aesara.tensor.type import ( ...@@ -41,6 +41,10 @@ from aesara.tensor.type import (
iscalar, iscalar,
lscalar, lscalar,
tensor, tensor,
ubscalar,
uiscalar,
ulscalar,
uwscalar,
wscalar, wscalar,
zscalar, zscalar,
) )
...@@ -50,12 +54,25 @@ from aesara.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice ...@@ -50,12 +54,25 @@ from aesara.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
_logger = logging.getLogger("aesara.tensor.subtensor") _logger = logging.getLogger("aesara.tensor.subtensor")
invalid_scal_types = (aes.float64, aes.float32, aes.float16) invalid_scal_types = (aes.float64, aes.float32, aes.float16)
scal_types = (aes.int64, aes.int32, aes.int16, aes.int8) scal_types = (
aes.int64,
aes.int32,
aes.int16,
aes.int8,
aes.uint64,
aes.uint32,
aes.uint16,
aes.uint8,
)
tensor_types = ( tensor_types = (
lscalar, lscalar,
iscalar, iscalar,
wscalar, wscalar,
bscalar, bscalar,
ulscalar,
uiscalar,
uwscalar,
ubscalar,
) )
invalid_tensor_types = ( invalid_tensor_types = (
fscalar, fscalar,
...@@ -376,7 +393,7 @@ def slice_len(slc, n): ...@@ -376,7 +393,7 @@ def slice_len(slc, n):
def is_basic_idx(idx): def is_basic_idx(idx):
"""Determine if an index is of the NumPy basic type. """Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integers is *not* considered a XXX: This only checks a single index, so an integer is *not* considered a
basic index, because--depending on the other indices its used with--an basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing. integer can indicate advanced indexing.
......
...@@ -781,6 +781,10 @@ bscalar = TensorType("int8", ()) ...@@ -781,6 +781,10 @@ bscalar = TensorType("int8", ())
wscalar = TensorType("int16", ()) wscalar = TensorType("int16", ())
iscalar = TensorType("int32", ()) iscalar = TensorType("int32", ())
lscalar = TensorType("int64", ()) lscalar = TensorType("int64", ())
ubscalar = TensorType("uint8", ())
uwscalar = TensorType("uint16", ())
uiscalar = TensorType("uint32", ())
ulscalar = TensorType("uint64", ())
def scalar(name=None, dtype=None): def scalar(name=None, dtype=None):
......
...@@ -515,13 +515,13 @@ class _tensor_py_operators: ...@@ -515,13 +515,13 @@ class _tensor_py_operators:
isinstance(val, np.ndarray) and val.size == 0 isinstance(val, np.ndarray) and val.size == 0
) )
# Force input to be int64 datatype if input is an empty list or tuple # Force input to be an int datatype if input is an empty list or tuple
# Else leave it as is if it is a real number # Else leave it as is if it is a real number
# Convert python literals to aesara constants # Convert python literals to aesara constants
args = tuple( args = tuple(
[ [
at.subtensor.as_index_constant( at.subtensor.as_index_constant(
np.array(inp, dtype=np.int64) if is_empty_array(inp) else inp np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp
) )
for inp in args for inp in args
] ]
......
...@@ -2615,3 +2615,8 @@ def test_index_vars_to_types(): ...@@ -2615,3 +2615,8 @@ def test_index_vars_to_types():
res = index_vars_to_types(iscalar) res = index_vars_to_types(iscalar)
assert isinstance(res, scal.ScalarType) assert isinstance(res, scal.ScalarType)
x = scal.constant(1, dtype=np.uint8)
assert isinstance(x.type, scal.ScalarType)
res = index_vars_to_types(x)
assert res == x.type
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论