提交 f5cc20ca authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add local Subtensor of Shape canonicalization

上级 73146b4a
......@@ -10,7 +10,7 @@ import aesara
from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
......@@ -132,6 +132,36 @@ def as_index_constant(a):
return a
def as_index_literal(
idx: Union[Variable, slice, type(np.newaxis)]
) -> Union[int, slice, type(np.newaxis)]:
"""Convert a symbolic index element to its Python equivalent.
This is like the inverse of `as_index_constant`
Raises
------
NotScalarConstantError
"""
if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT):
return np.newaxis
if isinstance(idx, Constant):
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
if isinstance(getattr(idx, "type", None), SliceType):
idx = slice(*idx.owner.inputs)
if isinstance(idx, slice):
return slice(
as_index_literal(idx.start),
as_index_literal(idx.stop),
as_index_literal(idx.step),
)
raise NotScalarConstantError()
def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list)
......
import sys
from collections.abc import Iterable
import numpy as np
......@@ -16,6 +17,7 @@ from aesara.tensor.basic import (
ScalarFromTensor,
TensorFromScalar,
alloc,
as_tensor,
cast,
extract_constant,
get_scalar_constant_value,
......@@ -45,7 +47,7 @@ from aesara.tensor.math import (
minimum,
or_,
)
from aesara.tensor.shape import shape_padleft, shape_tuple
from aesara.tensor.shape import Shape, shape_padleft, shape_tuple
from aesara.tensor.sharedvar import TensorSharedVariable
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -58,6 +60,7 @@ from aesara.tensor.subtensor import (
advanced_subtensor,
advanced_subtensor1,
as_index_constant,
as_index_literal,
get_canonical_form_slice,
get_idx_list,
inc_subtensor,
......@@ -1560,3 +1563,58 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
copy_stack_trace(node.outputs, r)
return [r]
@register_specialize
@register_canonicalize
@local_optimizer([Subtensor])
def local_subtensor_shape_constant(fgraph, node):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
We want to convert graphs like
Subtensor{int64} [id A] ''
|Shape [id B] ''
| |<TensorType(float64, row)> [id C]
|ScalarConstant{0} [id D]
into
TensorConstant{1}
TODO: Something like `local_shape_to_shape_i` should be a general
canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were
the case, we could change this to only operate on `Shape_i`\s.
Currently, we're not handling them because they should only appear when
`ShapeFeature` is present, and it will also simplify/remove them.
"""
if not isinstance(node.op, Subtensor):
return False
shape = node.inputs[0]
if not (shape.owner and isinstance(shape.owner.op, Shape)):
return False
shape_arg = shape.owner.inputs[0]
(idx,) = get_idx_list(node.inputs, node.op.idx_list)
try:
idx_val = as_index_literal(idx)
except NotScalarConstantError:
return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType):
return False
shape_parts = shape_arg.type.broadcastable[idx_val]
if isinstance(shape_parts, Iterable):
if all(shape_parts):
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
elif shape_parts:
return [as_tensor(1, dtype=np.int64)]
......@@ -28,6 +28,7 @@ from aesara.tensor.subtensor import (
advanced_inc_subtensor1,
advanced_set_subtensor,
advanced_set_subtensor1,
as_index_literal,
basic_shape,
get_canonical_form_slice,
inc_subtensor,
......@@ -59,7 +60,7 @@ from aesara.tensor.type import (
tensor4,
vector,
)
from aesara.tensor.type_other import make_slice, slicetype
from aesara.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
from tests import unittest_tools as utt
from tests.tensor.utils import inplace_func, integers_ranged, random
......@@ -72,6 +73,29 @@ subtensor_ops = (
)
def test_as_index_literal():
res = as_index_literal(slice(None, aet.as_tensor(1)))
assert res == slice(None, 1)
res = as_index_literal(slice(aet.as_tensor(1), None))
assert res == slice(1, None)
res = as_index_literal(slice(None, None, aet.as_tensor(2)))
assert res == slice(None, None, 2)
res = as_index_literal(SliceConstant(slicetype, slice(None)))
assert res == slice(None)
res = as_index_literal(make_slice(None, aet.as_tensor(1)))
assert res == slice(None, 1)
res = as_index_literal(aet.as_tensor(2))
assert res == 2
res = as_index_literal(np.newaxis)
assert res is np.newaxis
res = as_index_literal(NoneConst)
assert res is np.newaxis
res = as_index_literal(NoneConst.clone())
assert res is np.newaxis
class TestSubtensor(utt.OptimizationTestMixin):
"""
This is designed to be sub-classed (e.g. by the GPU tests).
......
......@@ -10,8 +10,10 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config
from aesara.graph.basic import Variable, ancestors
from aesara.graph.basic import Constant, Variable, ancestors
from aesara.graph.opt import check_stack_trace
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.type import Type
from aesara.tensor import inplace
from aesara.tensor.basic import (
Alloc,
......@@ -22,7 +24,7 @@ from aesara.tensor.basic import (
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Dot, add, dot, exp, sqr
from aesara.tensor.shape import SpecifyShape, specify_shape
from aesara.tensor.shape import SpecifyShape, shape, specify_shape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -33,7 +35,10 @@ from aesara.tensor.subtensor import (
inc_subtensor,
set_subtensor,
)
from aesara.tensor.subtensor_opt import local_replace_AdvancedSubtensor
from aesara.tensor.subtensor_opt import (
local_replace_AdvancedSubtensor,
local_subtensor_shape_constant,
)
from aesara.tensor.type import (
bmatrix,
col,
......@@ -41,6 +46,7 @@ from aesara.tensor.type import (
fmatrix,
iscalar,
ivector,
lscalar,
lscalars,
matrix,
row,
......@@ -1036,14 +1042,14 @@ class TestLocalSubtensorMerge:
]
x = matrix("x")
for shape, sl1, sl2 in cases:
for s, sl1, sl2 in cases:
z = x[slice(*sl1)][slice(*sl2)]
f = function([x], z, mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f, ops_to_check=Subtensor)
x_val = self.rng.uniform(size=shape).astype(config.floatX)
x_val = self.rng.uniform(size=s).astype(config.floatX)
f(x_val)
def test_scalar5(self):
......@@ -1950,11 +1956,11 @@ def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape.
for shape in [(3, 5), (4, 6), (3, 8), (4, 7), (1, 5), (5, 1)]:
x = tensor(dtype=config.floatX, broadcastable=(shape[0] == 1, shape[1] == 1))
for s in [(3, 5), (4, 6), (3, 8), (4, 7), (1, 5), (5, 1)]:
x = tensor(dtype=config.floatX, broadcastable=(s[0] == 1, s[1] == 1))
xval = np.zeros(shape, dtype=config.floatX)
yval = np.arange(shape[1], dtype=config.floatX)
xval = np.zeros(s, dtype=config.floatX)
yval = np.arange(s[1], dtype=config.floatX)
for y in [shared(yval), aet.constant([1.0])]:
......@@ -1970,9 +1976,9 @@ def test_local_subtensor_of_alloc():
assert z_vec.ndim == 1
# results are vector
slicess = []
if shape[0] != 1:
if s[0] != 1:
slicess.append((2, slice(None)))
if shape[1] != 1:
if s[1] != 1:
slicess.append((slice(None), 3))
# results are matrix
......@@ -1992,3 +1998,46 @@ def test_local_subtensor_of_alloc():
assert not isinstance(f.maker.fgraph.toposort()[-1].op, Subtensor)
val = f(xval)
assert xval.__getitem__(slices).shape == val.shape
def test_local_subtensor_shape_constant():
x = tensor(np.float64, [True, False]).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert res.data == 1
# Make sure it's part of the canonicalizations
res = optimize_graph(x)
assert isinstance(res, Constant)
assert res.data == 1
x = tensor(np.float64, [True, False]).shape[lscalar()]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = tensor(np.float64, [True, False]).shape[0:]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = tensor(np.float64, [True, False]).shape[lscalar() :]
assert not local_subtensor_shape_constant.transform(None, x.owner)
x = tensor(np.float64, [True, True]).shape[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1])
x = tensor(np.float64, [False, True, True]).shape[1:]
(res,) = local_subtensor_shape_constant.transform(None, x.owner)
assert isinstance(res, Constant)
assert np.array_equal(res.data, [1, 1])
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
x = shape(Variable(MyType(), None, None))[0]
assert not local_subtensor_shape_constant.transform(None, x.owner)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论