提交 cb4da6ed authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary Subtensors from broadcast_to

上级 605b609f
......@@ -1590,7 +1590,6 @@ class BroadcastTo(Op):
def make_node(self, a, *shape):
a = at.as_tensor_variable(a)
shape = at.as_tensor_variable(shape, ndim=1)
shape, bcast = at.infer_broadcastable(shape)
......@@ -1658,7 +1657,6 @@ def broadcast_to(
"""
x = at.as_tensor(x)
shape = at.as_tensor(shape, ndim=1, dtype="int64")
shape_len = get_vector_length(shape)
if x.ndim == 0 and shape_len == 0:
......
......@@ -1123,6 +1123,14 @@ class TestBroadcastTo(utt.InferShapeTester):
y = broadcast_to(x, ())
assert y is x
def test_avoid_useless_subtensors(self):
x = scalar()
y = broadcast_to(x, (1, 2))
# There shouldn't be any unnecessary `Subtensor` operations
# (e.g. from `at.as_tensor((1, 2))[0]`)
assert y.owner.inputs[1].owner is None
assert y.owner.inputs[2].owner is None
@config.change_flags(compute_test_value="raise")
def test_perform(self):
a = scalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论