提交 477fbafb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow inplace of Elemwise ScalarLoop

上级 34b91eff
......@@ -1302,18 +1302,6 @@ class ScalarOp(COp):
def __str__(self):
if hasattr(self, "name") and self.name:
return self.name
else:
param = [
(k, v)
for k, v in self.__dict__.items()
if k
not in ("name", "_op_use_c_code", "bool", "output_types_preference")
]
if param:
classname = self.__class__.__name__
args = ", ".join(f"{k}={v}" for k, v in param)
return f"{classname}{{{args}}}"
else:
return self.__class__.__name__
def c_code_cache_version(self):
......@@ -4102,6 +4090,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
def __init__(self, *args, **kwargs):
self.prepare_node_called = set()
super().__init__(*args, **kwargs)
def _cleanup_graph(self, inputs, outputs):
# TODO: We could convert to TensorVariable, optimize graph,
......
......@@ -55,6 +55,7 @@ class ScalarLoop(ScalarInnerGraphOp):
constant: Sequence[Variable] | None = None,
until: Variable | None = None,
name="ScalarLoop",
**kwargs,
):
if constant is None:
constant = []
......@@ -75,7 +76,7 @@ class ScalarLoop(ScalarInnerGraphOp):
self.nout = len(self.outputs)
self.name = name
super().__init__()
super().__init__(**kwargs)
def output_types(self, input_types):
return self.outputs_type
......@@ -115,7 +116,7 @@ class ScalarLoop(ScalarInnerGraphOp):
self._fgraph = fgraph
return self._fgraph
def clone(self):
def clone(self, name=None, **kwargs):
if self.is_while:
*update, until = self.outputs
else:
......@@ -127,7 +128,8 @@ class ScalarLoop(ScalarInnerGraphOp):
update=update,
constant=constant,
until=until,
name=self.name,
name=self.name if name is None else name,
**kwargs,
)
@property
......@@ -135,20 +137,7 @@ class ScalarLoop(ScalarInnerGraphOp):
raise NotImplementedError
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(ScalarLoop, out).__init__(output_types_preference, name)
return out
return self.clone(output_types_preference=output_types_preference, name=name)
def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1
......@@ -229,11 +218,11 @@ class ScalarLoop(ScalarInnerGraphOp):
c: f"%(i{int(i)})s"
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
}
update_subd = {
out_subd = {
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update])
}
until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd}
subd = {**carry_subd, **constant_subd, **until_subd}
for var in fgraph.variables:
if var.owner is None:
......@@ -257,11 +246,11 @@ class ScalarLoop(ScalarInnerGraphOp):
_c_code += "bool until = 1;\n\n"
# Copy carried inputs
for i, (var, name) in enumerate(carry_subd.items()):
copy_var_name = f"{name}_copy{i}"
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n"
carry_subd[var] = copy_var_name
subd[var] = copy_var_name
for i, (var, name) in enumerate(carry_subd.items(), start=1):
carry_var_name = f"{name}_carry{i}"
_c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n"
carry_subd[var] = carry_var_name
subd[var] = carry_var_name
# _c_code += 'printf("inputs=[");'
# for i in range(1, len(fgraph.inputs)):
......@@ -270,9 +259,8 @@ class ScalarLoop(ScalarInnerGraphOp):
_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort())
]
# Used by self.c_support_code_apply
self.nodenames = nodenames = []
i = 0
for j, node in enumerate(fgraph.toposort()):
......@@ -282,9 +270,13 @@ class ScalarLoop(ScalarInnerGraphOp):
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
nodename = f"%(nodename)s_subnode{int(j)}"
nodenames.append(nodename)
s = node.op.c_code(
node,
self.nodenames[j],
nodename,
# Any node that depended on `init` will depend on `update` instead
# The initial value of `update` was set to `init` before the loop
[subd[input] for input in node.inputs],
......@@ -294,10 +286,12 @@ class ScalarLoop(ScalarInnerGraphOp):
_c_code += s
_c_code += "\n"
# Set the carry variables to the output variables
# Update the carry variables to the output variables
_c_code += "\n"
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True):
_c_code += f"{init} = {update};\n"
for carry, out in zip(
carry_subd.values(), fgraph.outputs[:n_update], strict=True
):
_c_code += f"{carry} = {subd[out]};\n"
# _c_code += 'printf("%%ld\\n", i);\n'
# for carry in range(1, 10):
......@@ -309,6 +303,10 @@ class ScalarLoop(ScalarInnerGraphOp):
# End of the loop
_c_code += "}\n"
# Assign the carry variables to the outputs
for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True):
_c_code += f"{out} = {carry};\n"
# Output until flag
if self.is_while:
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
......@@ -343,4 +341,4 @@ class ScalarLoop(ScalarInnerGraphOp):
return res
def c_code_cache_version_outer(self):
return (3,)
return (4,)
......@@ -24,7 +24,6 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
......@@ -74,15 +73,6 @@ class InplaceElemwiseOptimizer(GraphRewriter):
for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node):
# TODO: Implement specialized InplaceCompositeOptimizer with logic
# needed to correctly assign inplace for multi-output Composites
# and ScalarLoops
if isinstance(node.op.scalar_op, ScalarLoop):
return []
else:
return range(len(node.outputs))
def apply(self, fgraph):
r"""
......@@ -173,7 +163,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
baseline = op.inplace_pattern
candidate_outputs = [
i for i in self.candidate_input_idxs(node) if i not in baseline
i for i in range(len(node.outputs)) if i not in baseline
]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
......@@ -190,7 +180,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
]
else:
baseline = []
candidate_outputs = self.candidate_input_idxs(node)
candidate_outputs = range(len(node.outputs))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
......
......@@ -3,7 +3,8 @@ import re
import numpy as np
import pytest
from pytensor import Mode, function
from pytensor import In, Mode, function
from pytensor.compile import get_default_mode
from pytensor.scalar import (
Composite,
as_scalar,
......@@ -18,6 +19,8 @@ from pytensor.scalar import (
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import exp as tensor_exp
from pytensor.tensor import lvector
from pytensor.tensor.elemwise import Elemwise
mode = pytest.mark.parametrize(
......@@ -255,3 +258,46 @@ def test_inner_loop(mode):
out16,
3**2 + 2.5,
)
@pytest.mark.parametrize("mutate_arg_idx", (0, 1, 2, 3))
def test_elemwise_inplace(mutate_arg_idx):
x0 = int64("x0")
y0 = int64("y0")
c = int64("c")
x = x0 - y0 + c
y = y0 - x0 + c
op = Elemwise(ScalarLoop(init=[x0, y0], constant=[c], update=[x, y]))
n_steps = lvector("n_steps")
x0v = lvector("x0")
y0v = lvector("y0")
cv = lvector("c")
xv, yv = op(n_steps, x0v, y0v, cv)
inputs = [
In(inp, mutable=i == mutate_arg_idx)
for i, inp in enumerate([n_steps, x0v, y0v, cv])
]
fn = function(
inputs,
[xv, yv],
mode=get_default_mode().including("inplace"),
)
fn.dprint()
elem_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(elem_op, Elemwise) and isinstance(elem_op.scalar_op, ScalarLoop)
destroy_map = elem_op.destroy_map
assert destroy_map == {0: [mutate_arg_idx]}
n_test = np.array([1, 4, 8], dtype="int64")
x0v_test = np.array([0, 0, 0], dtype="int64")
y0v_test = np.array([1, 1, 1], dtype="int64")
cv_test = np.array([0, 0, 0], dtype="int64")
xv_res, yv_res = fn(n_test, x0v_test, y0v_test, cv_test)
# Check the outputs are the destroyed inputs
assert xv_res is (n_test, x0v_test, y0v_test, cv_test)[mutate_arg_idx]
np.testing.assert_allclose(xv_res, [-1, -8, -128])
np.testing.assert_allclose(yv_res, [1, 8, 128])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论