提交 3500fec8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update support for unsigned integers in aesara.tensor.subtensor

上级 5b935bc6
......@@ -41,6 +41,10 @@ from aesara.tensor.type import (
iscalar,
lscalar,
tensor,
ubscalar,
uiscalar,
ulscalar,
uwscalar,
wscalar,
zscalar,
)
......@@ -50,12 +54,25 @@ from aesara.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
_logger = logging.getLogger("aesara.tensor.subtensor")
invalid_scal_types = (aes.float64, aes.float32, aes.float16)
scal_types = (aes.int64, aes.int32, aes.int16, aes.int8)
scal_types = (
aes.int64,
aes.int32,
aes.int16,
aes.int8,
aes.uint64,
aes.uint32,
aes.uint16,
aes.uint8,
)
tensor_types = (
lscalar,
iscalar,
wscalar,
bscalar,
ulscalar,
uiscalar,
uwscalar,
ubscalar,
)
invalid_tensor_types = (
fscalar,
......@@ -376,7 +393,7 @@ def slice_len(slc, n):
def is_basic_idx(idx):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integers is *not* considered a
XXX: This only checks a single index, so an integer is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
......
......@@ -781,6 +781,10 @@ bscalar = TensorType("int8", ())
wscalar = TensorType("int16", ())
iscalar = TensorType("int32", ())
lscalar = TensorType("int64", ())
ubscalar = TensorType("uint8", ())
uwscalar = TensorType("uint16", ())
uiscalar = TensorType("uint32", ())
ulscalar = TensorType("uint64", ())
def scalar(name=None, dtype=None):
......
......@@ -515,13 +515,13 @@ class _tensor_py_operators:
isinstance(val, np.ndarray) and val.size == 0
)
# Force input to be int64 datatype if input is an empty list or tuple
# Force input to be an int datatype if input is an empty list or tuple
# Else leave it as is if it is a real number
# Convert python literals to aesara constants
args = tuple(
[
at.subtensor.as_index_constant(
np.array(inp, dtype=np.int64) if is_empty_array(inp) else inp
np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp
)
for inp in args
]
......
......@@ -2615,3 +2615,8 @@ def test_index_vars_to_types():
res = index_vars_to_types(iscalar)
assert isinstance(res, scal.ScalarType)
x = scal.constant(1, dtype=np.uint8)
assert isinstance(x.type, scal.ScalarType)
res = index_vars_to_types(x)
assert res == x.type
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论