提交 82c2f79d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add direct tests for _tensor_py_operators.__getitem__ in test_var.py

上级 7bccd317
......@@ -7,14 +7,15 @@ import theano.tensor as tt
from numpy.testing import assert_equal, assert_string_equal
from theano.tensor import (
from theano.tensor.var import TensorConstant
from theano.tensor.subtensor import (
Subtensor,
AdvancedSubtensor,
AdvancedBooleanSubtensor,
AdvancedSubtensor1,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
)
from theano.tensor.elemwise import DimShuffle
from theano.tensor.type_other import MakeSlice
import tests.unittest_tools as utt
......@@ -79,37 +80,103 @@ def test_copy():
assert_string_equal(y.name, "y")
def test_None_dimShuffle_replace():
# tests replacing None usage in subtensor with dimshuffle
#
# tests whenever None is used in subtensor to reshape a variable, it is
# replaced by dimshuffle. If the replacement is done properly, Subtensor op
# (or any of its variants) should not be used anymore.
def test__getitem__Subtensor():
# Make sure we get `Subtensor`s for basic indexing operations
x = tt.matrix("x")
i = tt.iscalar("i")
x = tt.dmatrix("x")
y = x[:, None, :]
f = theano.function([x], y)
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
]
x = tt.tensor3("x")
y1 = x[:, :, None, :]
y2 = x[None, :, :, None, :]
y3 = x[:, :, None, :, None, None]
f = theano.function([x], [y1, y2, y3])
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
]
z = x[i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
# This should ultimately do nothing (i.e. just return `x`)
z = x[()]
assert len(z.owner.op.idx_list) == 0
# assert z is x
# This is a poorly placed optimization that produces a `DimShuffle`
# It lands in the `full_slices` condition in
# `_tensor_py_operators.__getitem__`
z = x[..., None]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert all(op_type == DimShuffle for op_type in op_types)
z = x[None, :, None, :]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert all(op_type == DimShuffle for op_type in op_types)
# This one lands in the non-`full_slices` condition in
# `_tensor_py_operators.__getitem__`
z = x[:i, :, None]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[1:] == [DimShuffle, Subtensor]
z = x[:]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
z = x[..., :]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
z = x[..., i, :]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == Subtensor
def test__getitem__AdvancedBooleanSubtensor():
# Make sure we get `AdvancedBooleanSubtensor`s for basic indexing operations
x = tt.matrix("x")
i = tt.type.TensorType("bool", (False, False))("i")
z = x[i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
i = tt.type.TensorType("bool", (False,))("i")
z = x[:, i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
i = tt.type.TensorType("bool", (False,))("i")
z = x[..., i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
with pytest.raises(TypeError):
z = x[[True, False], i]
z = x[tt.ivector("b"), i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
def test__getitem__AdvancedSubtensor():
# Make sure we get `AdvancedSubtensor`s for basic indexing operations
x = tt.matrix("x")
i = tt.ivector("i")
# This is a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor1
# This should index nothing (i.e. return an empty copy of `x`)
# We check that the index is empty
z = x[[]]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types == [AdvancedSubtensor1]
assert isinstance(z.owner.inputs[1], TensorConstant)
# This is also a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[:, i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types == [DimShuffle, AdvancedSubtensor1, DimShuffle]
z = x[..., i, None]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[i, None]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论