提交 73d798ab authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add general shape inference for all types of non-boolean indexing

上级 b210efbc
import logging
import sys
import pytest
import numpy as np
import theano
import theano.scalar as scal
import theano.tensor as tensor
from six import StringIO
from numpy.testing import assert_array_equal
from theano import config
from six import StringIO
from theano import config, change_flags
from theano.compile import DeepCopyOp
from theano.gof.op import get_test_value
from theano.gof.toolbox import is_same_graph
from theano.tensor import (
_shared,
......@@ -32,6 +38,8 @@ from theano.tensor import (
)
from theano.tensor.basic import DimShuffle
from theano.tensor.subtensor import (
basic_shape,
indexed_result_shape,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
......@@ -45,7 +53,7 @@ from theano.tensor.subtensor import (
inc_subtensor,
set_subtensor,
)
from theano import change_flags
from theano.tensor.type_other import make_slice
from tests import unittest_tools as utt
from tests.tensor.test_basic import inplace_func, rand, randint_ranged
......@@ -1815,8 +1823,7 @@ class TestAdvancedSubtensor:
],
), aval
def test_advanced_indexing(self):
# tests advanced indexing in Theano for 2D and 3D tensors
def test_2d_3d_tensors(self):
rng = np.random.RandomState(utt.fetch_seed())
a = rng.uniform(size=(3, 3))
b = theano.shared(a)
......@@ -1920,9 +1927,7 @@ class TestAdvancedSubtensor:
class TestInferShape(utt.InferShapeTester):
@pytest.mark.slow
def test_infer_shape(self):
# IncSubtensor
def test_IncSubtensor(self):
admat = dmatrix()
bdmat = dmatrix()
advec = dvector()
......@@ -2044,7 +2049,7 @@ class TestInferShape(utt.InferShapeTester):
IncSubtensor,
)
# AdvancedIncSubtensor1
def test_AdvancedIncSubtensor1(self):
admat = dmatrix()
bdmat = dmatrix()
advec = dvector()
......@@ -2074,6 +2079,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor1,
)
adtens4 = dtensor4()
bdtens4 = dtensor4()
adtens4_val = rand(4, 3, 2, 5)
aivec_val = [2, 3]
......@@ -2152,7 +2158,10 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor1,
)
# AdvancedIncSubtensor
def test_AdvancedIncSubtensor(self):
admat = dmatrix()
advec = dvector()
admat_val = rand(5, 4)
aivec_val = [1, 3, 2]
bivec_val = [0, 3, 3]
advec_val = [23, 24, 25]
......@@ -2163,7 +2172,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor,
)
def test_adv_sub(self):
def test_AdvancedSubtensor(self):
admat = dmatrix()
aivec = lvector()
bivec = lvector()
......@@ -2177,23 +2186,20 @@ class TestInferShape(utt.InferShapeTester):
[admat_val, aivec_val, bivec_val],
AdvancedSubtensor,
)
# Test case that aren't implemented, but make sure they do not crash.
self._compile_and_check(
[admat, aivec],
[admat[aivec, 1:3]],
[admat_val, aivec_val],
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[admat, aivec],
[admat[1:3, aivec]],
[admat_val, aivec_val],
AdvancedSubtensor,
check_topo=False,
)
def test_boolean(self):
def test_AdvancedBooleanSubtensor(self):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))
......@@ -2212,3 +2218,97 @@ class TestInferShape(utt.InferShapeTester):
tensor.AdvancedBooleanSubtensor,
check_topo=False,
)
@change_flags(compute_test_value="raise")
def test_basic_shape():
test_shape = (5, 4)
test_indices = (make_slice(1, 3, None),)
res = basic_shape(test_shape, test_indices)
assert get_test_value(res) == (2,)
@change_flags(compute_test_value="raise")
def test_indexed_result_shape():
_test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
test_shape = (5, 6, 7, 8)
test_array = np.arange(np.prod(test_shape)).reshape(test_shape)
def idx_as_tensor(x):
if isinstance(x, (slice, type(None))):
return x
else:
return tensor.as_tensor(x)
def bcast_shape_tuple(x):
if not hasattr(x, "shape"):
return x
return tuple(
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
)
def compare_index_shapes(test_array, test_idx):
res = indexed_result_shape(
tensor.as_tensor(test_array).shape, [idx_as_tensor(i) for i in test_idx]
)
exp_res = test_array[test_idx].shape
assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res)
# Test shape-only version
res = indexed_result_shape(
tensor.as_tensor(test_array).shape,
[bcast_shape_tuple(idx_as_tensor(i)) for i in test_idx],
indices_are_shapes=True,
)
exp_res = test_array[test_idx].shape
assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res)
# Simple basic indices
test_idx = (slice(None, None),)
compare_index_shapes(test_array, test_idx)
# Advanced indices
test_idx = (2,)
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1]
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:2]
compare_index_shapes(test_array, test_idx)
# A Mix of advanced and basic indices
test_idx = _test_idx[:2] + (slice(None, None),)
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None),) + _test_idx[1:]
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None), None) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_idx = (np.array(1), slice(None, None), None)
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None), None, np.array(1))
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1] + (slice(None, None),) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_idx = (
_test_idx[:1] + (slice(None, None),) + _test_idx[1:2] + (slice(None, None),)
)
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1] + (None,) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_shape = (5, 4)
test_array = np.arange(np.prod(test_shape)).reshape(test_shape)
test_idx = ([1, 3, 2], slice(1, 3))
compare_index_shapes(test_array, test_idx)
test_idx = (slice(1, 3), [1, 3, 2])
compare_index_shapes(test_array, test_idx)
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论