提交 837d9f9a authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix local_useless_SpecifyShape rewrite

This rewrite predated partial shape specification in SpecifyShape, and as such ignored possible shape refinement over consecutive SpecifyShapes. It now merges information across consecutive SpecifyShape. Unlike the original, preference is given to the outer SpecifyShape, similar to what local_rebroadcast_lift does.
上级 a9f1e284
...@@ -81,6 +81,7 @@ from aesara.tensor.shape import ( ...@@ -81,6 +81,7 @@ from aesara.tensor.shape import (
SpecifyShape, SpecifyShape,
shape_i, shape_i,
shape_padleft, shape_padleft,
specify_shape,
) )
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
...@@ -3425,8 +3426,10 @@ def local_useless_topk(fgraph, node): ...@@ -3425,8 +3426,10 @@ def local_useless_topk(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@local_optimizer([SpecifyShape]) @local_optimizer([SpecifyShape])
def local_useless_SpecifyShape(fgraph, node): def local_merge_consecutive_specify_shape(fgraph, node):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s1)``.""" """Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
"""
if not isinstance(node.op, SpecifyShape): if not isinstance(node.op, SpecifyShape):
return False return False
...@@ -3435,10 +3438,15 @@ def local_useless_SpecifyShape(fgraph, node): ...@@ -3435,10 +3438,15 @@ def local_useless_SpecifyShape(fgraph, node):
if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)): if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)):
return False return False
# TODO: We could make sure that the shapes of the two `SpecifyShape`s are inner_obj, *shape = obj.owner.inputs
for dim, sh in enumerate(node.inputs[1:]):
if not NoneConst.equals(sh):
shape[dim] = sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
# the same. # the same.
return [obj] return [specify_shape(inner_obj, shape)]
@register_useless @register_useless
......
...@@ -115,6 +115,7 @@ from aesara.tensor.type import ( ...@@ -115,6 +115,7 @@ from aesara.tensor.type import (
fvector, fvector,
imatrices, imatrices,
iscalar, iscalar,
iscalars,
ivector, ivector,
lscalar, lscalar,
lvector, lvector,
...@@ -3491,14 +3492,16 @@ def test_local_Unique_second( ...@@ -3491,14 +3492,16 @@ def test_local_Unique_second(
assert np.array_equal(y_exp_val, y_val) assert np.array_equal(y_exp_val, y_val)
def test_local_useless_SpecifyShape(): def test_local_merge_consecutive_specify_shape():
x = matrix() x = matrix()
s = at.as_tensor([iscalar(), iscalar()]) s = at.as_tensor([iscalar(), iscalar()])
y = specify_shape(specify_shape(x, s), s) y = specify_shape(specify_shape(x, s), s)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False) y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph( y_opt_fg = optimize_graph(
y_fg, clone=False, include=["canonicalize", "local_useless_SpecifyShape"] y_fg,
clone=False,
include=["canonicalize", "local_merge_consecutive_specify_shape"],
) )
y_opt = y_opt_fg.outputs[0] y_opt = y_opt_fg.outputs[0]
...@@ -3506,6 +3509,23 @@ def test_local_useless_SpecifyShape(): ...@@ -3506,6 +3509,23 @@ def test_local_useless_SpecifyShape():
assert y_opt.owner.inputs[0] == x assert y_opt.owner.inputs[0] == x
def test_local_merge_consecutive_specify_shape2():
x = tensor3()
s1, s2, s3, s4 = iscalars("s1", "s2", "s3", "s4")
y = specify_shape(specify_shape(x, [s1, s2, None]), [None, s3, s4])
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_merge_consecutive_specify_shape"],
)
y_opt = y_opt_fg.outputs[0]
assert isinstance(y_opt.owner.op, SpecifyShape)
assert tuple(y_opt.owner.inputs) == (x, s1, s3, s4)
def test_printing(): def test_printing():
a, b = scalars("ab") a, b = scalars("ab")
mv = MakeVector(config.floatX) mv = MakeVector(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论