提交 316cfd1d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Disable invalid inplace logic for multiple-output Composites

上级 390a8d67
...@@ -59,6 +59,14 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -59,6 +59,14 @@ class InplaceElemwiseOptimizer(GraphRewriter):
for n in sorted(ndim.keys()): for n in sorted(ndim.keys()):
print(blanc, n, ndim[n], file=stream) print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node):
if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1:
# TODO: Implement specialized InplaceCompositeOptimizer with logic
# needed to correctly assign inplace for multi-output Composites
return []
else:
return range(len(node.outputs))
def apply(self, fgraph): def apply(self, fgraph):
r""" r"""
...@@ -149,7 +157,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -149,7 +157,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
baseline = op.inplace_pattern baseline = op.inplace_pattern
candidate_outputs = [ candidate_outputs = [
i for i in range(len(node.outputs)) if i not in baseline i for i in self.candidate_input_idxs(node) if i not in baseline
] ]
# node inputs that are Constant, already destroyed, # node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as # or fgraph protected inputs and fgraph outputs can't be used as
...@@ -167,7 +175,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -167,7 +175,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
] ]
else: else:
baseline = [] baseline = []
candidate_outputs = list(range(len(node.outputs))) candidate_outputs = self.candidate_input_idxs(node)
# node inputs that are Constant, already destroyed, # node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace # fgraph protected inputs and fgraph outputs can't be used as inplace
# target. # target.
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
from pytensor import In
from pytensor import scalar as aes from pytensor import scalar as aes
from pytensor import shared from pytensor import shared
from pytensor import tensor as at from pytensor import tensor as at
...@@ -1024,6 +1025,34 @@ class TestFusion: ...@@ -1024,6 +1025,34 @@ class TestFusion:
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
) )
def test_fusion_multiout_inplace(self):
x = vector("x")
# Create Composite where inplacing the first non-constant output would corrupt the second output
xs = aes.float64("xs")
outs = (
Elemwise(Composite([xs], [xs + 1, aes.cos(xs + 1) + xs]))
.make_node(x)
.outputs
)
f = pytensor.function(
[In(x, mutable=True)],
outs,
mode=self.mode.including("inplace"),
)
(composite_node,) = f.maker.fgraph.apply_nodes
# Destroy map must be None or the last toposorted output
destroy_map = composite_node.op.destroy_map
assert (destroy_map == {}) or (
destroy_map == {1: [composite_node.inputs.index(x)]}
)
res = f([0, 1, 2])
assert np.allclose(res[0], [1, 2, 3])
assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2]))
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler") @pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_no_c_code(self): def test_no_c_code(self):
r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" r"""Make sure we avoid fusions for `Op`\s without C code implementations."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论