提交 9e7aea88 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix TensorVariable.__getitem__ np.newaxis cases

Closes #633
上级 e08dac2a
......@@ -184,8 +184,8 @@ class DimShuffle(ExternalCOp):
else:
# we cannot drop non-broadcastable dimensions
raise ValueError(
"You cannot drop a non-broadcastable dimension.",
(input_broadcastable, new_order),
"You cannot drop a non-broadcastable dimension:",
f" {input_broadcastable}, {new_order}",
)
# this is the list of the original dimensions that we keep
......
......@@ -564,6 +564,9 @@ class _tensor_py_operators:
pattern.append(counter)
counter += 1
new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern)
full_slices = True
for arg in new_args:
......
......@@ -38,7 +38,6 @@ from aesara.tensor.subtensor import (
from aesara.tensor.type import (
TensorType,
col,
cscalar,
ctensor3,
dmatrix,
dscalar,
......@@ -528,45 +527,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
with pytest.raises(TypeError):
test_array.__getitem__(([0, 1], [0, aesara.shared(True)]))
def test_newaxis(self):
# newaxis support comes from logic in the __getitem__ of TensorType
# Variables, which currently inserts dimshuffle to get the right number
# of dimensions, and adjusts the slice tuple accordingly.
#
# So testing is done via square-bracket notation rather than direct
# interaction with the Subtensor Op (which has no support of its own for
# newaxis).
newaxis = np.newaxis
n = self.shared(np.arange(24, dtype=self.dtype).reshape((2, 3, 4)))
assert n.ndim == 3
n4 = n[newaxis, :, :, :]
assert n4.broadcastable == (True, False, False, False), n4
n4 = n[:, newaxis, :, :]
assert n4.broadcastable == (False, True, False, False), n4
n4 = n[:, :, newaxis, :]
assert n4.broadcastable == (False, False, True, False), n4
n4 = n[:, :, :, newaxis]
assert n4.broadcastable == (False, False, False, True), n4
n3 = n.flatten()[newaxis, :, newaxis]
assert n3.broadcastable == (True, False, True), n3
s = cscalar()
s1 = s[newaxis]
assert s1.broadcastable == (True,), s1
vs1, vn3, vn4 = aesara.function([s], [s1, n3, n4], mode=self.mode)(-2.0)
assert np.all(vs1 == [-2.0])
assert np.all(vn3 == np.arange(24)[newaxis, :, newaxis])
assert np.all(vn4 == np.arange(24).reshape((2, 3, 4))[:, :, :, newaxis])
def test_grad_1d(self):
subi = 0
data = np.asarray(random(2, 3), dtype=self.dtype)
......
......@@ -8,7 +8,16 @@ from aesara.graph.basic import equal_computations
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import dot
from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
from aesara.tensor.type import TensorType, dmatrix, dvector, iscalar, ivector, matrix
from aesara.tensor.type import (
TensorType,
cscalar,
dmatrix,
dvector,
iscalar,
ivector,
matrix,
tensor3,
)
from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant
......@@ -185,3 +194,29 @@ def test__getitem__AdvancedSubtensor():
z = x[i, None]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor
@pytest.mark.parametrize(
"x, indices, new_order",
[
(tensor3(), (np.newaxis, slice(None), np.newaxis), ("x", 0, "x", 1, 2)),
(cscalar(), (np.newaxis,), ("x",)),
(matrix(), (np.newaxis,), ("x", 0, 1)),
(matrix(), (np.newaxis, np.newaxis), ("x", "x", 0, 1)),
(matrix(), (np.newaxis, slice(None)), ("x", 0, 1)),
(matrix(), (np.newaxis, slice(None), slice(None)), ("x", 0, 1)),
(matrix(), (np.newaxis, np.newaxis, slice(None)), ("x", "x", 0, 1)),
(matrix(), (slice(None), np.newaxis), (0, "x", 1)),
(matrix(), (slice(None), slice(None), np.newaxis), (0, 1, "x")),
(
matrix(),
(np.newaxis, slice(None), np.newaxis, slice(None), np.newaxis),
("x", 0, "x", 1, "x"),
),
],
)
def test__getitem__newaxis(x, indices, new_order):
res = x[indices]
assert isinstance(res.owner.op, DimShuffle)
assert res.broadcastable == tuple(i == "x" for i in new_order)
assert res.owner.op.new_order == new_order
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论