提交 8af9aa23 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a rewrite that removes useless scalar BroadcastTo Ops

上级 85326def
......@@ -3668,3 +3668,17 @@ def local_Unique_second(fgraph, node):
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]
if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
......@@ -3576,3 +3576,16 @@ def test_printing():
mv = MakeVector(config.floatX)
v = mv(a, b)
assert pprint(v) == "[a, b]"
def test_local_remove_scalar_BroadcastTo():
x = dscalar()
y = BroadcastTo()(x, ())
assert isinstance(y.owner.op, BroadcastTo)
res = optimize_graph(
y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"]
)
assert res is x
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论