提交 4cda2e52 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use the indestructible tag to prevent in-place updates to BroadcastTo output

上级 9655d978
...@@ -1575,6 +1575,9 @@ class BroadcastTo(Op): ...@@ -1575,6 +1575,9 @@ class BroadcastTo(Op):
out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)() out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)()
# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True
return Apply(self, [a] + shape, [out]) return Apply(self, [a] + shape, [out])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
......
...@@ -5,9 +5,11 @@ import aesara ...@@ -5,9 +5,11 @@ import aesara
from aesara import function from aesara import function
from aesara import tensor as aet from aesara import tensor as aet
from aesara.assert_op import Assert from aesara.assert_op import Assert
from aesara.compile.mode import Mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import grad from aesara.gradient import grad
from aesara.graph.basic import applys_between from aesara.graph.basic import applys_between
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.extra_ops import ( from aesara.tensor.extra_ops import (
Bartlett, Bartlett,
...@@ -41,6 +43,7 @@ from aesara.tensor.extra_ops import ( ...@@ -41,6 +43,7 @@ from aesara.tensor.extra_ops import (
unravel_index, unravel_index,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
dmatrix, dmatrix,
...@@ -1155,3 +1158,22 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1155,3 +1158,22 @@ class TestBroadcastTo(utt.InferShapeTester):
[np.random.rand(2, 1, 3).astype(config.floatX), 6, 2, 5, 3], [np.random.rand(2, 1, 3).astype(config.floatX), 6, 2, 5, 3],
self.op_class, self.op_class,
) )
def test_inplace(self):
"""Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``."""
a = aet.zeros((5,))
d = aet.vector("d")
c = aet.set_subtensor(a[np.r_[0, 1, 3]], d)
b = broadcast_to(c, (5,))
q = b[np.r_[0, 1, 3]]
e = aet.set_subtensor(q, np.r_[0, 0, 0])
opts = Query(include=["inplace"])
py_mode = Mode("py", opts)
e_fn = function([d], e, mode=py_mode)
advincsub_node = e_fn.maker.fgraph.outputs[0].owner
assert isinstance(advincsub_node.op, AdvancedIncSubtensor1)
assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo)
assert advincsub_node.op.inplace is False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论