提交 316cfd1d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Disable invalid inplace logic for multiple-output Composites

上级 390a8d67
......@@ -59,6 +59,14 @@ class InplaceElemwiseOptimizer(GraphRewriter):
for n in sorted(ndim.keys()):
print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node):
if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1:
# TODO: Implement specialized InplaceCompositeOptimizer with logic
# needed to correctly assign inplace for multi-output Composites
return []
else:
return range(len(node.outputs))
def apply(self, fgraph):
r"""
......@@ -149,7 +157,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
baseline = op.inplace_pattern
candidate_outputs = [
i for i in range(len(node.outputs)) if i not in baseline
i for i in self.candidate_input_idxs(node) if i not in baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
......@@ -167,7 +175,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
]
else:
baseline = []
candidate_outputs = list(range(len(node.outputs)))
candidate_outputs = self.candidate_input_idxs(node)
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
......
......@@ -4,6 +4,7 @@ import numpy as np
import pytest
import pytensor
from pytensor import In
from pytensor import scalar as aes
from pytensor import shared
from pytensor import tensor as at
......@@ -1024,6 +1025,34 @@ class TestFusion:
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
)
def test_fusion_multiout_inplace(self):
x = vector("x")
# Create Composite where inplacing the first non-constant output would corrupt the second output
xs = aes.float64("xs")
outs = (
Elemwise(Composite([xs], [xs + 1, aes.cos(xs + 1) + xs]))
.make_node(x)
.outputs
)
f = pytensor.function(
[In(x, mutable=True)],
outs,
mode=self.mode.including("inplace"),
)
(composite_node,) = f.maker.fgraph.apply_nodes
# Destroy map must be None or the last toposorted output
destroy_map = composite_node.op.destroy_map
assert (destroy_map == {}) or (
destroy_map == {1: [composite_node.inputs.index(x)]}
)
res = f([0, 1, 2])
assert np.allclose(res[0], [1, 2, 3])
assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2]))
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论