提交 5057f618 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a better Subtensor-to-indices helper function

上级 77c6e0b1
...@@ -12,7 +12,6 @@ from numpy.random import RandomState ...@@ -12,7 +12,6 @@ from numpy.random import RandomState
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import CType
from aesara.ifelse import IfElse from aesara.ifelse import IfElse
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan from aesara.scan.op import Scan
...@@ -58,14 +57,14 @@ from aesara.tensor.nnet.sigm import ScalarSoftplus ...@@ -58,14 +57,14 @@ from aesara.tensor.nnet.sigm import ScalarSoftplus
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import ( # This is essentially `np.take`; Boolean mask indexing and setting from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
get_idx_list, indices_from_subtensor,
) )
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
...@@ -628,20 +627,6 @@ def jax_funcify_IfElse(op): ...@@ -628,20 +627,6 @@ def jax_funcify_IfElse(op):
return ifelse return ifelse
def convert_indices(indices, entry):
if indices and isinstance(entry, CType):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
@jax_funcify.register(Subtensor) @jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op): def jax_funcify_Subtensor(op):
...@@ -649,15 +634,12 @@ def jax_funcify_Subtensor(op): ...@@ -649,15 +634,12 @@ def jax_funcify_Subtensor(op):
def subtensor(x, *ilists): def subtensor(x, *ilists):
if idx_list: indices = indices_from_subtensor(ilists, idx_list)
cdata = get_idx_list((x,) + ilists, idx_list)
else:
cdata = ilists
if len(cdata) == 1: if len(indices) == 1:
cdata = cdata[0] indices = indices[0]
return x.__getitem__(cdata) return x.__getitem__(indices)
return subtensor return subtensor
...@@ -675,16 +657,11 @@ def jax_funcify_IncSubtensor(op): ...@@ -675,16 +657,11 @@ def jax_funcify_IncSubtensor(op):
jax_fn = jax.ops.index_add jax_fn = jax.ops.index_add
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist) indices = indices_from_subtensor(ilist, idx_list)
cdata = ( if len(indices) == 1:
tuple(convert_indices(_ilist, idx) for idx in idx_list) indices = indices[0]
if idx_list
else _ilist
)
if len(cdata) == 1:
cdata = cdata[0]
return jax_fn(x, cdata, y) return jax_fn(x, indices, y)
return incsubtensor return incsubtensor
......
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
import warnings import warnings
from itertools import chain, groupby from itertools import chain, groupby
from textwrap import dedent from textwrap import dedent
from typing import Iterable, List, Tuple, Union
import numpy as np import numpy as np
...@@ -13,10 +14,11 @@ from aesara.gradient import DisconnectedType ...@@ -13,10 +14,11 @@ from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar.basic import ScalarConstant
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import ( from aesara.tensor.exceptions import (
...@@ -62,6 +64,53 @@ invalid_tensor_types = ( ...@@ -62,6 +64,53 @@ invalid_tensor_types = (
) )
def indices_from_subtensor(
op_indices: Iterable[Union[ScalarConstant]],
idx_list: List[Union[Type, slice, Variable]],
) -> Tuple[Union[slice, Variable]]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters
==========
op_indices
The flattened indices obtained from ``x.owner.inputs``, when ``x`` is a
``*Subtensor*`` ``Op``.
idx_list
The values describing the types of each dimension's index. This is
obtained from ``x.owner.inputs``, when ``x`` is a ``*Subtensor*``
``Op``.
Example
=======
array, *op_indices = subtensor_node.inputs
idx_list = getattr(subtensor_node.op, "idx_list", None)
indices = indices_from_subtensor(op_indices, idx_list)
"""
def convert_indices(indices, entry):
"""Reconstruct ``*Subtensor*`` index input parameter entries."""
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
op_indices = list(op_indices)
return (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(op_indices)
)
def as_index_constant(a): def as_index_constant(a):
"""Convert Python literals to Aesara constants--when possible--in Subtensor arguments. """Convert Python literals to Aesara constants--when possible--in Subtensor arguments.
...@@ -83,38 +132,8 @@ def as_index_constant(a): ...@@ -83,38 +132,8 @@ def as_index_constant(a):
return a return a
def get_idx_list(inputs, idx_list, get_count=False): def get_idx_list(inputs, idx_list):
""" return indices_from_subtensor(inputs[1:], idx_list)
Given a list of inputs to the subtensor and its idx_list reorders
the inputs according to the idx list to get the right values.
If get_counts=True, instead returns the number of inputs consumed
during this process.
"""
# The number of indices
n = len(inputs) - 1
# The subtensor (or idx_list) does not depend on the inputs.
if n == 0:
return tuple(idx_list)
indices = list(reversed(list(inputs[1:])))
# General case
def convert(entry):
if isinstance(entry, CType):
return indices.pop()
elif isinstance(entry, slice):
return slice(convert(entry.start), convert(entry.stop), convert(entry.step))
else:
return entry
cdata = tuple(map(convert, idx_list))
if get_count:
return n - len(indices)
else:
return cdata
def get_canonical_form_slice(theslice, length): def get_canonical_form_slice(theslice, length):
...@@ -508,7 +527,7 @@ class Subtensor(COp): ...@@ -508,7 +527,7 @@ class Subtensor(COp):
if isinstance(entry, Variable) and entry.type in scal_types: if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type return entry.type
elif isinstance(entry, CType) and entry in scal_types: elif isinstance(entry, Type) and entry in scal_types:
return entry return entry
if ( if (
...@@ -518,7 +537,7 @@ class Subtensor(COp): ...@@ -518,7 +537,7 @@ class Subtensor(COp):
): ):
return aes.get_scalar_type(entry.type.dtype) return aes.get_scalar_type(entry.type.dtype)
elif ( elif (
isinstance(entry, CType) isinstance(entry, Type)
and entry in tensor_types and entry in tensor_types
and np.all(entry.broadcastable) and np.all(entry.broadcastable)
): ):
...@@ -641,7 +660,7 @@ class Subtensor(COp): ...@@ -641,7 +660,7 @@ class Subtensor(COp):
raise IndexError("too many indices for array") raise IndexError("too many indices for array")
input_types = Subtensor.collapse( input_types = Subtensor.collapse(
idx_list, lambda entry: isinstance(entry, CType) idx_list, lambda entry: isinstance(entry, Type)
) )
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
raise IndexError( raise IndexError(
...@@ -859,7 +878,7 @@ class Subtensor(COp): ...@@ -859,7 +878,7 @@ class Subtensor(COp):
inc_spec_pos(1) inc_spec_pos(1)
if depth == 0: if depth == 0:
is_slice.append(0) is_slice.append(0)
elif isinstance(entry, CType): elif isinstance(entry, Type):
init_cmds.append( init_cmds.append(
"subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()]) "subtensor_spec[%i] = %s;" % (spec_pos(), inputs[input_pos()])
) )
...@@ -1477,7 +1496,7 @@ class IncSubtensor(COp): ...@@ -1477,7 +1496,7 @@ class IncSubtensor(COp):
raise IndexError("too many indices for array") raise IndexError("too many indices for array")
input_types = Subtensor.collapse( input_types = Subtensor.collapse(
idx_list, lambda entry: isinstance(entry, CType) idx_list, lambda entry: isinstance(entry, Type)
) )
if len(inputs) != len(input_types): if len(inputs) != len(input_types):
raise IndexError( raise IndexError(
...@@ -1501,7 +1520,7 @@ class IncSubtensor(COp): ...@@ -1501,7 +1520,7 @@ class IncSubtensor(COp):
indices = list(reversed(inputs[2:])) indices = list(reversed(inputs[2:]))
def convert(entry): def convert(entry):
if isinstance(entry, CType): if isinstance(entry, Type):
return indices.pop() return indices.pop()
elif isinstance(entry, slice): elif isinstance(entry, slice):
return slice( return slice(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论