提交 73d798ab authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add general shape inference for all types of non-boolean indexing

上级 b210efbc
import logging import logging
import sys import sys
import pytest import pytest
import numpy as np import numpy as np
import theano import theano
import theano.scalar as scal import theano.scalar as scal
import theano.tensor as tensor import theano.tensor as tensor
from six import StringIO
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from theano import config
from six import StringIO
from theano import config, change_flags
from theano.compile import DeepCopyOp from theano.compile import DeepCopyOp
from theano.gof.op import get_test_value
from theano.gof.toolbox import is_same_graph from theano.gof.toolbox import is_same_graph
from theano.tensor import ( from theano.tensor import (
_shared, _shared,
...@@ -32,6 +38,8 @@ from theano.tensor import ( ...@@ -32,6 +38,8 @@ from theano.tensor import (
) )
from theano.tensor.basic import DimShuffle from theano.tensor.basic import DimShuffle
from theano.tensor.subtensor import ( from theano.tensor.subtensor import (
basic_shape,
indexed_result_shape,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedSubtensor, AdvancedSubtensor,
...@@ -45,7 +53,7 @@ from theano.tensor.subtensor import ( ...@@ -45,7 +53,7 @@ from theano.tensor.subtensor import (
inc_subtensor, inc_subtensor,
set_subtensor, set_subtensor,
) )
from theano import change_flags from theano.tensor.type_other import make_slice
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.test_basic import inplace_func, rand, randint_ranged from tests.tensor.test_basic import inplace_func, rand, randint_ranged
...@@ -1815,8 +1823,7 @@ class TestAdvancedSubtensor: ...@@ -1815,8 +1823,7 @@ class TestAdvancedSubtensor:
], ],
), aval ), aval
def test_advanced_indexing(self): def test_2d_3d_tensors(self):
# tests advanced indexing in Theano for 2D and 3D tensors
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
a = rng.uniform(size=(3, 3)) a = rng.uniform(size=(3, 3))
b = theano.shared(a) b = theano.shared(a)
...@@ -1920,9 +1927,7 @@ class TestAdvancedSubtensor: ...@@ -1920,9 +1927,7 @@ class TestAdvancedSubtensor:
class TestInferShape(utt.InferShapeTester): class TestInferShape(utt.InferShapeTester):
@pytest.mark.slow def test_IncSubtensor(self):
def test_infer_shape(self):
# IncSubtensor
admat = dmatrix() admat = dmatrix()
bdmat = dmatrix() bdmat = dmatrix()
advec = dvector() advec = dvector()
...@@ -2044,7 +2049,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2044,7 +2049,7 @@ class TestInferShape(utt.InferShapeTester):
IncSubtensor, IncSubtensor,
) )
# AdvancedIncSubtensor1 def test_AdvancedIncSubtensor1(self):
admat = dmatrix() admat = dmatrix()
bdmat = dmatrix() bdmat = dmatrix()
advec = dvector() advec = dvector()
...@@ -2074,6 +2079,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2074,6 +2079,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
) )
adtens4 = dtensor4()
bdtens4 = dtensor4() bdtens4 = dtensor4()
adtens4_val = rand(4, 3, 2, 5) adtens4_val = rand(4, 3, 2, 5)
aivec_val = [2, 3] aivec_val = [2, 3]
...@@ -2152,7 +2158,10 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2152,7 +2158,10 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
) )
# AdvancedIncSubtensor def test_AdvancedIncSubtensor(self):
admat = dmatrix()
advec = dvector()
admat_val = rand(5, 4)
aivec_val = [1, 3, 2] aivec_val = [1, 3, 2]
bivec_val = [0, 3, 3] bivec_val = [0, 3, 3]
advec_val = [23, 24, 25] advec_val = [23, 24, 25]
...@@ -2163,7 +2172,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2163,7 +2172,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedIncSubtensor, AdvancedIncSubtensor,
) )
def test_adv_sub(self): def test_AdvancedSubtensor(self):
admat = dmatrix() admat = dmatrix()
aivec = lvector() aivec = lvector()
bivec = lvector() bivec = lvector()
...@@ -2177,23 +2186,20 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2177,23 +2186,20 @@ class TestInferShape(utt.InferShapeTester):
[admat_val, aivec_val, bivec_val], [admat_val, aivec_val, bivec_val],
AdvancedSubtensor, AdvancedSubtensor,
) )
# Test case that aren't implemented, but make sure they do not crash.
self._compile_and_check( self._compile_and_check(
[admat, aivec], [admat, aivec],
[admat[aivec, 1:3]], [admat[aivec, 1:3]],
[admat_val, aivec_val], [admat_val, aivec_val],
AdvancedSubtensor, AdvancedSubtensor,
check_topo=False,
) )
self._compile_and_check( self._compile_and_check(
[admat, aivec], [admat, aivec],
[admat[1:3, aivec]], [admat[1:3, aivec]],
[admat_val, aivec_val], [admat_val, aivec_val],
AdvancedSubtensor, AdvancedSubtensor,
check_topo=False,
) )
def test_boolean(self): def test_AdvancedBooleanSubtensor(self):
n = dmatrix() n = dmatrix()
n_val = np.arange(6).reshape((2, 3)) n_val = np.arange(6).reshape((2, 3))
...@@ -2212,3 +2218,97 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2212,3 +2218,97 @@ class TestInferShape(utt.InferShapeTester):
tensor.AdvancedBooleanSubtensor, tensor.AdvancedBooleanSubtensor,
check_topo=False, check_topo=False,
) )
@change_flags(compute_test_value="raise")
def test_basic_shape():
test_shape = (5, 4)
test_indices = (make_slice(1, 3, None),)
res = basic_shape(test_shape, test_indices)
assert get_test_value(res) == (2,)
@change_flags(compute_test_value="raise")
def test_indexed_result_shape():
_test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True]))
test_shape = (5, 6, 7, 8)
test_array = np.arange(np.prod(test_shape)).reshape(test_shape)
def idx_as_tensor(x):
if isinstance(x, (slice, type(None))):
return x
else:
return tensor.as_tensor(x)
def bcast_shape_tuple(x):
if not hasattr(x, "shape"):
return x
return tuple(
s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable)
)
def compare_index_shapes(test_array, test_idx):
res = indexed_result_shape(
tensor.as_tensor(test_array).shape, [idx_as_tensor(i) for i in test_idx]
)
exp_res = test_array[test_idx].shape
assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res)
# Test shape-only version
res = indexed_result_shape(
tensor.as_tensor(test_array).shape,
[bcast_shape_tuple(idx_as_tensor(i)) for i in test_idx],
indices_are_shapes=True,
)
exp_res = test_array[test_idx].shape
assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res)
# Simple basic indices
test_idx = (slice(None, None),)
compare_index_shapes(test_array, test_idx)
# Advanced indices
test_idx = (2,)
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1]
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:2]
compare_index_shapes(test_array, test_idx)
# A Mix of advanced and basic indices
test_idx = _test_idx[:2] + (slice(None, None),)
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None),) + _test_idx[1:]
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None), None) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_idx = (np.array(1), slice(None, None), None)
compare_index_shapes(test_array, test_idx)
test_idx = (slice(None, None), None, np.array(1))
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1] + (slice(None, None),) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_idx = (
_test_idx[:1] + (slice(None, None),) + _test_idx[1:2] + (slice(None, None),)
)
compare_index_shapes(test_array, test_idx)
test_idx = _test_idx[:1] + (None,) + _test_idx[1:2]
compare_index_shapes(test_array, test_idx)
test_shape = (5, 4)
test_array = np.arange(np.prod(test_shape)).reshape(test_shape)
test_idx = ([1, 3, 2], slice(1, 3))
compare_index_shapes(test_array, test_idx)
test_idx = (slice(1, 3), [1, 3, 2])
compare_index_shapes(test_array, test_idx)
import sys import sys
from textwrap import dedent
import warnings import warnings
import logging import logging
import numpy as np import numpy as np
from six import integer_types
import theano import theano
from theano.gradient import DisconnectedType from textwrap import dedent
from theano import gof
from itertools import groupby, chain
from collections.abc import Iterable
from six import integer_types
from theano import gof, scalar as scal, config
from theano.gof import Apply, hashtype, Op, Type, MethodNotDefined, ParamsType from theano.gof import Apply, hashtype, Op, Type, MethodNotDefined, ParamsType
from theano.gradient import DisconnectedType
from theano.printing import pprint from theano.printing import pprint
from theano import scalar as scal
from theano.tensor.basic import alloc
from theano.tensor.basic import ( from theano.tensor.basic import (
alloc,
addbroadcast, addbroadcast,
clip, clip,
get_scalar_constant_value, get_scalar_constant_value,
...@@ -23,17 +26,12 @@ from theano.tensor.basic import ( ...@@ -23,17 +26,12 @@ from theano.tensor.basic import (
NotScalarConstantError, NotScalarConstantError,
) )
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.tensor.inc_code import inc_code
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice from theano.tensor.type_other import NoneConst, SliceType, NoneTypeT, make_slice
from theano import config
from theano.compat import Iterable
from .inc_code import inc_code
_logger = logging.getLogger("theano.tensor.subtensor") _logger = logging.getLogger("theano.tensor.subtensor")
# Do a lazy import of the sparse module
sparse_module_ref = None
class AdvancedIndexingError(TypeError): class AdvancedIndexingError(TypeError):
""" """
...@@ -53,16 +51,8 @@ class AdvancedBooleanIndexingError(TypeError): ...@@ -53,16 +51,8 @@ class AdvancedBooleanIndexingError(TypeError):
pass pass
##########
# Helpful functions to deal with Subtensor and IncSubtensor
##########
def make_constant(args): def make_constant(args):
""" """Convert Python literals to Theano constants in Subtensor arguments."""
Convert python litterals to theano constants in subtensor arguments.
"""
def conv(a): def conv(a):
if a is None: if a is None:
...@@ -72,6 +62,7 @@ def make_constant(args): ...@@ -72,6 +62,7 @@ def make_constant(args):
elif isinstance(a, (integer_types, np.integer)): elif isinstance(a, (integer_types, np.integer)):
return scal.ScalarConstant(scal.int64, a) return scal.ScalarConstant(scal.int64, a)
else: else:
# Use `tensor.scalar_from_tensor`?
return a return a
return tuple(map(conv, args)) return tuple(map(conv, args))
...@@ -112,7 +103,8 @@ def get_idx_list(inputs, idx_list, get_count=False): ...@@ -112,7 +103,8 @@ def get_idx_list(inputs, idx_list, get_count=False):
def get_canonical_form_slice(theslice, length): def get_canonical_form_slice(theslice, length):
""" """Convert slices to canonical form.
Given a slice [start:stop:step] transform it into a canonical form Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy. that respects the conventions imposed by python and numpy.
...@@ -277,6 +269,162 @@ def get_canonical_form_slice(theslice, length): ...@@ -277,6 +269,162 @@ def get_canonical_form_slice(theslice, length):
return value, 1 return value, 1
def range_len(slc):
"""Length of a `range` object.
Adapted from CPython.
"""
from theano.tensor import switch, and_, lt, gt
start, stop, step = make_constant([slc.start, slc.stop, slc.step])
return switch(
and_(gt(step, 0), lt(start, stop)),
1 + (stop - 1 - start) // step,
switch(
and_(lt(step, 0), gt(start, stop)),
1 + (start - 1 - stop) // (-step),
scal.ScalarConstant(scal.int64, 0),
),
)
def slice_len(slc, n):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
"""
# TODO: Do we need to do this or should we expect `slc` to
# already be canonicalized?
canon_slc, _ = get_canonical_form_slice(slc, n)
return range_len(canon_slc)
def is_basic_idx(idx):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integers is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
"""
return isinstance(idx, (slice, type(None))) or isinstance(
getattr(idx, "type", None), (SliceType, NoneTypeT)
)
def basic_shape(shape, indices):
"""Computes the shape resulting from basic NumPy indexing.
Basic indices are either `slice`s or `None`s.
`Ellipsis` are not supported here; convert them to `slice`s first.
Parameters
----------
shape: Tuple[int]
The shape of the array being indexed
indices: Sequence[Or[slice, NoneType]]
A sequence of basic indices used to index an array.
"""
res_shape = ()
for idx, n in zip(indices, shape):
if isinstance(idx, slice):
res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner:
idx_inputs = idx.owner.inputs
else:
idx_inputs = (None,)
res_shape += (slice_len(slice(*idx_inputs), n),)
elif idx is None:
res_shape += (scal.ScalarConstant(scal.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT):
res_shape += (scal.ScalarConstant(scal.int64, 1),)
else:
raise ValueError("Invalid index type: {}".format(idx))
return res_shape
def group_indices(indices):
"""Group indices sequentially by whether or not they're basic or advanced.
Returns
-------
Tuple[Boolean, List[Tuple[Integer, Any]]]
The boolean indicates whether or not the group is a set of basic
indices. The list contains the contiguous set of indices paired with their
corresponding dimension number in the array being indexed.
"""
idx_groups = []
dim_num = -1
for basic, grp_indices in groupby(indices, key=is_basic_idx):
enum_grp_indices = []
for idx in grp_indices:
# We "zip" the dimension number to each index, which means we can't
# count indices that add new axes
if (idx is not None) and not isinstance(
getattr(idx, "type", None), NoneTypeT
):
dim_num += 1
enum_grp_indices.append((dim_num, idx))
idx_groups.append((basic, enum_grp_indices))
return idx_groups
def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
"""Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`.
This function uses NumPy's basic and advanced indexing logic. It can also
handle combinations of advanced and basic indices.
Parameters
----------
array_shape: Tuple[Variable]
Shape of the array being indexed.
indices: Sequence[Union[TensorVariable, Tuple[Union[None, slice, Variable]]]]
Either the indices themselves or the shapes of each index--depending
on the value of `indices_are_shapes`.
indices_are_shapes: bool (Optional)
Indicates whether or not the `indices` contains shape tuples instead of
the actual index arrays. If you use this approach, make sure that the
broadcastable dimensions are (scalar) constants with the value `1`, or `1`
exactly.
"""
res_shape = ()
remaining_dims = range(theano.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices)
if len(idx_groups) > 2 or len(idx_groups) > 1 and not idx_groups[0][0]:
# Bring adv. index groups to the front and merge each group
idx_groups = sorted(idx_groups, key=lambda x: x[0])
idx_groups = groupby(
chain.from_iterable(d_idx for _, d_idx in idx_groups),
key=lambda x: is_basic_idx(x[1]),
)
for basic, grp_dim_indices in idx_groups:
dim_nums, grp_indices = zip(*grp_dim_indices)
remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums)
if basic:
grp_shapes = tuple(array_shape[dim] for dim in dim_nums)
res_shape += basic_shape(grp_shapes, grp_indices)
else:
res_shape += broadcast_shape(
*grp_indices, arrays_are_shapes=indices_are_shapes
)
res_shape += tuple(array_shape[dim] for dim in remaining_dims)
return res_shape
class Subtensor(Op): class Subtensor(Op):
""" """
Return a subtensor view. Return a subtensor view.
...@@ -1783,14 +1931,6 @@ def _sum_grad_over_bcasted_dims(x, gx): ...@@ -1783,14 +1931,6 @@ def _sum_grad_over_bcasted_dims(x, gx):
return gx return gx
#########################
# Advanced indexing
#########################
#
# Should reproduce numpy's behaviour, see url:
# docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
class AdvancedSubtensor1(Op): class AdvancedSubtensor1(Op):
""" """
Implement x[ilist] where ilist is a vector of integers. Implement x[ilist] where ilist is a vector of integers.
...@@ -1858,7 +1998,6 @@ class AdvancedSubtensor1(Op): ...@@ -1858,7 +1998,6 @@ class AdvancedSubtensor1(Op):
return rval return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
global sparse_module_ref
x, ilist = inputs x, ilist = inputs
(gz,) = grads (gz,) = grads
assert len(inputs) == 2 assert len(inputs) == 2
...@@ -1868,10 +2007,8 @@ class AdvancedSubtensor1(Op): ...@@ -1868,10 +2007,8 @@ class AdvancedSubtensor1(Op):
"AdvancedSubtensor1: you can't take the sparse grad" "AdvancedSubtensor1: you can't take the sparse grad"
" from a tensor with ndim != 2. ndim is " + str(x.type.ndim) " from a tensor with ndim != 2. ndim is " + str(x.type.ndim)
) )
if sparse_module_ref is None:
import theano.sparse as sparse_module_ref
rval1 = [sparse_module_ref.construct_sparse_from_list(x, gz, ilist)] rval1 = [theano.sparse.construct_sparse_from_list(x, gz, ilist)]
else: else:
if x.dtype in theano.tensor.discrete_dtypes: if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x # The output dtype is the same as x
...@@ -2203,45 +2340,6 @@ def as_index_variable(idx): ...@@ -2203,45 +2340,6 @@ def as_index_variable(idx):
return idx return idx
def adv_index_broadcastable_pattern(a, idx):
"""
This function is only used to determine the broadcast pattern for
AdvancedSubtensor output variable.
For this, we make a fake ndarray and a fake idx and call use ask numpy
the output. From this, we find the output broadcast pattern.
"""
def replace_slice(v):
if isinstance(v, gof.Apply):
if len(v.outputs) != 1:
raise ValueError(
"It is ambiguous which output of a multi-output Op has"
" to be fetched.",
v,
)
else:
v = v.outputs[0]
if NoneConst.equals(v):
return None
if isinstance(v.type, SliceType):
return slice(None, None)
if v.dtype == "bool":
return np.ones((2,) * v.ndim, v.dtype)
else:
return np.zeros((2,) * v.ndim, int)
newidx = tuple(map(replace_slice, idx))
# 2 - True = 1; 2 - False = 2
fakeshape = [2 - bc for bc in a.broadcastable]
retshape = np.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape])
def check_advanced_indexing_dimensions(input, idx_list): def check_advanced_indexing_dimensions(input, idx_list):
""" """
This function checks if the index list in idx_list is correct. This function checks if the index list in idx_list is correct.
...@@ -2288,23 +2386,33 @@ def check_and_reject_bool(args_el): ...@@ -2288,23 +2386,33 @@ def check_and_reject_bool(args_el):
class BaseAdvancedSubtensor(Op): class BaseAdvancedSubtensor(Op):
""" """Abstract base class for AdvancedSubtensor and AdvancedBooleanSubtensor.
Abstract base class for AdvancedSubtensor and AdvancedBooleanSubtensor.
Implements advanced indexing with boolean masks. Implements advanced indexing with boolean masks.
Should be used by __getitem__ and __getslice__, as follows:
- AdvancedSubtensor()(self, *args) or
- AdvancedBooleanSubtensor()(self, *args), if args contain advanced indices
""" """
# Should be used by __getitem__ and __getslice__, as follows:
# AdvancedSubtensor()(self, *args) or
# AdvancedBooleanSubtensor()(self, *args),
# if args contains and advanced indexing pattern
__props__ = () __props__ = ()
def make_node(self, x, *index): def make_node(self, x, *index):
x = theano.tensor.as_tensor_variable(x) x = theano.tensor.as_tensor_variable(x)
index = tuple(map(as_index_variable, index)) index = tuple(map(as_index_variable, index))
bcast = adv_index_broadcastable_pattern(x, index)
# We only want the broadcast information, and we don't need recursive
# `Subtensor` calls, so we create a fake symbolic shape tuple and
# identify the broadcast dimensions from the shape result of this
# entire subtensor operation.
fake_shape = tuple(
theano.tensor.tensor(dtype="int64", broadcastable=()) if not bcast else 1
for bcast in x.broadcastable
)
bcast = [
getattr(i, "value", i) == 1 for i in indexed_result_shape(fake_shape, index)
]
return gof.Apply( return gof.Apply(
self, self,
(x,) + index, (x,) + index,
...@@ -2317,8 +2425,26 @@ class BaseAdvancedSubtensor(Op): ...@@ -2317,8 +2425,26 @@ class BaseAdvancedSubtensor(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
# Default case, we don't know indices = node.inputs[1:]
raise theano.tensor.basic.ShapeError("case not implemented") index_shapes = list(ishapes[1:])
for i, idx in enumerate(indices):
if (
isinstance(idx, (np.bool_, bool))
or getattr(idx, "dtype", None) == "bool"
):
raise theano.tensor.basic.ShapeError(
"Shape inference for boolean indices is not implemented"
)
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
if isinstance(getattr(idx, "type", None), SliceType):
index_shapes[i] = idx
res_shape = indexed_result_shape(
ishapes[0], index_shapes, indices_are_shapes=True
)
assert node.outputs[0].ndim == len(res_shape)
return [[s for s in res_shape]]
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
...@@ -2348,34 +2474,10 @@ class AdvancedSubtensor(BaseAdvancedSubtensor): ...@@ -2348,34 +2474,10 @@ class AdvancedSubtensor(BaseAdvancedSubtensor):
""" """
# Should be used by __getitem__ and __getslice__, as follows:
# AdvancedSubtensor()(self, *args),
# if args contains and advanced indexing pattern
def make_node(self, x, *index): def make_node(self, x, *index):
check_and_reject_bool(index) check_and_reject_bool(index)
return super(AdvancedSubtensor, self).make_node(x, *index) return super(AdvancedSubtensor, self).make_node(x, *index)
def infer_shape(self, node, ishapes):
# Really special case
if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes
if (
len(xshp) == 2
and ind1shp is not None
and len(ind1shp) == 1
and ind2shp is not None
and len(ind2shp) == 1
):
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
if node.inputs[2].owner is None:
return [ind2shp]
else:
return [ind1shp]
return super(AdvancedSubtensor, self).infer_shape(node, ishapes)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(gz,) = grads (gz,) = grads
x = inputs[0] x = inputs[0]
...@@ -2401,10 +2503,6 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor): ...@@ -2401,10 +2503,6 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
""" """
# Should be used by __getitem__ and __getslice__, as follows:
# AdvancedBooleanSubtensor()(self, *args),
# if args contains and advanced indexing pattern with boolean masks
def grad(self, inputs, grads): def grad(self, inputs, grads):
(gz,) = grads (gz,) = grads
x = inputs[0] x = inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论