提交 6bb459ad authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a rewrite for useless SpecifyShapes

上级 748a3e2a
......@@ -3462,6 +3462,25 @@ def local_useless_topk(fgraph, node):
return {old_output: new_output}
@register_useless
@register_canonicalize
@local_optimizer([SpecifyShape])
def local_useless_SpecifyShape(fgraph, node):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s1)``."""
if not isinstance(node.op, SpecifyShape):
return False
obj = node.inputs[0]
if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)):
return False
# TODO: We could make sure that the shapes of the two `SpecifyShape`s are
# the same.
return [obj]
@register_useless
@register_canonicalize
@local_optimizer([Shape])
......
......@@ -93,7 +93,7 @@ from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.shape import Reshape, Shape_i, reshape, specify_shape
from aesara.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor1,
Subtensor,
......@@ -3602,3 +3602,18 @@ def test_local_Unique_second(
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
def test_local_useless_SpecifyShape():
x = matrix()
s = aet.as_tensor([iscalar(), iscalar()])
y = specify_shape(specify_shape(x, s), s)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg, clone=False, include=["canonicalize", "local_useless_SpecifyShape"]
)
y_opt = y_opt_fg.outputs[0]
assert isinstance(y_opt.owner.op, SpecifyShape)
assert y_opt.owner.inputs[0] == x
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论