提交 0845fa48 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix NoneConst handling in TensorVariable.__getitem__

上级 3500fec8
......@@ -15,6 +15,7 @@ from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor.exceptions import AdvancedIndexingError
from aesara.tensor.type import TensorType
from aesara.tensor.type_other import NoneConst
from aesara.tensor.utils import hash_from_ndarray
......@@ -466,7 +467,7 @@ class _tensor_py_operators:
ellipses = []
index_dim_count = 0
for i, arg in enumerate(args):
if arg is np.newaxis:
if arg is np.newaxis or arg is NoneConst:
# no increase in index_dim_count
pass
elif arg is Ellipsis:
......@@ -537,7 +538,7 @@ class _tensor_py_operators:
advanced = True
break
if arg is not np.newaxis:
if arg is not np.newaxis and arg is not NoneConst:
try:
at.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
......@@ -549,7 +550,7 @@ class _tensor_py_operators:
if advanced:
return at.subtensor.advanced_subtensor(self, *args)
else:
if np.newaxis in args:
if np.newaxis in args or NoneConst in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since Aesara adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
......@@ -561,7 +562,7 @@ class _tensor_py_operators:
pattern = []
new_args = []
for arg in args:
if arg == np.newaxis:
if arg is np.newaxis or arg is NoneConst:
pattern.append("x")
new_args.append(slice(None, None, None))
else:
......@@ -579,9 +580,9 @@ class _tensor_py_operators:
# with some symbolic variable.
if not (
isinstance(arg, slice)
and arg.start is None
and arg.stop is None
and arg.step is None
and (arg.start is None or arg.start is NoneConst)
and (arg.stop is None or arg.stop is NoneConst)
and (arg.step is None or arg.step is NoneConst)
):
full_slices = False
if full_slices:
......
......@@ -26,7 +26,7 @@ from aesara.tensor.type import (
scalar,
tensor3,
)
from aesara.tensor.type_other import MakeSlice
from aesara.tensor.type_other import MakeSlice, NoneConst
from aesara.tensor.var import (
DenseTensorConstant,
DenseTensorVariable,
......@@ -221,6 +221,7 @@ def test_print_constant():
[
(tensor3(), (np.newaxis, slice(None), np.newaxis), ("x", 0, "x", 1, 2)),
(cscalar(), (np.newaxis,), ("x",)),
(cscalar(), (NoneConst,), ("x",)),
(matrix(), (np.newaxis,), ("x", 0, 1)),
(matrix(), (np.newaxis, np.newaxis), ("x", "x", 0, 1)),
(matrix(), (np.newaxis, slice(None)), ("x", 0, 1)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论