提交 34b91eff authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow inplace of Elemwise Composite with multiple outputs

上级 0699b48d
......@@ -4441,16 +4441,12 @@ class Composite(ScalarInnerGraphOp):
if hasattr(self, "_c_code"):
return self._c_code
subd = dict(
chain(
((e, f"%(i{int(i)})s") for i, e in enumerate(self.fgraph.inputs)),
((e, f"%(o{int(i)})s") for i, e in enumerate(self.fgraph.outputs)),
)
)
fg = self.fgraph
subd = {e: f"%(i{int(i)})s" for i, e in enumerate(fg.inputs)}
for var in self.fgraph.variables:
for var in fg.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
if var not in fg.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = f"({var.type.c_literal(var.data)})"
......@@ -4465,23 +4461,24 @@ class Composite(ScalarInnerGraphOp):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}"
for j, n in enumerate(self.fgraph.toposort())
]
self.nodenames = nodenames = [] # Used by self.c_support_code_apply
_c_code = "{\n"
i = 0
for j, node in enumerate(self.fgraph.toposort()):
for j, node in enumerate(fg.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
nodename = f"%(nodename)s_subnode{int(j)}"
nodenames.append(nodename)
s = node.op.c_code(
node,
self.nodenames[j],
nodename,
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
......@@ -4489,6 +4486,10 @@ class Composite(ScalarInnerGraphOp):
_c_code += s
_c_code += "\n"
# Copy the temporary outputs to the real outputs
for i, output in enumerate(fg.outputs):
_c_code += f"%(o{int(i)})s = {subd[output]};\n"
_c_code += "}\n"
self._c_code = _c_code
......@@ -4512,7 +4513,7 @@ class Composite(ScalarInnerGraphOp):
return self.c_code_template % d
def c_code_cache_version_outer(self) -> tuple[int, ...]:
return (5,)
return (6,)
class Compositef32:
......
......@@ -80,8 +80,6 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# and ScalarLoops
if isinstance(node.op.scalar_op, ScalarLoop):
return []
if isinstance(node.op.scalar_op, ps.Composite) and (len(node.outputs) > 1):
return []
else:
return range(len(node.outputs))
......
......@@ -1104,7 +1104,8 @@ class TestFusion:
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
)
def test_fusion_multiout_inplace(self):
@pytest.mark.parametrize("linker", ["cvm", "py"])
def test_fusion_multiout_inplace(self, linker):
x = vector("x")
# Create Composite where inplacing the first non-constant output would corrupt the second output
......@@ -1118,17 +1119,16 @@ class TestFusion:
f = pytensor.function(
[In(x, mutable=True)],
outs,
mode=self.mode.including("inplace"),
mode=Mode(linker=linker, optimizer=self.rewrites.including("inplace")),
)
(composite_node,) = f.maker.fgraph.apply_nodes
# Destroy map must be None or the last toposorted output
destroy_map = composite_node.op.destroy_map
assert (destroy_map == {}) or (
destroy_map == {1: [composite_node.inputs.index(x)]}
)
assert destroy_map == {0: [0]}
res = f([0, 1, 2])
inp = np.array([0, 1, 2], dtype=config.floatX)
res = f(inp)
assert not np.allclose(inp, [0, 1, 2])
assert np.allclose(res[0], [1, 2, 3])
assert np.allclose(res[1], np.cos([1, 2, 3]) + np.array([0, 1, 2]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论