提交 0dbd512b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace `take` in _tensor_py_operators.__getitem__ with an optimization

上级 47207edb
...@@ -26,9 +26,22 @@ if enable_sparse: ...@@ -26,9 +26,22 @@ if enable_sparse:
.. versionadded:: 0.6rc4 .. versionadded:: 0.6rc4
""" """
from aesara.tensor.subtensor import AdvancedSubtensor1 from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
assert isinstance(var.owner.op, AdvancedSubtensor1) if var.owner is None or not isinstance(
var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1)
):
raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
)
ret = var.owner.op.__class__(sparse_grad=True)(*var.owner.inputs) x = var.owner.inputs[0]
indices = var.owner.inputs[1:]
if len(indices) > 1:
raise TypeError(
"Sparse gradient is only implemented for single advanced indexing"
)
ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0])
return ret return ret
...@@ -58,6 +58,7 @@ from aesara.tensor import ( ...@@ -58,6 +58,7 @@ from aesara.tensor import (
blas_scipy, blas_scipy,
nnet, nnet,
opt_uncanonicalize, opt_uncanonicalize,
subtensor_opt,
xlogx, xlogx,
) )
......
...@@ -2534,12 +2534,16 @@ def local_useless_inc_subtensor(fgraph, node): ...@@ -2534,12 +2534,16 @@ def local_useless_inc_subtensor(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize
@local_optimizer([AdvancedIncSubtensor1]) @local_optimizer([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(fgraph, node): def local_set_to_inc_subtensor(fgraph, node):
""" r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False)
TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it
did this wouldn't need to also be included in the "specialize" pass.
""" """
if ( if (
isinstance(node.op, AdvancedIncSubtensor1) isinstance(node.op, AdvancedIncSubtensor1)
...@@ -2567,9 +2571,9 @@ def local_set_to_inc_subtensor(fgraph, node): ...@@ -2567,9 +2571,9 @@ def local_set_to_inc_subtensor(fgraph, node):
if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]: if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]:
return return
ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2]) ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])
# Copy over previous output stacktrace
# Julian: I'm not sure about this at all...
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
return [ret] return [ret]
...@@ -3448,7 +3452,7 @@ def local_setsubtensor_of_constants(fgraph, node): ...@@ -3448,7 +3452,7 @@ def local_setsubtensor_of_constants(fgraph, node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_specialize
@local_optimizer([AdvancedSubtensor1]) @local_optimizer([AdvancedSubtensor1])
def local_adv_sub1_adv_inc_sub1(fgraph, node): def local_adv_sub1_adv_inc_sub1(fgraph, node):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...). """Optimize the possible AdvSub1(AdvSetSub1(...), ...).
......
import aesara
from aesara.graph.opt import copy_stack_trace, local_optimizer
from aesara.tensor.basic_opt import register_specialize
from aesara.tensor.shape import shape_tuple
from aesara.tensor.sharedvar import TensorSharedVariable
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
advanced_subtensor1,
inc_subtensor,
)
from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from aesara.tensor.var import TensorConstant, TensorVariable
def transform_take(a, indices, axis):
r"""Transform ``arr[:,:,:,indices,...]``-like operations into single-dimensional, vector index operations.
This effectively converts certain `AdvancedSubtensor` `Op`\s into a
combination of `AdvancedSubtensor1`, `Dimshuffle`, and `Reshape` `Op`\s,
which can be more efficient.
Parameters
----------
a : TensorVariable
The source array.
indices : TensorVariable, ndarray, list, tuple
The indices of the values to extract.
axis : int
The axis over which to select values. By default, the flattened
input array is used.
"""
a = aesara.tensor.as_tensor_variable(a)
indices = aesara.tensor.as_tensor_variable(indices)
# We can use the more efficient `AdvancedSubtensor1` if `indices` is a vector
if indices.ndim == 1:
if axis == 0:
return advanced_subtensor1(a, indices)
else:
shuffle = list(range(a.ndim))
shuffle[0] = axis
shuffle[axis] = 0
res = advanced_subtensor1(a.dimshuffle(shuffle), indices).dimshuffle(
shuffle
)
return res
# We can reshape and flatten the indices in order to use an
# `AdvancedSubtensor1` `Op` per the above
indices_shape = shape_tuple(indices)
a_shape = shape_tuple(a)
shape_parts = [
a_shape[:axis],
indices_shape,
a_shape[axis + 1 :],
]
shape_parts = [sp for sp in shape_parts if len(sp) > 0]
assert len(shape_parts) > 0
if len(shape_parts) > 1:
shape = aesara.tensor.concatenate(shape_parts)
else:
shape = shape_parts[0]
ndim = a.ndim + indices.ndim - 1
return transform_take(a, indices.flatten(), axis).reshape(shape, ndim)
def is_full_slice(x):
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
if (
(isinstance(x, slice) and x == slice(None))
or (isinstance(x, SliceConstant) and x.value == slice(None))
or (
not isinstance(x, SliceConstant)
and isinstance(getattr(x, "type", None), SliceType)
and x.owner is not None
and all(
isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs
)
)
):
return True
return False
def get_advsubtensor_axis(indices):
"""Determine the axis at which an array index is applied.
This only works for ``take``-like indices: e.g. ``x[:, :, idx, ...]``. For
the above example, `get_advsubtensor_axis` would return ``2``. If it
encounters anything other than a set of `indices` containing full slices
and an array/tensor index, it will return ``None``.
"""
found_idx = False
axis = 0
for idx in indices:
if not found_idx and is_full_slice(idx):
# Preceding full slices
axis += 1
elif found_idx and not is_full_slice(idx):
# We don't handle multiple indices
return
elif found_idx and is_full_slice(idx):
# Trailing full slices
continue
else:
found_idx = True
if isinstance(
indices[axis], (TensorConstant, TensorVariable, TensorSharedVariable)
):
return axis
@register_specialize
@local_optimizer([AdvancedSubtensor])
def local_replace_AdvancedSubtensor(fgraph, node):
r"""
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
a vector ``y``, and ``X[z, ...]`` into ``X[z.flatten()].reshape(...)``, for a
matrix ``z``.
These rewrites replace `AdvancedSubtensor`\s with the more efficient
`AdvancedSubtensor1` and `Subtensor` `Op`\s.
"""
if not isinstance(node.op, AdvancedSubtensor):
return
indexed_var = node.inputs[0]
indices = node.inputs[1:]
axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool":
# Booleans aren't handled
return
new_res = transform_take(indexed_var, indices[axis], axis)
assert new_res.broadcastable == node.outputs[0].broadcastable
copy_stack_trace(node.outputs[0], new_res)
return [new_res]
@register_specialize
@local_optimizer([AdvancedIncSubtensor])
def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
This is only done when there's a single vector index.
"""
if not isinstance(node.op, AdvancedIncSubtensor):
return
res = node.inputs[0]
val = node.inputs[1]
indices = node.inputs[2:]
axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool":
# Booleans aren't handled
return
new_subtensor = transform_take(res, indices[axis], axis)
set_instead_of_inc = node.op.set_instead_of_inc
inplace = node.op.inplace
new_res = inc_subtensor(
new_subtensor, val, inplace=inplace, set_instead_of_inc=set_instead_of_inc
)
copy_stack_trace(node.outputs[0], new_res)
return [new_res]
...@@ -528,11 +528,9 @@ class _tensor_py_operators: ...@@ -528,11 +528,9 @@ class _tensor_py_operators:
# used; if it fails with AdvancedIndexingError, advanced indexing is # used; if it fails with AdvancedIndexingError, advanced indexing is
# used # used
advanced = False advanced = False
axis = None
for i, arg in enumerate(args): for i, arg in enumerate(args):
if includes_bool(arg): if includes_bool(arg):
advanced = True advanced = True
axis = None
break break
if arg is not np.newaxis: if arg is not np.newaxis:
...@@ -540,42 +538,11 @@ class _tensor_py_operators: ...@@ -540,42 +538,11 @@ class _tensor_py_operators:
aet.subtensor.Subtensor.convert(arg) aet.subtensor.Subtensor.convert(arg)
except AdvancedIndexingError: except AdvancedIndexingError:
if advanced: if advanced:
axis = None
break break
else: else:
advanced = True advanced = True
axis = i
if advanced: if advanced:
if (
axis is not None
and all(isinstance(a, slice) and a == slice(None) for a in args[:axis])
and all(
isinstance(a, slice) and a == slice(None) for a in args[axis + 1 :]
)
# I.e. if the first advanced index is a tensor or NumPy array,
# then it can't be boolean (in order to meet this condition).
# How could this possibly occur; we filter for booleans above,
# right?
# and (not hasattr(args[axis], "dtype") or args[axis].dtype != "bool")
and isinstance(
args[axis],
(
np.ndarray,
list,
TensorVariable,
TensorConstant,
aet.sharedvar.TensorSharedVariable,
),
)
):
# If we're here, it means that an advanced index was found
# (e.g. an array of indices) and it was surrounded by full
# slices--or no slices (e.g. `x[:, :, idx, ...]`). The
# `take` function/`Op` serves exactly this type of indexing,
# so we simply return its result.
return self.take(args[axis], axis)
else:
return aet.subtensor.advanced_subtensor(self, *args) return aet.subtensor.advanced_subtensor(self, *args)
else: else:
if np.newaxis in args: if np.newaxis in args:
......
...@@ -611,7 +611,7 @@ def test_jax_Subtensors(): ...@@ -611,7 +611,7 @@ def test_jax_Subtensors():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Advanced indexing # Advanced indexing
out_aet = x_aet[[1, 2]] out_aet = aet_subtensor.advanced_subtensor1(x_aet, [1, 2])
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -623,7 +623,7 @@ def test_jax_Subtensors(): ...@@ -623,7 +623,7 @@ def test_jax_Subtensors():
# Advanced and basic indexing # Advanced and basic indexing
out_aet = x_aet[[1, 2], :] out_aet = x_aet[[1, 2], :]
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
......
...@@ -410,15 +410,11 @@ def test_Subtensor(x, indices): ...@@ -410,15 +410,11 @@ def test_Subtensor(x, indices):
"x, indices", "x, indices",
[ [
(aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), (aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([1, 2], slice(None)),
),
], ],
) )
def test_AdvancedSubtensor1(x, indices): def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension.""" """Test NumPy's advanced indexing in one dimension."""
out_aet = x[[1, 2]] out_aet = aet_subtensor.advanced_subtensor1(x, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
...@@ -493,26 +489,21 @@ def test_IncSubtensor(x, y, indices): ...@@ -493,26 +489,21 @@ def test_IncSubtensor(x, y, indices):
aet.as_tensor(rng.poisson(size=(2, 4, 5))), aet.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2],), ([1, 2],),
), ),
(
aet.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
aet.as_tensor(rng.poisson(size=(2, 4, 5))),
([1, 2], slice(None)),
),
], ],
) )
def test_AdvancedIncSubtensor1(x, y, indices): def test_AdvancedIncSubtensor1(x, y, indices):
out_aet = aet.set_subtensor(x[indices], y) out_aet = aet_subtensor.advanced_set_subtensor1(x, y, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
out_aet = aet.inc_subtensor(x[indices], y) out_aet = aet_subtensor.advanced_inc_subtensor1(x, y, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_aet]) out_fg = FunctionGraph([], [out_aet])
compare_numba_and_py(out_fg, []) compare_numba_and_py(out_fg, [])
x_at = x.type() x_at = x.type()
out_aet = aet.set_subtensor(x_at[indices], y, inplace=True) out_aet = aet_subtensor.AdvancedIncSubtensor1(inplace=True)(x_at, y, *indices)
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_at], [out_aet]) out_fg = FunctionGraph([x_at], [out_aet])
compare_numba_and_py(out_fg, [x.data]) compare_numba_and_py(out_fg, [x.data])
......
...@@ -88,7 +88,7 @@ from aesara.tensor.basic import MakeVector ...@@ -88,7 +88,7 @@ from aesara.tensor.basic import MakeVector
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Shape_i from aesara.tensor.shape import Shape_i
from aesara.tensor.subtensor import AdvancedIncSubtensor1, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
float_dtypes, float_dtypes,
...@@ -644,11 +644,19 @@ class TestConstructSparseFromList: ...@@ -644,11 +644,19 @@ class TestConstructSparseFromList:
def test_adv_sub1_sparse_grad(self): def test_adv_sub1_sparse_grad(self):
v = ivector() v = ivector()
# Assert we don't create a sparse grad by default
m = matrix() m = matrix()
with pytest.raises(TypeError):
aesara.sparse.sparse_grad(v)
with pytest.raises(TypeError):
sub = m[v, v]
aesara.sparse.sparse_grad(sub)
# Assert we don't create a sparse grad by default
sub = m[v] sub = m[v]
g = aesara.grad(sub.sum(), m) g = aesara.grad(sub.sum(), m)
assert isinstance(g.owner.op, AdvancedIncSubtensor1) assert isinstance(g.owner.op, AdvancedIncSubtensor)
# Test that we create a sparse grad when asked # Test that we create a sparse grad when asked
# USER INTERFACE # USER INTERFACE
...@@ -685,7 +693,7 @@ class TestConstructSparseFromList: ...@@ -685,7 +693,7 @@ class TestConstructSparseFromList:
# Assert we don't create a sparse grad by default # Assert we don't create a sparse grad by default
g = aesara.grad(sub.sum(), t) g = aesara.grad(sub.sum(), t)
assert isinstance(g.owner.op, AdvancedIncSubtensor1) assert isinstance(g.owner.op, AdvancedIncSubtensor)
# Test that we raise an error, as we can't create a sparse # Test that we raise an error, as we can't create a sparse
# grad from tensors that don't have 2 dimensions. # grad from tensors that don't have 2 dimensions.
......
...@@ -1672,7 +1672,11 @@ class TestSubtensorIncSubtensor: ...@@ -1672,7 +1672,11 @@ class TestSubtensorIncSubtensor:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.rng = np.random.default_rng(utt.fetch_seed()) cls.rng = np.random.default_rng(utt.fetch_seed())
cls.mode = get_default_mode().including("local_subtensor_inc_subtensor") cls.mode = get_default_mode().including(
"local_subtensor_inc_subtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
"local_replace_AdvancedSubtensor",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"val, indices, optype", "val, indices, optype",
...@@ -1685,11 +1689,13 @@ class TestSubtensorIncSubtensor: ...@@ -1685,11 +1689,13 @@ class TestSubtensorIncSubtensor:
def test_inplace(self, val, indices, optype): def test_inplace(self, val, indices, optype):
x = matrix("x") x = matrix("x")
y = set_subtensor((2 * x)[indices], val, inplace=False) y = set_subtensor((2 * x)[indices], val, inplace=False)
assert isinstance(y.owner.op, optype)
assert y.owner.op.inplace is False assert y.owner.op.inplace is False
f = function( f = function(
[x, val] + list(indices), y, mode=get_default_mode().including("inplace") [x, val] + list(indices),
y,
mode=self.mode.including("inplace"),
) )
assert isinstance(f.maker.fgraph.outputs[0].owner.op, optype)
assert f.maker.fgraph.outputs[0].owner.op.inplace is True assert f.maker.fgraph.outputs[0].owner.op.inplace is True
def test_basic(self): def test_basic(self):
...@@ -2602,7 +2608,11 @@ class TestLocalAdvSub1AdvIncSub1: ...@@ -2602,7 +2608,11 @@ class TestLocalAdvSub1AdvIncSub1:
def setup_method(self): def setup_method(self):
mode = get_default_mode() mode = get_default_mode()
self.mode = mode.including("local_adv_sub1_adv_inc_sub1").excluding("fusion") self.mode = mode.including(
"local_replace_AdvancedSubtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
"local_adv_sub1_adv_inc_sub1",
).excluding("fusion")
self.mode_no_assert = self.mode.including("local_remove_all_assert") self.mode_no_assert = self.mode.including("local_remove_all_assert")
def test_basic(self): def test_basic(self):
...@@ -2969,8 +2979,13 @@ def test_local_set_to_inc_subtensor(): ...@@ -2969,8 +2979,13 @@ def test_local_set_to_inc_subtensor():
s = v[[2, 1]] s = v[[2, 1]]
g = s + 3 g = s + 3
r = set_subtensor(s, g) r = set_subtensor(s, g)
moder = get_default_mode().excluding("local_set_to_inc_subtensor")
modet = get_default_mode().including("local_set_to_inc_subtensor") mode = get_default_mode().including(
"local_replace_AdvancedSubtensor",
"local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1",
)
moder = mode.excluding("local_set_to_inc_subtensor")
modet = mode.including("local_set_to_inc_subtensor")
f1 = function([v], r, mode=moder) f1 = function([v], r, mode=moder)
f2 = function([v], r, mode=modet) f2 = function([v], r, mode=modet)
...@@ -3453,8 +3468,8 @@ class TestLocalUselessIncSubtensorAlloc: ...@@ -3453,8 +3468,8 @@ class TestLocalUselessIncSubtensorAlloc:
utt.assert_allclose(r1, r2) utt.assert_allclose(r1, r2)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor) assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor1)
assert check_stack_trace(f2, ops_to_check=AdvancedIncSubtensor) assert check_stack_trace(f2, ops_to_check=AdvancedIncSubtensor1)
def test_advanced_inc_subtensor1(self): def test_advanced_inc_subtensor1(self):
x = vector("x") x = vector("x")
......
...@@ -44,7 +44,7 @@ from aesara.tensor.extra_ops import ( ...@@ -44,7 +44,7 @@ from aesara.tensor.extra_ops import (
unravel_index, unravel_index,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import AdvancedIncSubtensor1 from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
dmatrix, dmatrix,
...@@ -1174,7 +1174,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1174,7 +1174,7 @@ class TestBroadcastTo(utt.InferShapeTester):
e_fn = function([d], e, mode=py_mode) e_fn = function([d], e, mode=py_mode)
advincsub_node = e_fn.maker.fgraph.outputs[0].owner advincsub_node = e_fn.maker.fgraph.outputs[0].owner
assert isinstance(advincsub_node.op, AdvancedIncSubtensor1) assert isinstance(advincsub_node.op, AdvancedIncSubtensor)
assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo) assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)
assert advincsub_node.op.inplace is False assert advincsub_node.op.inplace is False
......
...@@ -504,3 +504,11 @@ def test_nonstandard_shapes(): ...@@ -504,3 +504,11 @@ def test_nonstandard_shapes():
none_shape = shape(NoneConst) none_shape = shape(NoneConst)
assert np.array_equal(none_shape.get_test_value(), []) assert np.array_equal(none_shape.get_test_value(), [])
def test_shape_i_basics():
with pytest.raises(TypeError):
Shape_i(0)([1, 2])
with pytest.raises(TypeError):
Shape_i(0)(scalar())
import numpy as np
import pytest
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.graph.basic import Variable, ancestors
from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.subtensor_opt import local_replace_AdvancedSubtensor
from aesara.tensor.type import tensor
from tests.unittest_tools import create_aesara_param
y = create_aesara_param(np.random.randint(0, 4, size=(2,)))
z = create_aesara_param(np.random.randint(0, 4, size=(2, 2)))
@pytest.mark.parametrize(
("indices", "is_none"),
[
((slice(None), y, y), True),
((y, y, slice(None)), True),
((y,), False),
((slice(None), y), False),
((y, slice(None)), False),
((slice(None), y, slice(None)), False),
((slice(None), z, slice(None)), False),
((slice(None), z), False),
((z, slice(None)), False),
((slice(None), z, slice(None)), False),
],
)
def test_local_replace_AdvancedSubtensor(indices, is_none):
X_val = np.random.normal(size=(4, 4, 4))
X = tensor(np.float64, [False, False, False], name="X")
X.tag.test_value = X_val
Y = X[indices]
res_at = local_replace_AdvancedSubtensor.transform(None, Y.owner)
if is_none:
assert res_at is None
else:
(res_at,) = res_at
assert not any(
isinstance(v.owner.op, AdvancedSubtensor)
for v in ancestors([res_at])
if v.owner
)
inputs = [X] + [i for i in indices if isinstance(i, Variable)]
res_fn = function(inputs, res_at, mode=Mode("py", None, None))
exp_res_fn = function(inputs, Y, mode=Mode("py", None, None))
# Make sure that the expected result graph has an `AdvancedSubtensor`
assert any(
isinstance(v.owner.op, AdvancedSubtensor)
for v in exp_res_fn.maker.fgraph.variables
if v.owner
)
res_val = res_fn(*[i.tag.test_value for i in inputs])
exp_res_val = exp_res_fn(*[i.tag.test_value for i in inputs])
assert np.array_equal(res_val, exp_res_val)
...@@ -5,7 +5,7 @@ from numpy.testing import assert_equal, assert_string_equal ...@@ -5,7 +5,7 @@ from numpy.testing import assert_equal, assert_string_equal
import aesara import aesara
import tests.unittest_tools as utt import tests.unittest_tools as utt
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
from aesara.tensor.type import TensorType, dmatrix, iscalar, ivector, matrix from aesara.tensor.type import TensorType, dmatrix, iscalar, ivector, matrix
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
...@@ -149,19 +149,18 @@ def test__getitem__AdvancedSubtensor(): ...@@ -149,19 +149,18 @@ def test__getitem__AdvancedSubtensor():
# This is a `__getitem__` call that's redirected to `_tensor_py_operators.take` # This is a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[i] z = x[i]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types[-1] == AdvancedSubtensor1 assert op_types[-1] == AdvancedSubtensor
# This should index nothing (i.e. return an empty copy of `x`) # This should index nothing (i.e. return an empty copy of `x`)
# We check that the index is empty # We check that the index is empty
z = x[[]] z = x[[]]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types == [AdvancedSubtensor1] assert op_types == [AdvancedSubtensor]
assert isinstance(z.owner.inputs[1], TensorConstant) assert isinstance(z.owner.inputs[1], TensorConstant)
# This is also a `__getitem__` call that's redirected to `_tensor_py_operators.take`
z = x[:, i] z = x[:, i]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
assert op_types == [DimShuffle, AdvancedSubtensor1, DimShuffle] assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[..., i, None] z = x[..., i, None]
op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])] op_types = [type(node.op) for node in aesara.graph.basic.io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论