提交 0845fa48 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix NoneConst handling in TensorVariable.__getitem__

上级 3500fec8
...@@ -15,6 +15,7 @@ from aesara.scalar import ComplexError, IntegerDivisionError ...@@ -15,6 +15,7 @@ from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length, as_tensor_variable from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.exceptions import AdvancedIndexingError
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
from aesara.tensor.type_other import NoneConst
from aesara.tensor.utils import hash_from_ndarray from aesara.tensor.utils import hash_from_ndarray
...@@ -466,7 +467,7 @@ class _tensor_py_operators: ...@@ -466,7 +467,7 @@ class _tensor_py_operators:
ellipses = [] ellipses = []
index_dim_count = 0 index_dim_count = 0
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg is np.newaxis: if arg is np.newaxis or arg is NoneConst:
# no increase in index_dim_count # no increase in index_dim_count
pass pass
elif arg is Ellipsis: elif arg is Ellipsis:
...@@ -537,7 +538,7 @@ class _tensor_py_operators: ...@@ -537,7 +538,7 @@ class _tensor_py_operators:
advanced = True advanced = True
break break
if arg is not np.newaxis: if arg is not np.newaxis and arg is not NoneConst:
try: try:
at.subtensor.index_vars_to_types(arg) at.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
...@@ -549,7 +550,7 @@ class _tensor_py_operators: ...@@ -549,7 +550,7 @@ class _tensor_py_operators:
if advanced: if advanced:
return at.subtensor.advanced_subtensor(self, *args) return at.subtensor.advanced_subtensor(self, *args)
else: else:
if np.newaxis in args: if np.newaxis in args or NoneConst in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since Aesara adds # broadcastable dimension at this location". Since Aesara adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the # new broadcastable dimensions via the `DimShuffle` `Op`, the
...@@ -561,7 +562,7 @@ class _tensor_py_operators: ...@@ -561,7 +562,7 @@ class _tensor_py_operators:
pattern = [] pattern = []
new_args = [] new_args = []
for arg in args: for arg in args:
if arg == np.newaxis: if arg is np.newaxis or arg is NoneConst:
pattern.append("x") pattern.append("x")
new_args.append(slice(None, None, None)) new_args.append(slice(None, None, None))
else: else:
...@@ -579,9 +580,9 @@ class _tensor_py_operators: ...@@ -579,9 +580,9 @@ class _tensor_py_operators:
# with some symbolic variable. # with some symbolic variable.
if not ( if not (
isinstance(arg, slice) isinstance(arg, slice)
and arg.start is None and (arg.start is None or arg.start is NoneConst)
and arg.stop is None and (arg.stop is None or arg.stop is NoneConst)
and arg.step is None and (arg.step is None or arg.step is NoneConst)
): ):
full_slices = False full_slices = False
if full_slices: if full_slices:
......
...@@ -26,7 +26,7 @@ from aesara.tensor.type import ( ...@@ -26,7 +26,7 @@ from aesara.tensor.type import (
scalar, scalar,
tensor3, tensor3,
) )
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice, NoneConst
from aesara.tensor.var import ( from aesara.tensor.var import (
DenseTensorConstant, DenseTensorConstant,
DenseTensorVariable, DenseTensorVariable,
...@@ -221,6 +221,7 @@ def test_print_constant(): ...@@ -221,6 +221,7 @@ def test_print_constant():
[ [
(tensor3(), (np.newaxis, slice(None), np.newaxis), ("x", 0, "x", 1, 2)), (tensor3(), (np.newaxis, slice(None), np.newaxis), ("x", 0, "x", 1, 2)),
(cscalar(), (np.newaxis,), ("x",)), (cscalar(), (np.newaxis,), ("x",)),
(cscalar(), (NoneConst,), ("x",)),
(matrix(), (np.newaxis,), ("x", 0, 1)), (matrix(), (np.newaxis,), ("x", 0, 1)),
(matrix(), (np.newaxis, np.newaxis), ("x", "x", 0, 1)), (matrix(), (np.newaxis, np.newaxis), ("x", "x", 0, 1)),
(matrix(), (np.newaxis, slice(None)), ("x", 0, 1)), (matrix(), (np.newaxis, slice(None)), ("x", 0, 1)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论