提交 477fbafb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Allow inplace of Elemwise ScalarLoop

上级 34b91eff
...@@ -1302,18 +1302,6 @@ class ScalarOp(COp): ...@@ -1302,18 +1302,6 @@ class ScalarOp(COp):
def __str__(self): def __str__(self):
if hasattr(self, "name") and self.name: if hasattr(self, "name") and self.name:
return self.name return self.name
else:
param = [
(k, v)
for k, v in self.__dict__.items()
if k
not in ("name", "_op_use_c_code", "bool", "output_types_preference")
]
if param:
classname = self.__class__.__name__
args = ", ".join(f"{k}={v}" for k, v in param)
return f"{classname}{{{args}}}"
else:
return self.__class__.__name__ return self.__class__.__name__
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -4102,6 +4090,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph): ...@@ -4102,6 +4090,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.prepare_node_called = set() self.prepare_node_called = set()
super().__init__(*args, **kwargs)
def _cleanup_graph(self, inputs, outputs): def _cleanup_graph(self, inputs, outputs):
# TODO: We could convert to TensorVariable, optimize graph, # TODO: We could convert to TensorVariable, optimize graph,
......
...@@ -55,6 +55,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -55,6 +55,7 @@ class ScalarLoop(ScalarInnerGraphOp):
constant: Sequence[Variable] | None = None, constant: Sequence[Variable] | None = None,
until: Variable | None = None, until: Variable | None = None,
name="ScalarLoop", name="ScalarLoop",
**kwargs,
): ):
if constant is None: if constant is None:
constant = [] constant = []
...@@ -75,7 +76,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -75,7 +76,7 @@ class ScalarLoop(ScalarInnerGraphOp):
self.nout = len(self.outputs) self.nout = len(self.outputs)
self.name = name self.name = name
super().__init__() super().__init__(**kwargs)
def output_types(self, input_types): def output_types(self, input_types):
return self.outputs_type return self.outputs_type
...@@ -115,7 +116,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -115,7 +116,7 @@ class ScalarLoop(ScalarInnerGraphOp):
self._fgraph = fgraph self._fgraph = fgraph
return self._fgraph return self._fgraph
def clone(self): def clone(self, name=None, **kwargs):
if self.is_while: if self.is_while:
*update, until = self.outputs *update, until = self.outputs
else: else:
...@@ -127,7 +128,8 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -127,7 +128,8 @@ class ScalarLoop(ScalarInnerGraphOp):
update=update, update=update,
constant=constant, constant=constant,
until=until, until=until,
name=self.name, name=self.name if name is None else name,
**kwargs,
) )
@property @property
...@@ -135,20 +137,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -135,20 +137,7 @@ class ScalarLoop(ScalarInnerGraphOp):
raise NotImplementedError raise NotImplementedError
def make_new_inplace(self, output_types_preference=None, name=None): def make_new_inplace(self, output_types_preference=None, name=None):
""" return self.clone(output_types_preference=output_types_preference, name=name)
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(ScalarLoop, out).__init__(output_types_preference, name)
return out
def make_node(self, n_steps, *inputs): def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1 assert len(inputs) == self.nin - 1
...@@ -229,11 +218,11 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -229,11 +218,11 @@ class ScalarLoop(ScalarInnerGraphOp):
c: f"%(i{int(i)})s" c: f"%(i{int(i)})s"
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1) for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
} }
update_subd = { out_subd = {
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update]) u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update])
} }
until_subd = {u: "until" for u in fgraph.outputs[n_update:]} until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd} subd = {**carry_subd, **constant_subd, **until_subd}
for var in fgraph.variables: for var in fgraph.variables:
if var.owner is None: if var.owner is None:
...@@ -257,11 +246,11 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -257,11 +246,11 @@ class ScalarLoop(ScalarInnerGraphOp):
_c_code += "bool until = 1;\n\n" _c_code += "bool until = 1;\n\n"
# Copy carried inputs # Copy carried inputs
for i, (var, name) in enumerate(carry_subd.items()): for i, (var, name) in enumerate(carry_subd.items(), start=1):
copy_var_name = f"{name}_copy{i}" carry_var_name = f"{name}_carry{i}"
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n" _c_code += f"{var.type.dtype_specs()[1]} {carry_var_name} = {name};\n"
carry_subd[var] = copy_var_name carry_subd[var] = carry_var_name
subd[var] = copy_var_name subd[var] = carry_var_name
# _c_code += 'printf("inputs=[");' # _c_code += 'printf("inputs=[");'
# for i in range(1, len(fgraph.inputs)): # for i in range(1, len(fgraph.inputs)):
...@@ -270,9 +259,8 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -270,9 +259,8 @@ class ScalarLoop(ScalarInnerGraphOp):
_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n" _c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n"
self.nodenames = [ # Used by self.c_support_code_apply
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort()) self.nodenames = nodenames = []
]
i = 0 i = 0
for j, node in enumerate(fgraph.toposort()): for j, node in enumerate(fgraph.toposort()):
...@@ -282,9 +270,13 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -282,9 +270,13 @@ class ScalarLoop(ScalarInnerGraphOp):
name = f"V%(id)s_tmp{int(i)}" name = f"V%(id)s_tmp{int(i)}"
subd[output] = name subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n" _c_code += f"{output.type.dtype_specs()[1]} {name};\n"
nodename = f"%(nodename)s_subnode{int(j)}"
nodenames.append(nodename)
s = node.op.c_code( s = node.op.c_code(
node, node,
self.nodenames[j], nodename,
# Any node that depended on `init` will depend on `update` instead # Any node that depended on `init` will depend on `update` instead
# The initial value of `update` was set to `init` before the loop # The initial value of `update` was set to `init` before the loop
[subd[input] for input in node.inputs], [subd[input] for input in node.inputs],
...@@ -294,10 +286,12 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -294,10 +286,12 @@ class ScalarLoop(ScalarInnerGraphOp):
_c_code += s _c_code += s
_c_code += "\n" _c_code += "\n"
# Set the carry variables to the output variables # Update the carry variables to the output variables
_c_code += "\n" _c_code += "\n"
for init, update in zip(carry_subd.values(), update_subd.values(), strict=True): for carry, out in zip(
_c_code += f"{init} = {update};\n" carry_subd.values(), fgraph.outputs[:n_update], strict=True
):
_c_code += f"{carry} = {subd[out]};\n"
# _c_code += 'printf("%%ld\\n", i);\n' # _c_code += 'printf("%%ld\\n", i);\n'
# for carry in range(1, 10): # for carry in range(1, 10):
...@@ -309,6 +303,10 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -309,6 +303,10 @@ class ScalarLoop(ScalarInnerGraphOp):
# End of the loop # End of the loop
_c_code += "}\n" _c_code += "}\n"
# Assign the carry variables to the outputs
for out, carry in zip(out_subd.values(), carry_subd.values(), strict=True):
_c_code += f"{out} = {carry};\n"
# Output until flag # Output until flag
if self.is_while: if self.is_while:
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n" _c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
...@@ -343,4 +341,4 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -343,4 +341,4 @@ class ScalarLoop(ScalarInnerGraphOp):
return res return res
def c_code_cache_version_outer(self): def c_code_cache_version_outer(self):
return (3,) return (4,)
...@@ -24,7 +24,6 @@ from pytensor.graph.rewriting.basic import ( ...@@ -24,7 +24,6 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
...@@ -74,15 +73,6 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -74,15 +73,6 @@ class InplaceElemwiseOptimizer(GraphRewriter):
for n in sorted(ndim): for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream) print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node):
# TODO: Implement specialized InplaceCompositeOptimizer with logic
# needed to correctly assign inplace for multi-output Composites
# and ScalarLoops
if isinstance(node.op.scalar_op, ScalarLoop):
return []
else:
return range(len(node.outputs))
def apply(self, fgraph): def apply(self, fgraph):
r""" r"""
...@@ -173,7 +163,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -173,7 +163,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
baseline = op.inplace_pattern baseline = op.inplace_pattern
candidate_outputs = [ candidate_outputs = [
i for i in self.candidate_input_idxs(node) if i not in baseline i for i in range(len(node.outputs)) if i not in baseline
] ]
# node inputs that are Constant, already destroyed, # node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as # or fgraph protected inputs and fgraph outputs can't be used as
...@@ -190,7 +180,7 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -190,7 +180,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
] ]
else: else:
baseline = [] baseline = []
candidate_outputs = self.candidate_input_idxs(node) candidate_outputs = range(len(node.outputs))
# node inputs that are Constant, already destroyed, # node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace # fgraph protected inputs and fgraph outputs can't be used as inplace
# target. # target.
......
...@@ -3,7 +3,8 @@ import re ...@@ -3,7 +3,8 @@ import re
import numpy as np import numpy as np
import pytest import pytest
from pytensor import Mode, function from pytensor import In, Mode, function
from pytensor.compile import get_default_mode
from pytensor.scalar import ( from pytensor.scalar import (
Composite, Composite,
as_scalar, as_scalar,
...@@ -18,6 +19,8 @@ from pytensor.scalar import ( ...@@ -18,6 +19,8 @@ from pytensor.scalar import (
) )
from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import exp as tensor_exp from pytensor.tensor import exp as tensor_exp
from pytensor.tensor import lvector
from pytensor.tensor.elemwise import Elemwise
mode = pytest.mark.parametrize( mode = pytest.mark.parametrize(
...@@ -255,3 +258,46 @@ def test_inner_loop(mode): ...@@ -255,3 +258,46 @@ def test_inner_loop(mode):
out16, out16,
3**2 + 2.5, 3**2 + 2.5,
) )
@pytest.mark.parametrize("mutate_arg_idx", (0, 1, 2, 3))
def test_elemwise_inplace(mutate_arg_idx):
x0 = int64("x0")
y0 = int64("y0")
c = int64("c")
x = x0 - y0 + c
y = y0 - x0 + c
op = Elemwise(ScalarLoop(init=[x0, y0], constant=[c], update=[x, y]))
n_steps = lvector("n_steps")
x0v = lvector("x0")
y0v = lvector("y0")
cv = lvector("c")
xv, yv = op(n_steps, x0v, y0v, cv)
inputs = [
In(inp, mutable=i == mutate_arg_idx)
for i, inp in enumerate([n_steps, x0v, y0v, cv])
]
fn = function(
inputs,
[xv, yv],
mode=get_default_mode().including("inplace"),
)
fn.dprint()
elem_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(elem_op, Elemwise) and isinstance(elem_op.scalar_op, ScalarLoop)
destroy_map = elem_op.destroy_map
assert destroy_map == {0: [mutate_arg_idx]}
n_test = np.array([1, 4, 8], dtype="int64")
x0v_test = np.array([0, 0, 0], dtype="int64")
y0v_test = np.array([1, 1, 1], dtype="int64")
cv_test = np.array([0, 0, 0], dtype="int64")
xv_res, yv_res = fn(n_test, x0v_test, y0v_test, cv_test)
# Check the outputs are the destroyed inputs
assert xv_res is (n_test, x0v_test, y0v_test, cv_test)[mutate_arg_idx]
np.testing.assert_allclose(xv_res, [-1, -8, -128])
np.testing.assert_allclose(yv_res, [1, 8, 128])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论