提交 748a3e2a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement Subtensor SpecifyShape lift

上级 a8103cfd
...@@ -46,7 +46,13 @@ from aesara.tensor.math import ( ...@@ -46,7 +46,13 @@ from aesara.tensor.math import (
minimum, minimum,
or_, or_,
) )
from aesara.tensor.shape import Shape, shape_padleft, shape_tuple from aesara.tensor.shape import (
Shape,
SpecifyShape,
shape_padleft,
shape_tuple,
specify_shape,
)
from aesara.tensor.sharedvar import TensorSharedVariable from aesara.tensor.sharedvar import TensorSharedVariable
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -1614,3 +1620,42 @@ def local_subtensor_shape_constant(fgraph, node): ...@@ -1614,3 +1620,42 @@ def local_subtensor_shape_constant(fgraph, node):
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
elif shape_parts: elif shape_parts:
return [as_tensor(1, dtype=np.int64)] return [as_tensor(1, dtype=np.int64)]
@register_canonicalize
@local_optimizer([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i]`` to ``specify_shape(x[i], s[i])``."""
if not isinstance(node.op, Subtensor):
return False
specify_shape_node = node.inputs[0]
if not (
specify_shape_node.owner
and isinstance(specify_shape_node.owner.op, SpecifyShape)
):
return False
obj_arg = specify_shape_node.owner.inputs[0]
shape_arg = specify_shape_node.owner.inputs[1]
indices = get_idx_list(node.inputs, node.op.idx_list)
if len(indices) > 1 or len(indices) == 0:
return False
if isinstance(indices[0], slice) or isinstance(
getattr(indices[0], "type", None), SliceType
):
return False
new_obj_arg = obj_arg[indices]
if new_obj_arg.ndim == 0:
new_shape_arg = as_tensor([], dtype=np.int64)
else:
new_shape_arg = as_tensor(shape_arg[indices], dtype=np.int64, ndim=1)
return [specify_shape(new_obj_arg, new_shape_arg)]
...@@ -13,6 +13,7 @@ from aesara.configdefaults import config ...@@ -13,6 +13,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, ancestors from aesara.graph.basic import Constant, Variable, ancestors
from aesara.graph.opt import check_stack_trace from aesara.graph.opt import check_stack_trace
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -57,6 +58,7 @@ from aesara.tensor.type import ( ...@@ -57,6 +58,7 @@ from aesara.tensor.type import (
tensor4, tensor4,
vector, vector,
) )
from aesara.tensor.type_other import slicetype
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.unittest_tools import create_aesara_param from tests.unittest_tools import create_aesara_param
...@@ -2066,3 +2068,71 @@ def test_local_subtensor_shape_constant(): ...@@ -2066,3 +2068,71 @@ def test_local_subtensor_shape_constant():
x = shape(Variable(MyType(), None, None))[0] x = shape(Variable(MyType(), None, None))[0]
assert not local_subtensor_shape_constant.transform(None, x.owner) assert not local_subtensor_shape_constant.transform(None, x.owner)
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val",
[
(
matrix(),
(iscalar(), iscalar()),
(1,),
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([2, 2], dtype=np.int64),
),
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
],
)
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
opts = OptimizationQuery(include=[None])
no_opt_mode = Mode(optimizer=opts)
y_val_fn = function([x] + list(s), y, on_unused_input="ignore", mode=no_opt_mode)
y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val]))
# This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False)
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val] + [s_ for s_ in s_val]))
assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize(
"x, s, idx",
[
(
matrix(),
(iscalar(), iscalar()),
(slice(1, None),),
),
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 0),
),
],
)
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论