提交 5057f618 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a better Subtensor-to-indices helper function

上级 77c6e0b1
......@@ -12,7 +12,6 @@ from numpy.random import RandomState
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import CType
from aesara.ifelse import IfElse
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan
......@@ -58,14 +57,14 @@ from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import ( # This is essentially `np.take`; Boolean mask indexing and setting
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
get_idx_list,
indices_from_subtensor,
)
from aesara.tensor.type_other import MakeSlice
......@@ -628,20 +627,6 @@ def jax_funcify_IfElse(op):
return ifelse
def convert_indices(indices, entry):
if indices and isinstance(entry, CType):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
@jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op):
......@@ -649,15 +634,12 @@ def jax_funcify_Subtensor(op):
def subtensor(x, *ilists):
if idx_list:
cdata = get_idx_list((x,) + ilists, idx_list)
else:
cdata = ilists
indices = indices_from_subtensor(ilists, idx_list)
if len(cdata) == 1:
cdata = cdata[0]
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(cdata)
return x.__getitem__(indices)
return subtensor
......@@ -675,16 +657,11 @@ def jax_funcify_IncSubtensor(op):
jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist)
cdata = (
tuple(convert_indices(_ilist, idx) for idx in idx_list)
if idx_list
else _ilist
)
if len(cdata) == 1:
cdata = cdata[0]
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return jax_fn(x, cdata, y)
return jax_fn(x, indices, y)
return incsubtensor
......
......@@ -3,6 +3,7 @@ import sys
import warnings
from itertools import chain, groupby
from textwrap import dedent
from typing import Iterable, List, Tuple, Union
import numpy as np
......@@ -13,10 +14,11 @@ from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType
from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.scalar.basic import ScalarConstant
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import (
......@@ -62,6 +64,53 @@ invalid_tensor_types = (
)
def indices_from_subtensor(
op_indices: Iterable[Union[ScalarConstant]],
idx_list: List[Union[Type, slice, Variable]],
) -> Tuple[Union[slice, Variable]]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters
==========
op_indices
The flattened indices obtained from ``x.owner.inputs``, when ``x`` is a
``*Subtensor*`` ``Op``.
idx_list
The values describing the types of each dimension's index. This is
obtained from ``x.owner.inputs``, when ``x`` is a ``*Subtensor*``
``Op``.
Example
=======
array, *op_indices = subtensor_node.inputs
idx_list = getattr(subtensor_node.op, "idx_list", None)
indices = indices_from_subtensor(op_indices, idx_list)
"""
def convert_indices(indices, entry):
"""Reconstruct ``*Subtensor*`` index input parameter entries."""
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
op_indices = list(op_indices)
return (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(op_indices)
)
def as_index_constant(a):
"""Convert Python literals to Aesara constants--when possible--in Subtensor arguments.
......@@ -83,38 +132,8 @@ def as_index_constant(a):
return a
def get_idx_list(inputs, idx_list, get_count=False):
"""
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values.
If get_counts=True, instead returns the number of inputs consumed
during this process.
"""
# The number of indices
n = len(inputs) - 1
# The subtensor (or idx_list) does not depend on the inputs.
if n == 0:
return tuple(idx_list)
indices = list(reversed(list(inputs[1:])))
# General case
def convert(entry):
if isinstance(entry, CType):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start), convert(entry.stop), convert(entry.step))
else:
return entry
cdata = tuple(map(convert, idx_list))
if get_count:
return n - len(indices)
else:
return cdata
def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list)
def get_canonical_form_slice(theslice, length):
......@@ -508,7 +527,7 @@ class Subtensor(COp):
if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type
elif isinstance(entry, CType) and entry in scal_types:
elif isinstance(entry, Type) and entry in scal_types:
return entry
if (
......@@ -518,7 +537,7 @@ class Subtensor(COp):
):
return aes.get_scalar_type(entry.type.dtype)
elif (
isinstance(entry, CType)
isinstance(entry, Type)
and entry in tensor_types
and np.all(entry.broadcastable)
):
......@@ -641,7 +660,7 @@ class Subtensor(COp):
raise IndexError("too many indices for array")
input_types = Subtensor.collapse(
idx_list, lambda entry: isinstance(entry, CType)
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
raise IndexError(
......@@ -859,7 +878,7 @@ class Subtensor(COp):
inc_spec_pos(1)
if depth == 0:
is_slice.append(0)
elif isinstance(entry, CType):
elif isinstance(entry, Type):
init_cmds.append(
"subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()])
)
......@@ -1477,7 +1496,7 @@ class IncSubtensor(COp):
raise IndexError("too many indices for array")
input_types = Subtensor.collapse(
idx_list, lambda entry: isinstance(entry, CType)
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
raise IndexError(
......@@ -1501,7 +1520,7 @@ class IncSubtensor(COp):
indices = list(reversed(inputs[2:]))
def convert(entry):
if isinstance(entry, CType):
if isinstance(entry, Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论