提交 0dbd512b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace `take` in _tensor_py_operators.__getitem__ with an optimization

上级 47207edb
......@@ -26,9 +26,22 @@ if enable_sparse:
.. versionadded:: 0.6rc4
"""
from aesara.tensor.subtensor import AdvancedSubtensor1
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
assert isinstance(var.owner.op, AdvancedSubtensor1)
if var.owner is None or not isinstance(
var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1)
):
raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
)
ret = var.owner.op.__class__(sparse_grad=True)(*var.owner.inputs)
x = var.owner.inputs[0]
indices = var.owner.inputs[1:]
if len(indices) > 1:
raise TypeError(
"Sparse gradient is only implemented for single advanced indexing"
)
ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0])
return ret
......@@ -58,6 +58,7 @@ from aesara.tensor import (
blas_scipy,
nnet,
opt_uncanonicalize,
subtensor_opt,
xlogx,
)
......
......@@ -2534,12 +2534,16 @@ def local_useless_inc_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(fgraph, node):
"""
r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it
did this wouldn't need to also be included in the "specialize" pass.
"""
if (
isinstance(node.op, AdvancedIncSubtensor1)
......@@ -2567,9 +2571,9 @@ def local_set_to_inc_subtensor(fgraph, node):
if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]:
return
ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])
# Copy over previous output stacktrace
# Julian: I'm not sure about this at all...
copy_stack_trace(node.outputs, ret)
return [ret]
......@@ -3448,7 +3452,7 @@ def local_setsubtensor_of_constants(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer([AdvancedSubtensor1])
def local_adv_sub1_adv_inc_sub1(fgraph, node):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
......
import aesara
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.tensor.basic_opt import register_specialize
from aesara.tensor.shape import shape_tuple
from aesara.tensor.sharedvar import TensorSharedVariable
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
advanced_subtensor1,
inc_subtensor,
)
from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from aesara.tensor.var import TensorConstant, TensorVariable
def transform_take(a, indices, axis):
r"""Transform ``arr[:,:,:,indices,...]``-like operations into single-dimensional, vector index operations.
This effectively converts certain `AdvancedSubtensor` `Op`\s into a
combination of `AdvancedSubtensor1`, `Dimshuffle`, and `Reshape` `Op`\s,
which can be more efficient.
Parameters
----------
a : TensorVariable
The source array.
indices : TensorVariable, ndarray, list, tuple
The indices of the values to extract.
axis : int
The axis over which to select values. By default, the flattened
input array is used.
"""
a = aesara.tensor.as_tensor_variable(a)
indices = aesara.tensor.as_tensor_variable(indices)
# We can use the more efficient `AdvancedSubtensor1` if `indices` is a vector
if indices.ndim == 1:
if axis == 0:
return advanced_subtensor1(a, indices)
else:
shuffle = list(range(a.ndim))
shuffle[0] = axis
shuffle[axis] = 0
res = advanced_subtensor1(a.dimshuffle(shuffle), indices).dimshuffle(
shuffle
)
return res
# We can reshape and flatten the indices in order to use an
# `AdvancedSubtensor1` `Op` per the above
indices_shape = shape_tuple(indices)
a_shape = shape_tuple(a)
shape_parts = [
a_shape[:axis],
indices_shape,
a_shape[axis + 1 :],
]
shape_parts = [sp for sp in shape_parts if len(sp) > 0]
assert len(shape_parts) > 0
if len(shape_parts) > 1:
shape = aesara.tensor.concatenate(shape_parts)
else:
shape = shape_parts[0]
ndim = a.ndim + indices.ndim - 1
return transform_take(a, indices.flatten(), axis).reshape(shape, ndim)
def is_full_slice(x):
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
if (
(isinstance(x, slice) and x == slice(None))
or (isinstance(x, SliceConstant) and x.value == slice(None))
or (
not isinstance(x, SliceConstant)
and isinstance(getattr(x, "type", None), SliceType)
and x.owner is not None
and all(
isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs
)
)
):
return True
return False
def get_advsubtensor_axis(indices):
"""Determine the axis at which an array index is applied.
This only works for ``take``-like indices: e.g. ``x[:, :, idx, ...]``. For
the above example, `get_advsubtensor_axis` would return ``2``. If it
encounters anything other than a set of `indices` containing full slices
and an array/tensor index, it will return ``None``.
"""
found_idx = False
axis = 0
for idx in indices:
if not found_idx and is_full_slice(idx):
# Preceding full slices
axis += 1
elif found_idx and not is_full_slice(idx):
# We don't handle multiple indices
return
elif found_idx and is_full_slice(idx):
# Trailing full slices
continue
else:
found_idx = True
if isinstance(
indices[axis], (TensorConstant, TensorVariable, TensorSharedVariable)
):
return axis
@register_specialize
@local_optimizer([AdvancedSubtensor])
def local_replace_AdvancedSubtensor(fgraph, node):
r"""
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
a vector ``y``, and ``X[z, ...]`` into ``X[z.flatten()].reshape(...)``, for a
matrix ``z``.
These rewrites replace `AdvancedSubtensor`\s with the more efficient
`AdvancedSubtensor1` and `Subtensor` `Op`\s.
"""
if not isinstance(node.op, AdvancedSubtensor):
return
indexed_var = node.inputs[0]
indices = node.inputs[1:]
axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool":
# Booleans aren't handled
return
new_res = transform_take(indexed_var, indices[axis], axis)
assert new_res.broadcastable == node.outputs[0].broadcastable
copy_stack_trace(node.outputs[0], new_res)
return [new_res]
@register_specialize
@local_optimizer([AdvancedIncSubtensor])
def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
This is only done when there's a single vector index.
"""
if not isinstance(node.op, AdvancedIncSubtensor):
return
res = node.inputs[0]
val = node.inputs[1]
indices = node.inputs[2:]
axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool":
# Booleans aren't handled
return
new_subtensor = transform_take(res, indices[axis], axis)
set_instead_of_inc = node.op.set_instead_of_inc
inplace = node.op.inplace
new_res = inc_subtensor(
new_subtensor, val, inplace=inplace, set_instead_of_inc=set_instead_of_inc
)
copy_stack_trace(node.outputs[0], new_res)
return [new_res]
......@@ -528,11 +528,9 @@ class _tensor_py_operators:
# used; if it fails with AdvancedIndexingError, advanced indexing is
# used
advanced = False
axis = None
for i, arg in enumerate(args):
if includes_bool(arg):
advanced = True
axis = None
break
if arg is not np.newaxis:
......@@ -540,43 +538,12 @@ class _tensor_py_operators:
aet.subtensor.Subtensor.convert(arg)
except AdvancedIndexingError:
if advanced:
axis = None
break
else:
advanced = True
axis = i
if advanced:
if (
axis is not None
and all(isinstance(a, slice) and a == slice(None) for a in args[:axis])
and all(
isinstance(a, slice) and a == slice(None) for a in args[axis + 1 :]
)
# I.e. if the first advanced index is a tensor or NumPy array,
# then it can't be boolean (in order to meet this condition).
# How could this possibly occur; we filter for booleans above,
# right?
# and (not hasattr(args[axis], "dtype") or args[axis].dtype != "bool")
and isinstance(
args[axis],
(
np.ndarray,
list,
TensorVariable,
TensorConstant,
aet.sharedvar.TensorSharedVariable,
),
)
):
# If we're here, it means that an advanced index was found
# (e.g. an array of indices) and it was surrounded by full
# slices--or no slices (e.g. `x[:, :, idx, ...]`). The
# `take` function/`Op` serves exactly this type of indexing,
# so we simply return its result.
return self.take(args[axis], axis)
else:
return aet.subtensor.advanced_subtensor(self, *args)
return aet.subtensor.advanced_subtensor(self, *args)
else:
if np.newaxis in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
......
......@@ -611,7 +611,7 @@ def test_jax_Subtensors():
compare_jax_and_py(out_fg, [])
# Advanced indexing
out_aet = x_aet[[1, 2]]
out_aet = aet_subtensor.advanced_subtensor1(x_aet, [1, 2])
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
......@@ -623,7 +623,7 @@ def test_jax_Subtensors():
# Advanced and basic indexing
out_aet = x_aet[[1, 2], :]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, [])
......
......@@ -410,15 +410,11 @@ def test_Subtensor(x, indices):
"x, indices",
[
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None)),
),
],
)
def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension."""
out_aet = x[[1, 2]]
out_aet = aet_subtensor.advanced_subtensor1(x, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
......@@ -493,26 +489,21 @@ def test_IncSubtensor(x, y, indices):
aet.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2],),
),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2], slice(None)),
),
],
)
def test_AdvancedIncSubtensor1(x, y, indices):
out_aet = aet.set_subtensor(x[indices], y)
out_aet = aet_subtensor.advanced_set_subtensor1(x, y, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
out_aet = aet.inc_subtensor(x[indices], y)
out_aet = aet_subtensor.advanced_inc_subtensor1(x, y, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, [])
x_at = x.type()
out_aet = aet.set_subtensor(x_at[indices], y, inplace=True)
out_aet = aet_subtensor.AdvancedIncSubtensor1(inplace=True)(x_at, y, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_at], [out_aet])
compare_numba_and_py(out_fg, [x.data])
......
......@@ -88,7 +88,7 @@ from aesara.tensor.basic import MakeVector
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Shape_i
from aesara.tensor.subtensor import AdvancedIncSubtensor1, AdvancedSubtensor1, Subtensor
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import (
TensorType,
float_dtypes,
......@@ -644,11 +644,19 @@ class TestConstructSparseFromList:
def test_adv_sub1_sparse_grad(self):
v = ivector()
# Assert we don't create a sparse grad by default
m = matrix()
with pytest.raises(TypeError):
aesara.sparse.sparse_grad(v)
with pytest.raises(TypeError):
sub = m[v, v]
aesara.sparse.sparse_grad(sub)
# Assert we don't create a sparse grad by default
sub = m[v]
g = aesara.grad(sub.sum(), m)
assert isinstance(g.owner.op, AdvancedIncSubtensor1)
assert isinstance(g.owner.op, AdvancedIncSubtensor)
# Test that we create a sparse grad when asked
# USER INTERFACE
......@@ -685,7 +693,7 @@ class TestConstructSparseFromList:
# Assert we don't create a sparse grad by default
g = aesara.grad(sub.sum(), t)
assert isinstance(g.owner.op, AdvancedIncSubtensor1)
assert isinstance(g.owner.op, AdvancedIncSubtensor)
# Test that we raise an error, as we can't create a sparse
# grad from tensors that don't have 2 dimensions.
......
......@@ -1672,7 +1672,11 @@ class TestSubtensorIncSubtensor:
@classmethod
def setup_class(cls):
cls.rng = np.random.default_rng(utt.fetch_seed())
cls.mode = get_default_mode().including("local_subtensor_inc_subtensor")
cls.mode = get_default_mode().including(
"local_subtensor_inc_subtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
"local_replace_AdvancedSubtensor",
)
@pytest.mark.parametrize(
"val, indices, optype",
......@@ -1685,11 +1689,13 @@ class TestSubtensorIncSubtensor:
def test_inplace(self, val, indices, optype):
x = matrix("x")
y = set_subtensor((2 * x)[indices], val, inplace=False)
assert isinstance(y.owner.op, optype)
assert y.owner.op.inplace is False
f = function(
[x, val] + list(indices), y, mode=get_default_mode().including("inplace")
[x, val] + list(indices),
y,
mode=self.mode.including("inplace"),
)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, optype)
assert f.maker.fgraph.outputs[0].owner.op.inplace is True
def test_basic(self):
......@@ -2602,7 +2608,11 @@ class TestLocalAdvSub1AdvIncSub1:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including("local_adv_sub1_adv_inc_sub1").excluding("fusion")
self.mode = mode.including(
"local_replace_AdvancedSubtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
"local_adv_sub1_adv_inc_sub1",
).excluding("fusion")
self.mode_no_assert = self.mode.including("local_remove_all_assert")
def test_basic(self):
......@@ -2969,8 +2979,13 @@ def test_local_set_to_inc_subtensor():
s = v[[2, 1]]
g = s + 3
r = set_subtensor(s, g)
moder = get_default_mode().excluding("local_set_to_inc_subtensor")
modet = get_default_mode().including("local_set_to_inc_subtensor")
mode = get_default_mode().including(
"local_replace_AdvancedSubtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
)
moder = mode.excluding("local_set_to_inc_subtensor")
modet = mode.including("local_set_to_inc_subtensor")
f1 = function([v], r, mode=moder)
f2 = function([v], r, mode=modet)
......@@ -3453,8 +3468,8 @@ class TestLocalUselessIncSubtensorAlloc:
utt.assert_allclose(r1, r2)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor)
assert check_stack_trace(f2, ops_to_check=AdvancedIncSubtensor)
assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor1)
assert check_stack_trace(f2, ops_to_check=AdvancedIncSubtensor1)
def test_advanced_inc_subtensor1(self):
x = vector("x")
......
......@@ -44,7 +44,7 @@ from aesara.tensor.extra_ops import (
unravel_index,
)
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.type import (
TensorType,
dmatrix,
......@@ -1174,7 +1174,7 @@ class TestBroadcastTo(utt.InferShapeTester):
e_fn = function([d], e, mode=py_mode)
advincsub_node = e_fn.maker.fgraph.outputs[0].owner
assert isinstance(advincsub_node.op, AdvancedIncSubtensor1)
assert isinstance(advincsub_node.op, AdvancedIncSubtensor)
assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)
assert advincsub_node.op.inplace is False
......
......@@ -504,3 +504,11 @@ def test_nonstandard_shapes():
none_shape = shape(NoneConst)
assert np.array_equal(none_shape.get_test_value(), [])
def test_shape_i_basics():
with pytest.raises(TypeError):
Shape_i(0)([1, 2])
with pytest.raises(TypeError):
Shape_i(0)(scalar())
......@@ -33,6 +33,7 @@ from aesara.tensor.subtensor import (
inc_subtensor,
indexed_result_shape,
set_subtensor,
take,
)
from aesara.tensor.type import (
TensorType,
......@@ -81,7 +82,11 @@ class TestSubtensor(utt.OptimizationTestMixin):
self.shared = shared
self.dtype = config.floatX
mode = aesara.compile.mode.get_default_mode()
self.mode = mode.including("local_useless_subtensor")
self.mode = mode.including(
"local_replace_AdvancedSubtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
"local_useless_subtensor",
)
self.fast_compile = config.mode == "FAST_COMPILE"
def function(
......@@ -332,26 +337,24 @@ class TestSubtensor(utt.OptimizationTestMixin):
numpy_n = np.arange(24, dtype=self.dtype).reshape((2, 3, 4))
n = self.shared(numpy_n)
test_cases = [
(0, Subtensor, Subtensor, np.index_exp[...]),
(1, Subtensor, Subtensor, np.index_exp[..., 1]),
(1, Subtensor, Subtensor, np.index_exp[1, ...]),
(1, Subtensor, Subtensor, np.index_exp[..., 1, 2, 3]),
(1, Subtensor, Subtensor, np.index_exp[1, ..., 2, 3]),
(1, Subtensor, Subtensor, np.index_exp[1, 2, 3, ...]),
(3, DimShuffle, DimShuffle, np.index_exp[..., [0, 2, 3]]),
(1, DimShuffle, DimShuffle, np.index_exp[np.newaxis, ...]),
(0, Subtensor, np.index_exp[...]),
(1, Subtensor, np.index_exp[..., 1]),
(1, Subtensor, np.index_exp[1, ...]),
(1, Subtensor, np.index_exp[..., 1, 2, 3]),
(1, Subtensor, np.index_exp[1, ..., 2, 3]),
(1, Subtensor, np.index_exp[1, 2, 3, ...]),
(3, DimShuffle, np.index_exp[..., [0, 2, 3]]),
(1, DimShuffle, np.index_exp[np.newaxis, ...]),
(
1,
AdvancedSubtensor,
AdvancedSubtensor,
np.index_exp[..., np.newaxis, [1, 2]],
),
]
for length, op_type, op_type_opt, slice_ in test_cases:
for length, op_type_opt, slice_ in test_cases:
numpy_tval = numpy_n[slice_]
t = n[slice_]
assert isinstance(t.owner.op, op_type)
tval = self.eval_output_and_check(t, op_type=op_type_opt, length=length)
assert tval.shape == numpy_tval.shape
assert_array_equal(tval, numpy_tval)
......@@ -641,9 +644,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(data)
t = n[idx]
# We test again AdvancedSubtensor1 as we transfer data to the cpu.
assert isinstance(t.owner.op, AdvancedSubtensor1)
val = self.eval_output_and_check(t, op_type=AdvancedSubtensor1)
if isinstance(idx, list):
good = data[idx]
......@@ -681,7 +681,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
idx = [2, 2, 0, 0, 1, 1]
n = self.shared(data)
t = n[self.shared(np.asarray(idx).astype("int64"))[::2]]
assert isinstance(t.owner.op, AdvancedSubtensor1)
val = self.eval_output_and_check(t, op_type=AdvancedSubtensor1, length=2)
utt.assert_allclose(data[idx[::2]], val)
......@@ -699,12 +698,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(np.ones((2, 3), dtype=self.dtype) * 5)
l = lvector()
t = n[l]
# We test again AdvancedSubtensor1 as we transfer data to the cpu.
assert isinstance(t.owner.op, AdvancedSubtensor1)
f = self.function([l], t, op=AdvancedSubtensor1)
# the grad
g = self.function(
[l],
inc_subtensor(t, np.asarray([[1.0]], self.dtype)),
......@@ -722,7 +718,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(v * 5, broadcastable=(True, False))
idx = lvector()
t = n[idx]
assert isinstance(t.owner.op, AdvancedSubtensor1)
f = self.function([idx], t, op=AdvancedSubtensor1)
topo = f.maker.fgraph.toposort()
......@@ -806,7 +801,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
idx = TensorType(dtype="int64", broadcastable=(True,))()
assert idx.type.broadcastable == (True,)
t = n[idx]
assert isinstance(t.owner.op, AdvancedSubtensor1)
f = self.function([idx], t, op=AdvancedSubtensor1)
topo = f.maker.fgraph.toposort()
......@@ -1435,12 +1429,41 @@ class TestSubtensor(utt.OptimizationTestMixin):
finally:
config.warn__inc_set_subtensor1 = orig_warn
def test_take(self):
a = matrix()
f = aesara.function(
[a], a.take(0, axis=-1), allow_input_downcast=True, mode=self.mode
)
f(np.random.normal(0, 1, (30, 4)))
def test_take_basic():
with pytest.raises(TypeError):
take(matrix(), lvector(), axis=lscalar())
@pytest.mark.parametrize(
"a, index, axis, mode",
[
(matrix(), lvector(), -1, None),
(matrix(), lvector(), 0, None),
(matrix(), lvector(), 1, None),
(matrix(), lvector(), 1, "clip"),
(matrix(), lvector(), 1, "wrap"),
],
)
def test_take_cases(a, index, axis, mode):
fn_mode = aesara.compile.mode.get_default_mode()
fn_mode = fn_mode.including(
"local_useless_subtensor",
# "local_replace_AdvancedSubtensor",
)
f = aesara.function([a, index], a.take(index, axis=axis, mode=mode), mode=fn_mode)
a_val = np.arange(3 * 3).reshape((3, 3)).astype(config.floatX)
if mode is None:
index_val = np.array([0, 1], dtype=np.int64)
else:
index_val = np.array([-1, 2], dtype=np.int64)
py_res = a_val.take(index_val, axis=axis, mode=mode)
f_res = f(a_val, index_val)
assert np.array_equal(py_res, f_res)
class TestIncSubtensor:
......@@ -2237,7 +2260,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[admat, bdmat],
[set_subtensor(admat[aivec_val], bdmat)],
[advanced_set_subtensor1(admat, bdmat, aivec_val)],
[admat_val, [[1, 2, 3, 4]]],
AdvancedIncSubtensor1,
)
......@@ -2245,7 +2268,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [1, 3, 2]
self._compile_and_check(
[admat, advec],
[set_subtensor(admat[aivec_val], advec)],
[advanced_set_subtensor1(admat, advec, aivec_val)],
[admat_val, [1, 2, 3, 4]],
AdvancedIncSubtensor1,
)
......@@ -2253,7 +2276,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [0, 3, 0]
self._compile_and_check(
[admat, adscal],
[set_subtensor(admat[aivec_val], adscal)],
[advanced_set_subtensor1(admat, adscal, aivec_val)],
[admat_val, 1],
AdvancedIncSubtensor1,
)
......@@ -2264,7 +2287,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[adtens4, bdtens4],
[set_subtensor(adtens4[aivec_val], bdtens4)],
[advanced_set_subtensor1(adtens4, bdtens4, aivec_val)],
[adtens4_val, [[[[1, 2, 3, 4, 5]]]]],
AdvancedIncSubtensor1,
warn=False,
......@@ -2273,7 +2296,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [1, 3, 2]
self._compile_and_check(
[adtens4, advec],
[set_subtensor(adtens4[aivec_val], advec)],
[advanced_set_subtensor1(adtens4, advec, aivec_val)],
[adtens4_val, [1, 2, 3, 4, 5]],
AdvancedIncSubtensor1,
)
......@@ -2281,7 +2304,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [0, 3, 0]
self._compile_and_check(
[adtens4, adscal],
[set_subtensor(adtens4[aivec_val], adscal)],
[advanced_set_subtensor1(adtens4, adscal, aivec_val)],
[adtens4_val, 1],
AdvancedIncSubtensor1,
)
......@@ -2289,7 +2312,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[admat, bdmat],
[inc_subtensor(admat[aivec_val], bdmat)],
[advanced_set_subtensor1(admat, bdmat, aivec_val)],
[admat_val, [[1, 2, 3, 4], [5, 6, 7, 8]]],
AdvancedIncSubtensor1,
)
......@@ -2297,7 +2320,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [1, 3, 2]
self._compile_and_check(
[admat, advec],
[inc_subtensor(admat[aivec_val], advec)],
[advanced_set_subtensor1(admat, advec, aivec_val)],
[admat_val, [1, 2, 3, 4]],
AdvancedIncSubtensor1,
)
......@@ -2305,7 +2328,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [0, 3, 0]
self._compile_and_check(
[admat, adscal],
[inc_subtensor(admat[aivec_val], adscal)],
[advanced_set_subtensor1(admat, adscal, aivec_val)],
[admat_val, 1],
AdvancedIncSubtensor1,
)
......@@ -2315,7 +2338,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[adtens4, bdtens4],
[inc_subtensor(adtens4[aivec_val], bdtens4)],
[advanced_set_subtensor1(adtens4, bdtens4, aivec_val)],
[adtens4_val, [[[[1, 2, 3, 4, 5]]], [[[6, 7, 8, 9, 10]]]]],
AdvancedIncSubtensor1,
warn=False,
......@@ -2324,7 +2347,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [1, 2, 1]
self._compile_and_check(
[adtens4, advec],
[inc_subtensor(adtens4[aivec_val], advec)],
[advanced_set_subtensor1(adtens4, advec, aivec_val)],
[adtens4_val, [1, 2, 3, 4, 5]],
AdvancedIncSubtensor1,
)
......@@ -2332,7 +2355,7 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [0, 3, 0]
self._compile_and_check(
[adtens4, adscal],
[inc_subtensor(adtens4[aivec_val], adscal)],
[advanced_set_subtensor1(adtens4, adscal, aivec_val)],
[adtens4_val, 2],
AdvancedIncSubtensor1,
)
......
import numpy as np
import pytest
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.graph.basic import Variable, ancestors
from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.subtensor_opt import local_replace_AdvancedSubtensor
from aesara.tensor.type import tensor
from tests.unittest_tools import create_aesara_param
y = create_aesara_param(np.random.randint(0, 4, size=(2,)))
z = create_aesara_param(np.random.randint(0, 4, size=(2, 2)))
@pytest.mark.parametrize(
("indices", "is_none"),
[
((slice(None), y, y), True),
((y, y, slice(None)), True),
((y,), False),
((slice(None), y), False),
((y, slice(None)), False),
((slice(None), y, slice(None)), False),
((slice(None), z, slice(None)), False),
((slice(None), z), False),
((z, slice(None)), False),
((slice(None), z, slice(None)), False),
],
)
def test_local_replace_AdvancedSubtensor(indices, is_none):
X_val = np.random.normal(size=(4, 4, 4))
X = tensor(np.float64, [False, False, False], name="X")
X.tag.test_value = X_val
Y = X[indices]
res_at = local_replace_AdvancedSubtensor.transform(None, Y.owner)
if is_none:
assert res_at is None
else:
(res_at,) = res_at
assert not any(
isinstance(v.owner.op, AdvancedSubtensor)
for v in ancestors([res_at])
if v.owner
)
inputs = [X] + [i for i in indices if isinstance(i, Variable)]
res_fn = function(inputs, res_at, mode=Mode("py", None, None))
exp_res_fn = function(inputs, Y, mode=Mode("py", None, None))
# Make sure that the expected result graph has an `AdvancedSubtensor`
assert any(
isinstance(v.owner.op, AdvancedSubtensor)
for v in exp_res_fn.maker.fgraph.variables
if v.owner
)
res_val = res_fn(*[i.tag.test_value for i in inputs])
exp_res_val = exp_res_fn(*[i.tag.test_value for i in inputs])
assert np.array_equal(res_val, exp_res_val)
......@@ -5,7 +5,7 @@ from numpy.testing import assert_equal, assert_string_equal
import aesara
import tests.unittest_tools as utt
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
from aesara.tensor.type import TensorType, dmatrix, iscalar, ivector, matrix
from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant
......@@ -149,19 +149,18 @@ def test__getitem__AdvancedSubtensor():
# This is a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[i]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor1
assert op_types[-1] == AdvancedSubtensor
# This should index nothing (i.e. return an empty copy of `x`)
# We check that the index is empty
z = x[[]]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types == [AdvancedSubtensor1]
assert op_types == [AdvancedSubtensor]
assert isinstance(z.owner.inputs[1], TensorConstant)
# This is also a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[:, i]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types == [DimShuffle, AdvancedSubtensor1, DimShuffle]
assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[..., i, None]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论