提交 8ec11777 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add local Shape of SpecifyShape canonicalization

上级 f5cc20ca
...@@ -76,7 +76,7 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError ...@@ -76,7 +76,7 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import broadcast_shape from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padleft from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import discrete_dtypes, integer_dtypes, lscalar from aesara.tensor.type import discrete_dtypes, integer_dtypes, lscalar
...@@ -3536,3 +3536,20 @@ def local_useless_topk(fgraph, node): ...@@ -3536,3 +3536,20 @@ def local_useless_topk(fgraph, node):
)(x, k) )(x, k)
copy_stack_trace(node.outputs[0], new_output) copy_stack_trace(node.outputs[0], new_output)
return {old_output: new_output} return {old_output: new_output}
@register_useless
@register_canonicalize
@local_optimizer([Shape])
def local_Shape_of_SpecifyShape(fgraph, node):
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
if not isinstance(node.op, Shape):
return False
specified_shape = node.inputs[0]
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
return False
return [specified_shape.owner.inputs[1].astype(np.int64)]
...@@ -19,6 +19,7 @@ from aesara.graph.basic import Apply, Constant ...@@ -19,6 +19,7 @@ from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in from aesara.graph.opt import check_stack_trace, local_optimizer, out2in
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -82,7 +83,7 @@ from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub ...@@ -82,7 +83,7 @@ from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.shape import Reshape, Shape_i, reshape from aesara.tensor.shape import Reshape, Shape_i, reshape, specify_shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
Subtensor, Subtensor,
...@@ -3192,6 +3193,21 @@ class TestShapeFeature: ...@@ -3192,6 +3193,21 @@ class TestShapeFeature:
shape_feature.same_shape(x, o, 0, 1) shape_feature.same_shape(x, o, 0, 1)
@pytest.mark.parametrize(
"shape",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape(shape):
x = vector()
s = specify_shape(x, shape).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert shape in fgraph.variables
def test_assert_op_gradient(): def test_assert_op_gradient():
x = vector("x") x = vector("x")
assert_op = Assert() assert_op = Assert()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论