提交 e08dac2a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert get_vector_length to a dispatch function

上级 2edc7339
"""Symbolic tensor types and constructor functions.""" """Symbolic tensor types and constructor functions."""
import warnings
from functools import singledispatch from functools import singledispatch
from typing import Any, Callable, NoReturn, Optional from typing import Any, Callable, NoReturn, Optional, Union
from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op
def as_tensor_variable( def as_tensor_variable(
...@@ -45,10 +47,53 @@ def _as_tensor_variable( ...@@ -45,10 +47,53 @@ def _as_tensor_variable(
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.") raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")
import aesara.tensor.exceptions def get_vector_length(v: Any):
from aesara.gradient import consider_constant, grad, hessian, jacobian """Return the run-time length of a symbolic vector, when possible.
from aesara.tensor import sharedvar # adds shared-variable constructors
from aesara.tensor import ( Parameters
----------
v
A rank-1 `TensorType` variable.
Raises
------
TypeError
`v` hasn't the proper type.
ValueError
No special case applies, the length is not known.
In general this is not possible, but for a number of special cases
the length can be determined at compile / graph-construction time.
This function implements these special cases.
"""
v = as_tensor_variable(v)
if v.type.ndim != 1:
raise TypeError(f"Argument must be a vector; got {v.type}")
if v.type.broadcastable[0]:
return 1
return _get_vector_length(getattr(v.owner, "op", v), v)
@singledispatch
def _get_vector_length(op: Union[Op, Variable], var: Variable) -> NoReturn:
"""`Op`-based dispatch for `get_vector_length`."""
raise ValueError(f"Length of {var} cannot be determined")
@_get_vector_length.register(Constant)
def _get_vector_length_Constant(var_inst, var):
return len(var.data)
import aesara.tensor.exceptions # noqa
from aesara.gradient import consider_constant, grad, hessian, jacobian # noqa
# adds shared-variable constructors
from aesara.tensor import sharedvar # noqa
from aesara.tensor import ( # noqa
basic_opt, basic_opt,
blas, blas,
blas_c, blas_c,
...@@ -61,14 +106,16 @@ from aesara.tensor import ( ...@@ -61,14 +106,16 @@ from aesara.tensor import (
# isort: off # isort: off
from aesara.tensor import linalg from aesara.tensor import linalg # noqa
from aesara.tensor import nlinalg # For backward compatibility
from aesara.tensor import slinalg # For backward compatibility # For backward compatibility
from aesara.tensor import nlinalg # noqa
from aesara.tensor import slinalg # noqa
# isort: on # isort: on
from aesara.tensor.basic import * from aesara.tensor.basic import * # noqa
from aesara.tensor.blas import batched_dot, batched_tensordot from aesara.tensor.blas import batched_dot, batched_tensordot # noqa
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import ( # noqa
bartlett, bartlett,
bincount, bincount,
broadcast_arrays, broadcast_arrays,
...@@ -86,9 +133,7 @@ from aesara.tensor.extra_ops import ( ...@@ -86,9 +133,7 @@ from aesara.tensor.extra_ops import (
unique, unique,
unravel_index, unravel_index,
) )
from aesara.tensor.io import * from aesara.tensor.shape import ( # noqa
from aesara.tensor.math import *
from aesara.tensor.shape import (
reshape, reshape,
shape, shape,
shape_padaxis, shape_padaxis,
...@@ -97,13 +142,17 @@ from aesara.tensor.shape import ( ...@@ -97,13 +142,17 @@ from aesara.tensor.shape import (
specify_shape, specify_shape,
) )
from aesara.tensor.io import * # noqa
from aesara.tensor.math import * # noqa
# We import as `_shared` instead of `shared` to avoid confusion between # We import as `_shared` instead of `shared` to avoid confusion between
# `aesara.shared` and `tensor._shared`. # `aesara.shared` and `tensor._shared`.
from aesara.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk from aesara.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk # noqa
from aesara.tensor.subtensor import * from aesara.tensor.subtensor import * # noqa
from aesara.tensor.type import * from aesara.tensor.type import * # noqa
from aesara.tensor.type_other import * from aesara.tensor.type_other import * # noqa
from aesara.tensor.var import TensorConstant, TensorVariable from aesara.tensor.var import TensorConstant, TensorVariable # noqa
__all__ = ["random"] # noqa: F405 __all__ = ["random"] # noqa: F405
...@@ -31,7 +31,12 @@ from aesara.misc.safe_asarray import _asarray ...@@ -31,7 +31,12 @@ from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint from aesara.printing import min_informative_str, pprint
from aesara.scalar import int32 from aesara.scalar import int32
from aesara.scalar.basic import ScalarConstant, ScalarVariable from aesara.scalar.basic import ScalarConstant, ScalarVariable
from aesara.tensor import _as_tensor_variable, as_tensor_variable from aesara.tensor import (
_as_tensor_variable,
_get_vector_length,
as_tensor_variable,
get_vector_length,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from aesara.tensor.shape import ( from aesara.tensor.shape import (
...@@ -1743,6 +1748,11 @@ class MakeVector(COp): ...@@ -1743,6 +1748,11 @@ class MakeVector(COp):
make_vector = MakeVector() make_vector = MakeVector()
@_get_vector_length.register(MakeVector)
def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)
def transfer(var, target): def transfer(var, target):
""" """
Return a version of `var` transferred to `target`. Return a version of `var` transferred to `target`.
...@@ -2554,6 +2564,17 @@ join_ = Join() ...@@ -2554,6 +2564,17 @@ join_ = Join()
pprint.assign(Join, printing.FunctionPrinter("join")) pprint.assign(Join, printing.FunctionPrinter("join"))
@_get_vector_length.register(Join)
def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs
try:
axis = get_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
def join(axis, *tensors_list): def join(axis, *tensors_list):
r""" r"""
Convenience function to concatenate `TensorType`\s along the given axis. Convenience function to concatenate `TensorType`\s along the given axis.
...@@ -2735,7 +2756,7 @@ def stack(*tensors, **kwargs): ...@@ -2735,7 +2756,7 @@ def stack(*tensors, **kwargs):
# in case there is direct int # in case there is direct int
tensors = list(map(as_tensor_variable, tensors)) tensors = list(map(as_tensor_variable, tensors))
dtype = aes.upcast(*[i.dtype for i in tensors]) dtype = aes.upcast(*[i.dtype for i in tensors])
return aesara.tensor.basic_opt.MakeVector(dtype)(*tensors) return MakeVector(dtype)(*tensors)
return join(axis, *[shape_padaxis(t, axis) for t in tensors]) return join(axis, *[shape_padaxis(t, axis) for t in tensors])
...@@ -2765,96 +2786,6 @@ def concatenate(tensor_list, axis=0): ...@@ -2765,96 +2786,6 @@ def concatenate(tensor_list, axis=0):
return join(axis, *tensor_list) return join(axis, *tensor_list)
def get_vector_length(v):
"""Return the run-time length of a symbolic vector.
Parameters
----------
v
A rank-1 TensorType variable.
Raises
------
TypeError
`v` hasn't the proper type.
ValueError
No special case applies, the length is not known.
In general this is not possible, but for a number of special cases
the length can be determined at compile / graph-construction time.
This function implements these special cases.
"""
v = as_tensor_variable(v)
if v.ndim != 1:
raise TypeError(f"argument must be symbolic vector, got '{v}'")
if v.type.broadcastable[0]:
return 1
if isinstance(v, aesara.tensor.sharedvar.TensorSharedVariable) and v.type.ndim == 1:
return len(v.get_value())
if isinstance(v, Constant) and v.type.ndim == 1:
return len(v.data)
if v.owner and isinstance(v.owner.op, aesara.tensor.basic_opt.MakeVector):
return len(v.owner.inputs)
if v.owner and isinstance(v.owner.op, Shape):
return v.owner.inputs[0].type.ndim
# We can skip `Op`s that don't affect the length, like unary `Elemwise`
# `Op`s
if (
v.owner
and isinstance(v.owner.op, Elemwise)
and len(v.owner.inputs) == 1
and len(v.owner.outputs) == 1
):
return get_vector_length(v.owner.inputs[0])
if v.owner and isinstance(v.owner.op, Join):
axis, *arrays = v.owner.inputs
try:
axis = get_scalar_constant_value(axis)
if axis != 0:
raise ValueError()
if not builtins.all(a.ndim == 1 for a in arrays):
raise ValueError()
return builtins.sum(get_vector_length(a) for a in arrays)
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {v} cannot be determined")
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
if (
v.owner
and isinstance(v.owner.op, aesara.tensor.subtensor.Subtensor)
and isinstance(v.owner.op.idx_list[0], slice)
):
try:
indices = aesara.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list
)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
arg_len = get_vector_length(v.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {v} cannot be determined")
raise ValueError(f"Length of {v} cannot be determined")
def horizontal_stack(*args): def horizontal_stack(*args):
""" """
Horizontally stack two L{TensorType}s. Horizontally stack two L{TensorType}s.
......
...@@ -20,7 +20,9 @@ from aesara.scalar.basic import Scalar ...@@ -20,7 +20,9 @@ from aesara.scalar.basic import Scalar
from aesara.scalar.basic import bool as scalar_bool from aesara.scalar.basic import bool as scalar_bool
from aesara.scalar.basic import identity as scalar_identity from aesara.scalar.basic import identity as scalar_identity
from aesara.scalar.basic import transfer_type, upcast from aesara.scalar.basic import transfer_type, upcast
from aesara.tensor import _get_vector_length
from aesara.tensor import elemwise_cgen as cgen from aesara.tensor import elemwise_cgen as cgen
from aesara.tensor import get_vector_length
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
continuous_dtypes, continuous_dtypes,
...@@ -1842,3 +1844,11 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1842,3 +1844,11 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
return construct(symbol[0]) return construct(symbol[0])
else: else:
return construct return construct
@_get_vector_length.register(Elemwise)
def _get_vector_length_Elemwise(op, var):
if len(var.owner.inputs) == 1 and len(var.owner.outputs) == 1:
return get_vector_length(var.owner.inputs[0])
raise ValueError(f"Length of {var} cannot be determined")
...@@ -10,6 +10,7 @@ from aesara.graph.op import COp ...@@ -10,6 +10,7 @@ from aesara.graph.op import COp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32 from aesara.scalar import int32
from aesara.tensor import _get_vector_length
from aesara.tensor import basic as aet from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor from aesara.tensor.type import TensorType, int_dtypes, tensor
...@@ -129,6 +130,11 @@ shape = Shape() ...@@ -129,6 +130,11 @@ shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
@_get_vector_length.register(Shape)
def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim
def shape_tuple(x): def shape_tuple(x):
"""Get a tuple of symbolic shape values. """Get a tuple of symbolic shape values.
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
from aesara.compile import SharedVariable, shared_constructor from aesara.compile import SharedVariable, shared_constructor
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import _get_vector_length
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
from aesara.tensor.var import _tensor_py_operators from aesara.tensor.var import _tensor_py_operators
...@@ -23,6 +24,11 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -23,6 +24,11 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
@_get_vector_length.register(TensorSharedVariable)
def _get_vector_length_TensorSharedVariable(var_inst, var):
return len(var.get_value(borrow=True))
@shared_constructor @shared_constructor
def tensor_constructor( def tensor_constructor(
value, value,
......
...@@ -18,6 +18,7 @@ from aesara.graph.utils import MethodNotDefined ...@@ -18,6 +18,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar.basic import ScalarConstant from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import ( from aesara.tensor.exceptions import (
...@@ -2705,6 +2706,39 @@ def take(a, indices, axis=None, mode="raise"): ...@@ -2705,6 +2706,39 @@ def take(a, indices, axis=None, mode="raise"):
return a[full_indices] return a[full_indices]
@_get_vector_length.register(Subtensor)
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try:
indices = aesara.tensor.subtensor.get_idx_list(
var.owner.inputs, var.owner.op.idx_list
)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
return 0
arg_len = get_vector_length(var.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {var} cannot be determined")
__all__ = [ __all__ = [
"take", "take",
"inc_subtensor", "inc_subtensor",
......
...@@ -948,23 +948,45 @@ def test_basic_allclose(): ...@@ -948,23 +948,45 @@ def test_basic_allclose():
def test_get_vector_length(): def test_get_vector_length():
# Test `Constant`s
empty_tuple = as_tensor_variable(())
assert 0 == get_vector_length(empty_tuple)
x = as_tensor_variable((1, 2, 3))
assert 3 == get_vector_length(x)
# Test `TensorSharedVariable`s
x = aesara.shared(np.array((2, 3, 4, 5)))
res = get_vector_length(x)
assert res == 4
# Test `Shape`s
x = aesara.shared(np.zeros((2, 3, 4, 5))) x = aesara.shared(np.zeros((2, 3, 4, 5)))
assert len(list(x.shape)) == 4 res = get_vector_length(x.shape)
assert len(list(x.shape[2:4])) == 2 assert res == 4
assert len(list(x.shape[2:])) == 2
assert len(list(x.shape[1:4])) == 3 # Test `Subtensor`s
assert len(list(x.shape[2:2])) == 0 x = as_tensor_variable(np.arange(4))
assert len(list(x.shape[1:5])) == 3 assert get_vector_length(x[2:4]) == 2
assert len(list(x.shape[1:10])) == 3 assert get_vector_length(x[2:]) == 2
assert get_vector_length(x[1:4]) == 3
assert get_vector_length(x[2:2]) == 0
assert get_vector_length(x[1:10]) == 3
# Test step # Test step
assert len(list(x.shape[1:10:2])) == 2 assert get_vector_length(x[1:10:2]) == 2
# Test neg start # Test neg start
assert len(list(x.shape[-1:4])) == 1 assert get_vector_length(x[-1:4]) == 1
assert len(list(x.shape[-6:4])) == 4 assert get_vector_length(x[-6:4]) == 4
# test neg stop # test neg stop
assert len(list(x.shape[1:-2])) == 1 assert get_vector_length(x[1:-2]) == 1
assert len(list(x.shape[1:-1])) == 2 assert get_vector_length(x[1:-1]) == 2
assert get_vector_length(lvector()[1:1]) == 0
assert get_vector_length(lvector()[-1:-1:3]) == 0
with pytest.raises(ValueError, match="^Length of .*"):
get_vector_length(x[lscalar() :])
# Test `Join`s
z = join(0, as_tensor_variable(1, ndim=1), as_tensor_variable(x.shape[0], ndim=1)) z = join(0, as_tensor_variable(1, ndim=1), as_tensor_variable(x.shape[0], ndim=1))
assert isinstance(z.owner.op, Join) assert isinstance(z.owner.op, Join)
assert get_vector_length(z) == 2 assert get_vector_length(z) == 2
...@@ -975,9 +997,15 @@ def test_get_vector_length(): ...@@ -975,9 +997,15 @@ def test_get_vector_length():
assert isinstance(z.owner.op, Join) assert isinstance(z.owner.op, Join)
assert get_vector_length(z) == 3 assert get_vector_length(z) == 3
empty_tuple = as_tensor_variable(()) z = join(
assert 0 == get_vector_length(empty_tuple) lscalar(),
as_tensor_variable([1, 2], ndim=1),
as_tensor_variable([3, 4], ndim=1),
)
with pytest.raises(ValueError, match="^Length of .*"):
get_vector_length(z)
# Test `MakeVector`s
x = lscalar("x") x = lscalar("x")
y = dscalar("y") y = dscalar("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论