提交 d98f33c5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite to lift SpecifyShape through Elemwise Operations

上级 b128827c
......@@ -1026,6 +1026,52 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return [stack(shape).astype(np.int64)]
@register_canonicalize
@register_specialize
@node_rewriter([SpecifyShape])
def local_specify_shape_lift(fgraph, node):
"""Lift SpecifyShape of Elemwise towards the inputs."""
inp, *shape = node.inputs
if inp.owner and isinstance(inp.owner.op, Elemwise):
if len(inp.owner.outputs) != 1:
return None
elem_inps = inp.owner.inputs
if len(elem_inps) == 1:
new_elem_inps = [specify_shape(elem_inps[0], shape)]
else:
# Rewrite does not support case where specify_shape provides new broadcastable information,
# As that may require a specify_shape for each input
out_broadcastable = node.outputs[0].type.broadcastable
if out_broadcastable != inp.type.broadcastable:
return None
# All non-broadcastable dimensions of inputs must match the non-broadcastbale specify_shape dims
# We look for a sufficient input to assign all the specify_shape dims
# We could consider distributing the SpecifyShape across multiple inputs, when none is sufficient
nonbcast_dims = {
i
for i, (dim, bcast) in enumerate(zip(shape, out_broadcastable))
if (not bcast and not NoneConst.equals(dim))
}
new_elem_inps = elem_inps.copy()
for i, elem_inp in enumerate(elem_inps):
if all(
bcast_dim is False
for dim, bcast_dim in enumerate(elem_inp.type.broadcastable)
if dim in nonbcast_dims
):
new_elem_inps[i] = specify_shape(elem_inp, shape)
break
else: # no-break, no sufficient candidate found
return None
new_out = inp.owner.op.make_node(*new_elem_inps).outputs
copy_stack_trace(node.outputs, new_out)
return new_out
@register_useless
@register_canonicalize
@node_rewriter([Shape_i])
......
......@@ -491,6 +491,14 @@ def test_local_Shape_of_SpecifyShape_partial(s1):
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_specify_shape_lift():
x = vector("x")
out = specify_shape([1.0] + x, shape=(5,))
new_out = rewrite_graph(out)
assert equal_computations([new_out], [[1.0] + specify_shape(x, shape=(5,))])
def test_local_Shape_i_ground():
x = tensor(dtype=np.float64, shape=(None, 2))
s = Shape_i(1)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论