提交 9f56e418 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove redundant AdvancedBoolean*Subtensor classes

This change also fixes the `set_subtensor` boolean gradient bug in #105.
上级 9733a2ab
......@@ -11,7 +11,6 @@ from theano.gpuarray.subtensor import (
GpuSubtensor,
GpuAdvancedSubtensor1,
GpuAdvancedSubtensor,
GpuAdvancedBooleanSubtensor,
GpuAdvancedIncSubtensor,
GpuAdvancedIncSubtensor1,
GpuAdvancedIncSubtensor1_dev20,
......@@ -37,7 +36,6 @@ class TestGPUSubtensor(TestSubtensor):
self.adv_sub1 = GpuAdvancedSubtensor1
self.adv_incsub1 = GpuAdvancedIncSubtensor1
self.adv_sub = GpuAdvancedSubtensor
self.adv_bool_sub = GpuAdvancedBooleanSubtensor
self.dimshuffle = GpuDimShuffle
self.mode = mode_with_gpu
# avoid errors with limited devices
......@@ -60,7 +58,6 @@ class TestGPUSubtensorF16(TestSubtensor):
self.adv_sub1 = GpuAdvancedSubtensor1
self.adv_incsub1 = GpuAdvancedIncSubtensor1
self.adv_sub = GpuAdvancedSubtensor
self.adv_bool_sub = GpuAdvancedBooleanSubtensor
self.dimshuffle = GpuDimShuffle
self.mode = mode_with_gpu
# avoid errors with limited devices
......
......@@ -46,8 +46,6 @@ from theano.tensor.subtensor import (
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
AdvancedBooleanSubtensor,
AdvancedBooleanIncSubtensor,
advanced_inc_subtensor,
advanced_inc_subtensor1,
advanced_set_subtensor,
......@@ -67,8 +65,6 @@ subtensor_ops = (
IncSubtensor,
AdvancedSubtensor1,
AdvancedIncSubtensor1,
AdvancedBooleanSubtensor,
AdvancedBooleanIncSubtensor,
)
......@@ -368,12 +364,10 @@ class TestSubtensor(utt.OptimizationTestMixin):
# indexing with a mask for some dimensions
mask = np.array([True, False])
val = self.eval_output_and_check(
test_array[mask], op_type=AdvancedBooleanSubtensor
)
val = self.eval_output_and_check(test_array[mask], op_type=AdvancedSubtensor)
assert_array_equal(test_array_np[mask], val)
val = self.eval_output_and_check(
inc_subtensor(test_array[mask], 1), op_type=AdvancedBooleanIncSubtensor
inc_subtensor(test_array[mask], 1), op_type=AdvancedIncSubtensor
)
assert_array_equal(numpy_inc_subtensor(test_array_np, mask, 1), val)
assert_array_equal(
......@@ -580,8 +574,8 @@ class TestSubtensor(utt.OptimizationTestMixin):
topo_ = [node for node in topo if not isinstance(node.op, DeepCopyOp)]
if not self.fast_compile:
assert len(topo_) == 6
assert np.sum([isinstance(node.op, IncSubtensor) for node in topo_]) == 1
assert np.sum([isinstance(node.op, Subtensor) for node in topo_]) == 1
assert any(isinstance(node.op, IncSubtensor) for node in topo_)
assert any(isinstance(node.op, Subtensor) for node in topo_)
gval = f()
good = np.zeros_like(data)
......@@ -1161,47 +1155,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
val = f()
assert np.allclose(val, data[idx].shape)
def test_grad_advanced_inc_subtensor(self):
def inc_slice(*s):
def just_numeric_args(a, b):
cost = (a[s] + b).sum()
cost_wrt_a = tensor.grad(cost, a)
cost_wrt_b = tensor.grad(cost, b)
grads = cost_wrt_a.sum() + cost_wrt_b.sum()
return grads
return just_numeric_args
# vector
utt.verify_grad(
inc_slice(slice(2, 4, None)),
(
np.asarray([0, 1, 2, 3, 4, 5.0]),
np.asarray([9, 9.0]),
),
mode=self.mode,
)
# matrix
utt.verify_grad(
inc_slice(slice(1, 2, None), slice(None, None, None)),
(
np.asarray([[0, 1], [2, 3], [4, 5.0]]),
np.asarray([[9, 9.0]]),
),
mode=self.mode,
)
# single element
utt.verify_grad(
inc_slice(2, 1),
(
np.asarray([[0, 1], [2, 3], [4, 5.0]]),
np.asarray(9.0),
),
mode=self.mode,
)
def test_inc_and_set_subtensor(self):
# Test increment and set with broadcast
......@@ -1323,21 +1276,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
all_params.append(
(set_instead_of_inc, inplace, data_shape, inc_shape)
)
if False: # Enable for debugging purpose.
f = self.function(
[data_var, idx_var, inc_var],
output,
accept_inplace=inplace,
op=AdvancedIncSubtensor1,
)
if inplace:
# Ensure calling `f` will not alter `data_num`.
data_num = data_num.copy()
f_out = f(data_num.copy(), idx_num, inc_num)
assert np.allclose(f_out, data_copy)
if not inplace:
# Sanity check: `data_num` should be intact.
assert (data_num == data_num_init).all()
# Actual test (we compile a single Theano function to make it faster).
orig_warn = theano.config.warn.gpu_set_subtensor1
......@@ -1647,18 +1585,18 @@ class TestAdvancedSubtensor:
rep[idx] += y_val
check(idx, y_val, x_val, rep)
def eval_output_and_check(self, t):
def eval_output_and_check(self, t, op):
f = inplace_func([], t, mode=self.mode)
topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op, DeepCopyOp)]
assert len(topo_) == 1
assert isinstance(topo_[0].op, AdvancedSubtensor)
assert isinstance(topo_[0].op, op)
tval = f()
return tval
def test_cant_adv_idx_into_scalar(self):
with pytest.raises(IndexError):
(lambda: self.s[self.ix1])()
self.s[self.ix1]
def test_index_into_vec_w_vec(self):
a = self.v[self.ix1]
......@@ -1698,7 +1636,7 @@ class TestAdvancedSubtensor:
assert isinstance(t.owner.op, AdvancedSubtensor)
val = self.eval_output_and_check(t)
val = self.eval_output_and_check(t, AdvancedSubtensor)
if isinstance(idx, list):
good = data[0, idx]
else:
......@@ -1942,6 +1880,35 @@ class TestAdvancedSubtensor:
mode=self.mode,
)
# Test boolean gradients
def fun(x, y):
return advanced_inc_subtensor(
x, y, tensor.as_tensor(np.array([[True, False], [False, True]]))
)
utt.verify_grad(
fun,
[
np.random.rand(2, 2).astype(self.dtype),
np.random.rand(2).astype(self.dtype),
],
mode=self.mode,
)
def fun(x, y):
return advanced_set_subtensor(
x, y, tensor.as_tensor(np.array([[True, False], [False, True]]))
)
utt.verify_grad(
fun,
[
np.random.rand(2, 2).astype(self.dtype),
np.random.rand(2).astype(self.dtype),
],
mode=self.mode,
)
class TestInferShape(utt.InferShapeTester):
def test_IncSubtensor(self):
......@@ -2216,7 +2183,7 @@ class TestInferShape(utt.InferShapeTester):
AdvancedSubtensor,
)
def test_AdvancedBooleanSubtensor(self):
def test_AdvancedSubtensor_bool(self):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))
......@@ -2225,14 +2192,14 @@ class TestInferShape(utt.InferShapeTester):
[n],
[n[n[:, 0] > 2, n[0, :] > 2]],
[n_val],
AdvancedBooleanSubtensor,
AdvancedSubtensor,
check_topo=False,
)
self._compile_and_check(
[n],
[n[n[:, 0] > 2]],
[n_val],
AdvancedBooleanSubtensor,
AdvancedSubtensor,
check_topo=False,
)
......
......@@ -11,7 +11,6 @@ from theano.tensor.var import TensorConstant
from theano.tensor.subtensor import (
Subtensor,
AdvancedSubtensor,
AdvancedBooleanSubtensor,
AdvancedSubtensor1,
)
from theano.tensor.elemwise import DimShuffle
......@@ -124,31 +123,30 @@ def test__getitem__Subtensor():
assert op_types[-1] == Subtensor
def test__getitem__AdvancedBooleanSubtensor():
# Make sure we get `AdvancedBooleanSubtensor`s for basic indexing operations
def test__getitem__AdvancedSubtensor_bool():
x = tt.matrix("x")
i = tt.type.TensorType("bool", (False, False))("i")
z = x[i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
assert op_types[-1] == AdvancedSubtensor
i = tt.type.TensorType("bool", (False,))("i")
z = x[:, i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
assert op_types[-1] == AdvancedSubtensor
i = tt.type.TensorType("bool", (False,))("i")
z = x[..., i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
assert op_types[-1] == AdvancedSubtensor
with pytest.raises(TypeError):
z = x[[True, False], i]
z = x[tt.ivector("b"), i]
op_types = [type(node.op) for node in theano.gof.graph.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedBooleanSubtensor
assert op_types[-1] == AdvancedSubtensor
def test__getitem__AdvancedSubtensor():
......
......@@ -130,11 +130,9 @@ from .subtensor import (
GpuSubtensor,
GpuAdvancedSubtensor,
GpuAdvancedSubtensor1,
GpuAdvancedBooleanSubtensor,
GpuAdvancedIncSubtensor,
GpuAdvancedIncSubtensor1,
GpuAdvancedIncSubtensor1_dev20,
GpuAdvancedBooleanIncSubtensor,
GpuAllocDiag,
GpuExtractDiag,
)
......@@ -1291,13 +1289,6 @@ def local_gpua_advanced_subtensor(op, context_name, inputs, outputs):
return GpuAdvancedSubtensor()
@register_opt("fast_compile")
@op_lifter([tensor.AdvancedBooleanSubtensor])
@register_opt2([tensor.AdvancedBooleanSubtensor], "fast_compile")
def local_gpua_advanced_boolean_subtensor(op, context_name, inputs, outputs):
return GpuAdvancedBooleanSubtensor()
@register_opt("fast_compile")
@op_lifter([tensor.AdvancedIncSubtensor1])
@register_opt2([tensor.AdvancedIncSubtensor1], "fast_compile")
......@@ -1342,20 +1333,6 @@ def local_gpua_advanced_incsubtensor(op, context_name, inputs, outputs):
return False
# Do not register this optimization for now, as it slows down the
# execution by a lot in important cases.
# @register_opt('fast_compile')
# @op_lifter([tensor.AdvancedBooleanIncSubtensor])
# @register_opt2([tensor.AdvancedBooleanIncSubtensor], 'fast_compile')
def local_gpua_advanced_boolean_incsubtensor(op, context_name, inputs, outputs):
# GpuAdvancedIncSubtensor only works with a single boolean mask,
# but not with fancy combinations.
if not op.set_instead_of_inc and len(inputs) == 3:
return GpuAdvancedBooleanIncSubtensor()
else:
return False
@register_inplace()
@local_optimizer([GpuAdvancedIncSubtensor1, GpuAdvancedIncSubtensor1_dev20])
def local_advincsub1_gpua_inplace(node):
......
......@@ -683,9 +683,6 @@ class GpuAdvancedSubtensor(HideC, BaseGpuAdvancedSubtensor, tensor.AdvancedSubte
def make_node(self, x, *inputs):
ctx_name = infer_context_name(x)
# This method relies on AdvancedSubtensor.make_node to
# call tensor.subtensor.check_and_reject_bool(inputs),
# which raises an IndexError if there are any boolean indices.
rval = tensor.AdvancedSubtensor.make_node(self, x, *inputs)
otype = GpuArrayType(
dtype=rval.outputs[0].type.dtype,
......@@ -696,25 +693,6 @@ class GpuAdvancedSubtensor(HideC, BaseGpuAdvancedSubtensor, tensor.AdvancedSubte
return gof.Apply(self, [x] + rval.inputs[1:], [otype()])
class GpuAdvancedBooleanSubtensor(
HideC, BaseGpuAdvancedSubtensor, tensor.AdvancedBooleanSubtensor
):
"""
AdvancedBooleanSubtensor on the GPU.
"""
def make_node(self, x, *inputs):
ctx_name = infer_context_name(x)
rval = tensor.AdvancedBooleanSubtensor.make_node(self, x, *inputs)
otype = GpuArrayType(
dtype=rval.outputs[0].type.dtype,
broadcastable=rval.outputs[0].type.broadcastable,
context_name=ctx_name,
)
x = as_gpuarray_variable(x, ctx_name)
return gof.Apply(self, [x] + rval.inputs[1:], [otype()])
class BaseGpuAdvancedIncSubtensor(object):
def perform(self, node, inp, out_):
(out,) = out_
......@@ -852,27 +830,6 @@ class GpuAdvancedIncSubtensor(
return gof.Apply(self, [x, y] + rval.inputs[2:], [otype()])
class GpuAdvancedBooleanIncSubtensor(
HideC, BaseGpuAdvancedIncSubtensor, tensor.AdvancedBooleanIncSubtensor
):
"""
Implement AdvancedBooleanIncSubtensor on the gpu.
"""
def make_node(self, x, y, *inputs):
ctx_name = infer_context_name(x, y)
rval = tensor.AdvancedBooleanIncSubtensor.make_node(self, x, y, *inputs)
otype = GpuArrayType(
dtype=rval.outputs[0].type.dtype,
broadcastable=rval.outputs[0].type.broadcastable,
context_name=ctx_name,
)
x = as_gpuarray_variable(x, ctx_name)
y = as_gpuarray_variable(y, ctx_name)
return gof.Apply(self, [x, y] + rval.inputs[2:], [otype()])
class GpuAdvancedIncSubtensor1(Op):
"""
Implement AdvancedIncSubtensor1 on the gpu.
......
......@@ -21,8 +21,8 @@ from theano.tensor.subtensor import (
AdvancedSubtensor1,
AdvancedIncSubtensor1,
# Boolean mask indexing and setting
BaseAdvancedSubtensor,
BaseAdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedIncSubtensor,
)
from theano.scan_module.scan_op import Scan
from theano.scan_module.scan_utils import scan_args as ScanArgs
......@@ -97,7 +97,7 @@ try:
except AttributeError:
pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
......@@ -588,18 +588,18 @@ def jax_funcify_IncSubtensor(op):
_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops]
@jax_funcify.register(BaseAdvancedIncSubtensor)
def jax_funcify_BaseAdvancedIncSubtensor(op):
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
else:
jax_fn = jax.ops.index_add
def baseadvancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return baseadvancedincsubtensor
return advancedincsubtensor
@jax_funcify.register(FunctionGraph)
......
......@@ -9,7 +9,6 @@ import theano
from textwrap import dedent
from itertools import groupby, chain
from collections.abc import Iterable
from six import integer_types
......@@ -42,15 +41,6 @@ class AdvancedIndexingError(TypeError):
pass
class AdvancedBooleanIndexingError(TypeError):
"""
Raised when Subtensor is asked to perform advanced indexing with boolean masks.
"""
pass
def as_index_constant(a):
"""Convert Python literals to Theano constants--when possible--in Subtensor arguments.
......@@ -501,9 +491,7 @@ class Subtensor(Op):
and hasattr(entry, "dtype")
and entry.dtype == "bool"
):
raise AdvancedBooleanIndexingError(
"Invalid index type or slice for Subtensor"
)
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
if isinstance(entry, gof.Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types
......@@ -1305,17 +1293,8 @@ def inc_subtensor(
elif isinstance(x.owner.op, AdvancedSubtensor):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1:]
the_op = AdvancedIncSubtensor(inplace, set_instead_of_inc=set_instead_of_inc)
return the_op(real_x, y, *ilist)
elif isinstance(x.owner.op, AdvancedBooleanSubtensor):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1:]
the_op = AdvancedBooleanIncSubtensor(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, *ilist)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
......@@ -2329,38 +2308,12 @@ def check_advanced_indexing_dimensions(input, idx_list):
dim_seen += 1
def check_and_reject_bool(args_el):
try:
if isinstance(args_el, (np.bool_, bool)) or args_el.dtype == "bool":
raise TypeError(
"AdvancedSubtensor does not support boolean "
"masks for indexing. Use AdvancedBooleanSubtensor "
"instead. "
)
except AttributeError:
pass
if not isinstance(args_el, theano.tensor.Variable) and isinstance(
args_el, Iterable
):
for el in args_el:
check_and_reject_bool(el)
class BaseAdvancedSubtensor(Op):
"""Abstract base class for AdvancedSubtensor and AdvancedBooleanSubtensor.
Implements advanced indexing with boolean masks.
Should be used by __getitem__ and __getslice__, as follows:
- AdvancedSubtensor()(self, *args) or
- AdvancedBooleanSubtensor()(self, *args), if args contain advanced indices
"""
class AdvancedSubtensor(Op):
"""Implements NumPy's advanced indexing."""
__props__ = ()
def make_node(self, x, *index, is_boolean=False):
def make_node(self, x, *index):
x = theano.tensor.as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
......@@ -2373,16 +2326,14 @@ class BaseAdvancedSubtensor(Op):
for bcast in x.broadcastable
)
bcast_index = index
if is_boolean:
bcast_index = tuple(
chain.from_iterable(
theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
else (idx,)
for idx in bcast_index
)
bcast_index = tuple(
chain.from_iterable(
theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0 and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
)
)
bcast = [
getattr(i, "value", i) == 1
......@@ -2443,17 +2394,6 @@ class BaseAdvancedSubtensor(Op):
return rval
class AdvancedSubtensor(BaseAdvancedSubtensor):
"""
Return a subtensor copy, using advanced indexing.
"""
def make_node(self, x, *index):
check_and_reject_bool(index)
return super(AdvancedSubtensor, self).make_node(x, *index)
def grad(self, inputs, grads):
(gz,) = grads
x = inputs[0]
......@@ -2473,39 +2413,8 @@ class AdvancedSubtensor(BaseAdvancedSubtensor):
advanced_subtensor = AdvancedSubtensor()
class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
"""
Return a subtensor copy, using advanced indexing with boolean masks.
"""
def make_node(self, x, *index):
return super().make_node(x, *index, is_boolean=True)
def grad(self, inputs, grads):
(gz,) = grads
x = inputs[0]
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rest = inputs[1:]
return [advanced_boolean_inc_subtensor(gx, gz, *rest)] + [
DisconnectedType()()
] * len(rest)
advanced_boolean_subtensor = AdvancedBooleanSubtensor()
class BaseAdvancedIncSubtensor(Op):
"""
Base class for AdvancedIncSubtensor and AdvancedBooleanIncSubtensor.
Increments a subtensor using advanced indexing.
"""
class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing."""
__props__ = ("inplace", "set_instead_of_inc")
......@@ -2578,16 +2487,6 @@ class BaseAdvancedIncSubtensor(Op):
return [None]
return self.make_node(eval_points[0], eval_points[1], *inputs[2:]).outputs
class AdvancedIncSubtensor(BaseAdvancedIncSubtensor):
"""
Increments a subtensor using advanced indexing.
"""
def make_node(self, x, y, *inputs):
check_and_reject_bool(inputs)
return super(AdvancedIncSubtensor, self).make_node(x, y, *inputs)
def grad(self, inpt, output_gradients):
x, y = inpt[:2]
idxs = inpt[2:]
......@@ -2617,40 +2516,6 @@ advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
class AdvancedBooleanIncSubtensor(BaseAdvancedIncSubtensor):
"""
Increments a subtensor using advanced indexing with boolean masks.
"""
def grad(self, inpt, output_gradients):
x, y = inpt[:2]
idxs = inpt[2:]
(outgrad,) = output_gradients
if x.dtype in theano.tensor.discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=theano.config.floatX)
if y.dtype in theano.tensor.discrete_dtypes:
gy = y.zeros_like(dtype=theano.config.floatX)
else:
gy = y.zeros_like()
elif x.dtype in theano.tensor.complex_dtypes:
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = advanced_set_subtensor(outgrad, y.zeros_like(), *idxs)
else:
gx = outgrad
gy = advanced_boolean_subtensor(outgrad, *idxs)
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()() for _ in idxs]
advanced_boolean_inc_subtensor = AdvancedBooleanIncSubtensor()
advanced_boolean_set_subtensor = AdvancedBooleanIncSubtensor(set_instead_of_inc=True)
def take(a, indices, axis=None, mode="raise"):
"""Take elements from an array along an axis.
......
......@@ -523,33 +523,30 @@ class _tensor_py_operators(object):
]
)
# Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing, or
# AdvancedBooleanIndexingError, advanced indexing with boolean masks
# Determine if advanced indexing is needed or not. The logic is
# already in `Subtensor.convert`: if it succeeds, standard indexing is
# used; if it fails with AdvancedIndexingError, advanced indexing is
# used
advanced = False
advanced_boolean = False
axis = None
for i, arg in enumerate(args):
try:
if arg is not np.newaxis:
theano.tensor.subtensor.Subtensor.convert(arg)
except theano.tensor.subtensor.AdvancedIndexingError:
if advanced:
axis = None
break
else:
advanced = True
axis = i
except theano.tensor.subtensor.AdvancedBooleanIndexingError:
advanced = False
advanced_boolean = True
if includes_bool(arg):
advanced = True
axis = None
break
if advanced_boolean:
return theano.tensor.subtensor.advanced_boolean_subtensor(self, *args)
elif advanced:
if arg is not np.newaxis:
try:
theano.tensor.subtensor.Subtensor.convert(arg)
except theano.tensor.subtensor.AdvancedIndexingError:
if advanced:
axis = None
break
else:
advanced = True
axis = i
if advanced:
if (
axis is not None
and all(isinstance(a, slice) and a == slice(None) for a in args[:axis])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论