提交 51de50be authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract general utility methods from Subtensor class

上级 df2a45d8
......@@ -50,6 +50,7 @@ from aesara.tensor.subtensor import (
Subtensor,
get_canonical_form_slice,
get_idx_list,
get_slice_elements,
set_subtensor,
)
from aesara.tensor.var import TensorConstant, get_unique_value
......@@ -1548,7 +1549,7 @@ def save_mem_new_scan(fgraph, node):
subtens = Subtensor(nw_slice)
# slice inputs
sl_ins = Subtensor.collapse(
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = subtens(new_outs[nw_pos], *sl_ins)
......@@ -1598,7 +1599,7 @@ def save_mem_new_scan(fgraph, node):
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
subtens = Subtensor(nw_slice)
sl_ins = Subtensor.collapse(
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = subtens(new_outs[nw_pos], *sl_ins)
......
......@@ -417,7 +417,9 @@ def get_scalar_constant_value(
and v.ndim == 0
):
if isinstance(v.owner.inputs[0], TensorConstant):
cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs))
from aesara.tensor.subtensor import get_constant_idx
cdata = tuple(get_constant_idx(v.owner.op.idx_list, v.owner.inputs))
try:
return v.owner.inputs[0].data.__getitem__(cdata).copy()
except IndexError:
......
......@@ -58,7 +58,12 @@ from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.shape import shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
get_constant_idx,
)
from aesara.tensor.type import (
TensorType,
discrete_dtypes,
......@@ -1736,8 +1741,8 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, Subtensor):
shape_subtensor = stop.owner
if shape_subtensor.op.get_constant_idx(
shape_subtensor.inputs, allow_partial=True
if get_constant_idx(
shape_subtensor.op.idx_list, shape_subtensor.inputs, allow_partial=True
) == [0]:
shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == shape:
......
......@@ -2,7 +2,7 @@ import logging
import sys
from itertools import chain, groupby
from textwrap import dedent
from typing import Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
......@@ -498,184 +498,184 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
return res_shape
class Subtensor(COp):
"""Basic NumPy indexing operator."""
def get_slice_elements(idxs: List, cond: Callable) -> List:
"""Extract slice elements conditional on a given predicate function.
check_input = False
view_map = {0: [0]}
_f16_ok = True
__props__ = ("idx_list",)
Parameters
----------
idxs : a list of indices or slices.
cond : a callable that returns a bool
@staticmethod
def collapse(idxs, cond):
"""
Parameters
----------
idxs : a list of indices or slices.
cond : a callable that returns a bool
Returns
-------
list
idxs, with the slices flattened out into a list.
If cond is true for an entry, does not flatten it.
Returns
-------
list
idxs, with the slices flattened out into a list.
If cond is true for an entry, does not flatten it.
"""
ret = []
"""
ret = []
def helper(entry):
if cond(entry):
ret.append(entry)
elif isinstance(entry, slice):
helper(entry.start)
helper(entry.stop)
helper(entry.step)
def helper(entry):
if cond(entry):
ret.append(entry)
elif isinstance(entry, slice):
helper(entry.start)
helper(entry.stop)
helper(entry.step)
for idx in idxs:
helper(idx)
for idx in idxs:
helper(idx)
return ret
return ret
@staticmethod
def convert(entry, slice_ok=True):
"""
Change references to Variables into references to Types.
def index_vars_to_types(entry, slice_ok=True):
r"""Change references to `Variable`s into references to `Type`s.
The "idx_list" field is unique to each Subtensor instance.
It is not unique to each Apply node, so it should not refer to
specific Variables.
TODO: WRITEME: This method also accepts "entry" already being a Type;
when would that happen?
The `Subtensor.idx_list` field is unique to each `Subtensor` instance. It
is not unique to each `Apply` node, so it should not refer to specific
`Variable`s.
"""
if (
isinstance(entry, (np.ndarray, Variable))
and hasattr(entry, "dtype")
and entry.dtype == "bool"
):
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
TODO WRITEME: This function also accepts an `entry` already being a `Type`;
when would that happen?
if isinstance(entry, Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types
):
raise TypeError("Expected an integer")
"""
if (
isinstance(entry, (np.ndarray, Variable))
and hasattr(entry, "dtype")
and entry.dtype == "bool"
):
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type
elif isinstance(entry, Type) and entry in scal_types:
return entry
if isinstance(entry, Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types
):
raise TypeError("Expected an integer")
if (
isinstance(entry, Variable)
and entry.type in tensor_types
and np.all(entry.type.broadcastable)
):
return aes.get_scalar_type(entry.type.dtype)
elif (
isinstance(entry, Type)
and entry in tensor_types
and np.all(entry.broadcastable)
):
return aes.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
if a is not None:
slice_a = Subtensor.convert(a, False)
else:
slice_a = None
if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type
elif isinstance(entry, Type) and entry in scal_types:
return entry
if b is not None and b != sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
slice_b = Subtensor.convert(b, False)
else:
slice_b = None
if (
isinstance(entry, Variable)
and entry.type in tensor_types
and np.all(entry.type.broadcastable)
):
return aes.get_scalar_type(entry.type.dtype)
elif (
isinstance(entry, Type)
and entry in tensor_types
and np.all(entry.broadcastable)
):
return aes.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
if a is not None:
slice_a = index_vars_to_types(a, False)
else:
slice_a = None
if c is not None:
slice_c = Subtensor.convert(c, False)
else:
slice_c = None
if b is not None and b != sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
slice_b = index_vars_to_types(b, False)
else:
slice_b = None
return slice(slice_a, slice_b, slice_c)
elif isinstance(entry, (int, np.integer)):
# Disallow the use of python scalars in idx_list
raise TypeError(
"Python scalar in idx_list." "Please report this error to aesara-dev."
)
if c is not None:
slice_c = index_vars_to_types(c, False)
else:
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
slice_c = None
def get_constant_idx(
self, inputs, allow_partial=False, only_process_constants=False, elemwise=True
):
"""
Return the idx_list with constant inputs replaced by their
python scalar equivalent.
May raise `NotScalarConstantError` if the idx contains
non-constant entries.
return slice(slice_a, slice_b, slice_c)
elif isinstance(entry, (int, np.integer)):
raise TypeError()
else:
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
If allow_partial is True, then entries that are not constant will
stay as their input variable rather than raising an exception.
None entries are always left as-is.
def get_constant_idx(
idx_list, inputs, allow_partial=False, only_process_constants=False, elemwise=True
):
r"""Return an `idx_list` with its constant inputs replaced by their Python scalar equivalents.
Parameters
----------
only_process_constants
If True, we only attempt to obtain the value of an index/slice if
it's directly constant and don't try to dig through dimshuffles,
fills, allocs, and other to figure out its value.
Examples
--------
Example usage where v, a are appropriately typed aesara variables :
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(Scalar(int64), slice(Scalar(int64), Scalar(int64), None))
>>> b.owner.op.get_constant_idx(b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
>>> b.owner.op.get_constant_idx(b.owner.inputs)
NotScalarConstantError: v
May raise `NotScalarConstantError` if the indices contain non-constant entries.
"""
real_idx = get_idx_list(inputs, self.idx_list)
If `allow_partial` is ``True``, then entries that are not constant will
stay as their input variable rather than raising an exception.
def conv(val):
if val is None:
return None
elif isinstance(val, slice):
return slice(conv(val.start), conv(val.stop), conv(val.step))
else:
try:
return get_scalar_constant_value(
val,
only_process_constants=only_process_constants,
elemwise=elemwise,
)
except NotScalarConstantError:
if allow_partial:
return val
else:
raise
``None`` entries are always left as-is.
return list(map(conv, real_idx))
Parameters
----------
only_process_constants
If ``True``, we only attempt to obtain the value of an index/slice if
it's directly constant and don't try to dig through `DimShuffle`\s,
fills, `Alloc`\s, and other to figure out its value.
def __init__(self, idx_list):
self.idx_list = tuple(map(self.convert, idx_list))
Examples
--------
Example usage where `v` and `a` are appropriately typed Aesara variables :
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(Scalar(int64), slice(Scalar(int64), Scalar(int64), None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(1, 3, None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
NotScalarConstantError: v
@staticmethod
def my_as_scalar(a):
# Since aes.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, Variable) and isinstance(a.type, TensorType):
return aesara.tensor.scalar_from_tensor(a)
"""
real_idx = get_idx_list(inputs, idx_list)
# TODO: Combine this with `as_index_literal`
def conv(val):
if val is None:
return None
elif isinstance(val, slice):
return slice(conv(val.start), conv(val.stop), conv(val.step))
else:
return aes.as_scalar(a)
try:
return get_scalar_constant_value(
val,
only_process_constants=only_process_constants,
elemwise=elemwise,
)
except NotScalarConstantError:
if allow_partial:
return val
else:
raise
return list(map(conv, real_idx))
def as_nontensor_scalar(a: Variable) -> aes.ScalarVariable:
"""Convert a value to a `Scalar` variable."""
# Since aes.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, Variable) and isinstance(a.type, TensorType):
return aesara.tensor.scalar_from_tensor(a)
else:
return aes.as_scalar(a)
class Subtensor(COp):
"""Basic NumPy indexing operator."""
check_input = False
view_map = {0: [0]}
_f16_ok = True
__props__ = ("idx_list",)
def __init__(self, idx_list):
# TODO: Provide the type of `self.idx_list`
self.idx_list = tuple(map(index_vars_to_types, idx_list))
def make_node(self, x, *inputs):
"""
......@@ -688,13 +688,13 @@ class Subtensor(COp):
"""
x = aesara.tensor.as_tensor_variable(x)
inputs = tuple(self.my_as_scalar(a) for a in inputs)
inputs = tuple(as_nontensor_scalar(a) for a in inputs)
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
input_types = Subtensor.collapse(
input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
......@@ -709,9 +709,9 @@ class Subtensor(COp):
)
# infer the broadcasting pattern
padded = self.get_constant_idx((None,) + inputs, allow_partial=True) + [
slice(None, None, None)
] * (x.type.ndim - len(idx_list))
padded = get_constant_idx(
self.idx_list, (None,) + inputs, allow_partial=True
) + [slice(None, None, None)] * (x.type.ndim - len(idx_list))
broadcastable = []
for i, (p, bc) in enumerate(zip(padded, x.type.broadcastable)):
if isinstance(p, slice):
......@@ -1435,7 +1435,7 @@ class IncSubtensor(COp):
):
if destroyhandler_tolerate_aliased is None:
destroyhandler_tolerate_aliased = []
self.idx_list = list(map(Subtensor.convert, idx_list))
self.idx_list = list(map(index_vars_to_types, idx_list))
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
......@@ -1483,13 +1483,13 @@ class IncSubtensor(COp):
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs = tuple(map(Subtensor.my_as_scalar, inputs))
inputs = tuple(map(as_nontensor_scalar, inputs))
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
input_types = Subtensor.collapse(
input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
......@@ -1513,17 +1513,17 @@ class IncSubtensor(COp):
x, y = inputs[:2]
indices = list(reversed(inputs[2:]))
def convert(entry):
def _convert(entry):
if isinstance(entry, Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(
convert(entry.start), convert(entry.stop), convert(entry.step)
_convert(entry.start), _convert(entry.stop), _convert(entry.step)
)
else:
return entry
cdata = tuple(map(convert, self.idx_list))
cdata = tuple(map(_convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
if not self.inplace:
......
......@@ -67,7 +67,9 @@ from aesara.tensor.subtensor import (
as_index_constant,
as_index_literal,
get_canonical_form_slice,
get_constant_idx,
get_idx_list,
get_slice_elements,
inc_subtensor,
)
from aesara.tensor.type import TensorType
......@@ -347,7 +349,7 @@ def local_useless_slice(fgraph, node):
# check if we removed something
if last_slice < len(slices):
subtens = Subtensor(slices[:last_slice])
sl_ins = Subtensor.collapse(
sl_ins = get_slice_elements(
slices[:last_slice], lambda x: isinstance(x, Variable)
)
out = subtens(node.inputs[0], *sl_ins)
......@@ -518,7 +520,7 @@ def local_subtensor_merge(fgraph, node):
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices)
sl_ins = Subtensor.collapse(
sl_ins = get_slice_elements(
merged_slices, lambda x: isinstance(x, Variable)
)
# Do not call make_node for test_value
......@@ -766,7 +768,9 @@ def local_subtensor_make_vector(fgraph, node):
# The index is a slice. If it's a constant slice, we can perform the
# index operation here.
try:
const_slice = node.op.get_constant_idx(node.inputs, allow_partial=False)[0]
const_slice = get_constant_idx(
node.op.idx_list, node.inputs, allow_partial=False
)[0]
ret = make_vector_op(*x.owner.inputs[const_slice])
copy_stack_trace(node.outputs, ret)
ret = patternbroadcast(ret, node.outputs[0].broadcastable)
......@@ -896,8 +900,11 @@ def local_useless_subtensor(fgraph, node):
shape_of = fgraph.shape_feature.shape_of
if isinstance(node.op, Subtensor):
cdata = node.op.get_constant_idx(
node.inputs, allow_partial=True, only_process_constants=True
cdata = get_constant_idx(
node.op.idx_list,
node.inputs,
allow_partial=True,
only_process_constants=True,
)
for pos, idx in enumerate(cdata):
if not isinstance(idx, slice):
......
......@@ -526,8 +526,8 @@ class _tensor_py_operators:
)
# Determine if advanced indexing is needed or not. The logic is
# already in `Subtensor.convert`: if it succeeds, standard indexing is
# used; if it fails with AdvancedIndexingError, advanced indexing is
# already in `index_vars_to_types`: if it succeeds, standard indexing is
# used; if it fails with `AdvancedIndexingError`, advanced indexing is
# used
advanced = False
for i, arg in enumerate(args):
......@@ -537,7 +537,7 @@ class _tensor_py_operators:
if arg is not np.newaxis:
try:
aet.subtensor.Subtensor.convert(arg)
aet.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
if advanced:
break
......@@ -589,7 +589,7 @@ class _tensor_py_operators:
else:
return aet.subtensor.Subtensor(args)(
self,
*aet.subtensor.Subtensor.collapse(
*aet.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable)
),
)
......
......@@ -23,6 +23,7 @@ from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedIndexingError,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
......@@ -35,6 +36,7 @@ from aesara.tensor.subtensor import (
basic_shape,
get_canonical_form_slice,
inc_subtensor,
index_vars_to_types,
indexed_result_shape,
set_subtensor,
take,
......@@ -2558,3 +2560,16 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
z = tensor3("z")
y = inc_subtensor(x[indices], z, set_instead_of_inc=set_instead_of_inc)
assert pprint(y) == exp_res
def test_index_vars_to_types():
x = aet.as_tensor_variable(np.array([True, False]))
with pytest.raises(AdvancedIndexingError):
index_vars_to_types(x)
with pytest.raises(TypeError):
index_vars_to_types(1)
res = index_vars_to_types(iscalar)
assert isinstance(res, scal.Scalar)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论