提交 9e7aea88 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix TensorVariable.__getitem__ np.newaxis cases

Closes #633
上级 e08dac2a
...@@ -184,8 +184,8 @@ class DimShuffle(ExternalCOp): ...@@ -184,8 +184,8 @@ class DimShuffle(ExternalCOp):
else: else:
# we cannot drop non-broadcastable dimensions # we cannot drop non-broadcastable dimensions
raise ValueError( raise ValueError(
"You cannot drop a non-broadcastable dimension.", "You cannot drop a non-broadcastable dimension:",
(input_broadcastable, new_order), f" {input_broadcastable}, {new_order}",
) )
# this is the list of the original dimensions that we keep # this is the list of the original dimensions that we keep
......
...@@ -564,6 +564,9 @@ class _tensor_py_operators: ...@@ -564,6 +564,9 @@ class _tensor_py_operators:
pattern.append(counter) pattern.append(counter)
counter += 1 counter += 1
new_args.append(arg) new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern) view = self.dimshuffle(pattern)
full_slices = True full_slices = True
for arg in new_args: for arg in new_args:
......
...@@ -38,7 +38,6 @@ from aesara.tensor.subtensor import ( ...@@ -38,7 +38,6 @@ from aesara.tensor.subtensor import (
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
col, col,
cscalar,
ctensor3, ctensor3,
dmatrix, dmatrix,
dscalar, dscalar,
...@@ -528,45 +527,6 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -528,45 +527,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
with pytest.raises(TypeError): with pytest.raises(TypeError):
test_array.__getitem__(([0, 1], [0, aesara.shared(True)])) test_array.__getitem__(([0, 1], [0, aesara.shared(True)]))
def test_newaxis(self):
# newaxis support comes from logic in the __getitem__ of TensorType
# Variables, which currently inserts dimshuffle to get the right number
# of dimensions, and adjusts the slice tuple accordingly.
#
# So testing is done via square-bracket notation rather than direct
# interaction with the Subtensor Op (which has no support of its own for
# newaxis).
newaxis = np.newaxis
n = self.shared(np.arange(24, dtype=self.dtype).reshape((2, 3, 4)))
assert n.ndim == 3
n4 = n[newaxis, :, :, :]
assert n4.broadcastable == (True, False, False, False), n4
n4 = n[:, newaxis, :, :]
assert n4.broadcastable == (False, True, False, False), n4
n4 = n[:, :, newaxis, :]
assert n4.broadcastable == (False, False, True, False), n4
n4 = n[:, :, :, newaxis]
assert n4.broadcastable == (False, False, False, True), n4
n3 = n.flatten()[newaxis, :, newaxis]
assert n3.broadcastable == (True, False, True), n3
s = cscalar()
s1 = s[newaxis]
assert s1.broadcastable == (True,), s1
vs1, vn3, vn4 = aesara.function([s], [s1, n3, n4], mode=self.mode)(-2.0)
assert np.all(vs1 == [-2.0])
assert np.all(vn3 == np.arange(24)[newaxis, :, newaxis])
assert np.all(vn4 == np.arange(24).reshape((2, 3, 4))[:, :, :, newaxis])
def test_grad_1d(self): def test_grad_1d(self):
subi = 0 subi = 0
data = np.asarray(random(2, 3), dtype=self.dtype) data = np.asarray(random(2, 3), dtype=self.dtype)
......
...@@ -8,7 +8,16 @@ from aesara.graph.basic import equal_computations ...@@ -8,7 +8,16 @@ from aesara.graph.basic import equal_computations
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import dot from aesara.tensor.math import dot
from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
from aesara.tensor.type import TensorType, dmatrix, dvector, iscalar, ivector, matrix from aesara.tensor.type import (
TensorType,
cscalar,
dmatrix,
dvector,
iscalar,
ivector,
matrix,
tensor3,
)
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
...@@ -185,3 +194,29 @@ def test__getitem__AdvancedSubtensor(): ...@@ -185,3 +194,29 @@ def test__getitem__AdvancedSubtensor():
z = x[i, None] z = x[i, None]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor assert op_types[-1] == AdvancedSubtensor
@pytest.mark.parametrize(
"x, indices, new_order",
[
(tensor3(), (np.newaxis, slice(None), np.newaxis), ("x", 0, "x", 1, 2)),
(cscalar(), (np.newaxis,), ("x",)),
(matrix(), (np.newaxis,), ("x", 0, 1)),
(matrix(), (np.newaxis, np.newaxis), ("x", "x", 0, 1)),
(matrix(), (np.newaxis, slice(None)), ("x", 0, 1)),
(matrix(), (np.newaxis, slice(None), slice(None)), ("x", 0, 1)),
(matrix(), (np.newaxis, np.newaxis, slice(None)), ("x", "x", 0, 1)),
(matrix(), (slice(None), np.newaxis), (0, "x", 1)),
(matrix(), (slice(None), slice(None), np.newaxis), (0, 1, "x")),
(
matrix(),
(np.newaxis, slice(None), np.newaxis, slice(None), np.newaxis),
("x", 0, "x", 1, "x"),
),
],
)
def test__getitem__newaxis(x, indices, new_order):
res = x[indices]
assert isinstance(res.owner.op, DimShuffle)
assert res.broadcastable == tuple(i == "x" for i in new_order)
assert res.owner.op.new_order == new_order
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论