提交 cd93444e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement Scalar Loop Op

上级 215cecd4
...@@ -1115,12 +1115,12 @@ def truncated_graph_inputs( ...@@ -1115,12 +1115,12 @@ def truncated_graph_inputs(
def clone( def clone(
inputs: List[Variable], inputs: Sequence[Variable],
outputs: List[Variable], outputs: Sequence[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: Optional[bool] = None, copy_orphans: Optional[bool] = None,
clone_inner_graphs: bool = False, clone_inner_graphs: bool = False,
) -> Tuple[Collection[Variable], Collection[Variable]]: ) -> Tuple[List[Variable], List[Variable]]:
r"""Copies the sub-graph contained between inputs and outputs. r"""Copies the sub-graph contained between inputs and outputs.
Parameters Parameters
......
import warnings
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Literal, Optional, Sequence, Tuple
from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone
from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.scalar.basic import ScalarInnerGraphOp, ScalarOp, as_scalar
class ScalarLoop(ScalarInnerGraphOp):
"""Scalar Op that encapsulates a scalar loop operation.
This Op can be used for the gradient of other Scalar Ops.
It is much more restricted that `Scan` in that the entire inner graph must be composed of Scalar operations.
"""
init_param: Tuple[str, ...] = (
"init",
"update",
"constant",
"until",
"until_condition_failed",
)
def __init__(
self,
init: Sequence[Variable],
update: Sequence[Variable],
constant: Optional[Sequence[Variable]] = None,
until: Optional[Variable] = None,
until_condition_failed: Literal["ignore", "warn", "raise"] = "warn",
name="ScalarLoop",
):
if until_condition_failed not in ["ignore", "warn", "raise"]:
raise ValueError(
f"Invalid until_condition_failed: {until_condition_failed}"
)
if constant is None:
constant = []
if not len(init) == len(update):
raise ValueError("An update must be given for each init variable")
if until:
inputs, (*outputs, until) = clone([*init, *constant], [*update, until])
self.outputs = copy([*outputs, until])
else:
inputs, outputs = clone([*init, *constant], update)
self.outputs = copy(outputs)
self.inputs = copy(inputs)
self.inputs_type = tuple(input.type for input in inputs)
self.outputs_type = tuple(output.type for output in outputs)
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph
self.nout = len(outputs) # until is not output
self.is_while = bool(until)
self.until_condition_failed = until_condition_failed
self.name = name
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
super().__init__()
def output_types(self, input_types):
return self.outputs_type
def _validate_fgraph(self, fgraph: FunctionGraph) -> None:
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
"The fgraph of ScalarLoop must be composed exclusively of ScalarOp nodes"
)
init = fgraph.inputs
update = fgraph.outputs
if self.is_while:
*update, until = update
if not until.type.dtype == "bool":
raise TypeError(
f"Until condition must be boolean, got {until}({until.type.dtype})"
)
for i, u in zip(init, update):
if i.type != u.type:
raise TypeError(
"Init and update types must be the same: "
f"{i}({i.type}) != {u}({u.type})"
)
if set(init) & set(update):
raise ValueError(
"Some inputs and outputs are the same variable. "
"If you want to return an output as a lagged input, wrap it in an identity Op."
)
@property
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
fgraph = FunctionGraph(self.inputs, self.outputs)
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
MergeOptimizer().rewrite(fgraph)
self._validate_fgraph(fgraph)
# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
self._fgraph = fgraph
return self._fgraph
def clone(self):
if self.is_while:
*update, until = self.outputs
else:
update, until = self.outputs, None
init = self.inputs[: len(update)]
constant = self.inputs[len(update) :]
return ScalarLoop(
init=init,
update=update,
constant=constant,
until=until,
until_condition_failed=self.until_condition_failed,
name=self.name,
)
@property
def fn(self):
raise NotImplementedError
def make_new_inplace(self, output_types_preference=None, name=None):
"""
This op.__init__ fct don't have the same parameter as other scalar op.
This break the insert_inplace_optimizer optimization.
This fct allow fix patch this.
"""
d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d)
if name:
out.name = name
else:
name = out.name
super(ScalarLoop, out).__init__(output_types_preference, name)
return out
def make_node(self, n_steps, *inputs):
assert len(inputs) == self.nin - 1
n_steps = as_scalar(n_steps)
if not n_steps.type.dtype.startswith("int"):
raise TypeError(
"The first variable of ScalarLoop (n_steps) must be of integer type. "
f"Got {n_steps.type.dtype}",
)
if self.inputs_type == tuple([i.type for i in inputs]):
return super().make_node(n_steps, *inputs)
else:
# Make a new op with the right input types.
res = rebuild_collect_shared(
self.outputs,
replace=dict(zip(self.inputs, inputs)),
rebuild_strict=False,
)
if self.is_while:
*cloned_update, cloned_until = res[1]
else:
cloned_update, cloned_until = res[1], None
cloned_inputs = [res[2][0][i] for i in inputs]
cloned_init = cloned_inputs[: len(cloned_update)]
cloned_constant = cloned_inputs[len(cloned_update) :]
# This will fail if the cloned init have a different dtype than the cloned_update
op = ScalarLoop(
init=cloned_init,
update=cloned_update,
constant=cloned_constant,
until=cloned_until,
until_condition_failed=self.until_condition_failed,
name=self.name,
)
node = op.make_node(n_steps, *inputs)
return node
def perform(self, node, inputs, output_storage):
n_steps, *inputs = inputs
n_update = len(self.outputs) - (1 if self.is_while else 0)
carry, constant = inputs[:n_update], inputs[n_update:]
inner_fn = self.py_perform_fn
if self.is_while:
until = True
for i in range(n_steps):
*carry, until = inner_fn(*carry, *constant)
if until:
break
if not until: # no-break
if self.until_condition_failed == "raise":
raise RuntimeError(
f"Until condition in ScalarLoop {self.name} not reached!"
)
elif self.until_condition_failed == "warn":
warnings.warn(
f"Until condition in ScalarLoop {self.name} not reached!",
RuntimeWarning,
)
else:
if n_steps < 0:
raise ValueError("ScalarLoop does not have a termination condition.")
for i in range(n_steps):
carry = inner_fn(*carry, *constant)
for storage, out_val in zip(output_storage, carry):
storage[0] = out_val
@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
if hasattr(self, "_c_code"):
return self._c_code
fgraph = self.fgraph
# The first input is `n_steps` so we skip it in the mapping dictionary
n_update = len(self.outputs) - (1 if self.is_while else 0)
carry_subd = {
c: f"%(i{int(i)})s" for i, c in enumerate(fgraph.inputs[:n_update], start=1)
}
constant_subd = {
c: f"%(i{int(i)})s"
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
}
update_subd = {
u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update])
}
until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
subd = {**carry_subd, **constant_subd, **update_subd, **until_subd}
for var in fgraph.variables:
if var.owner is None:
if var not in self.fgraph.inputs:
# This is an orphan
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
subd[var] = var.type.c_literal(var.data)
else:
raise ValueError(
"All orphans in the fgraph to ScalarLoop must"
" be Constant, CLinkerType instances."
)
elif any(i.dtype == "float16" for i in var.owner.inputs) or any(
o.dtype == "float16" for o in var.owner.outputs
):
# flag for elemwise ops to check.
self.inner_float16 = True
_c_code = "{\n"
if self.is_while:
_c_code += "bool until = 1;\n\n"
# Copy carried inputs
for i, (var, name) in enumerate(carry_subd.items()):
copy_var_name = f"{name}_copy{i}"
_c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n"
carry_subd[var] = copy_var_name
subd[var] = copy_var_name
# _c_code += 'printf("inputs=[");'
# for i in range(1, len(fgraph.inputs)):
# _c_code += f'printf("%%.16g, ", %(i{i})s);'
# _c_code += 'printf("]\\n");\n'
_c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n"
self.nodenames = [
f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort())
]
i = 0
for j, node in enumerate(fgraph.toposort()):
for output in node.outputs:
if output not in subd:
i += 1
name = f"V%(id)s_tmp{int(i)}"
subd[output] = name
_c_code += f"{output.type.dtype_specs()[1]} {name};\n"
s = node.op.c_code(
node,
self.nodenames[j],
# Any node that depended on `init` will depend on `update` instead
# The initial value of `update` was set to `init` before the loop
[subd[input] for input in node.inputs],
[subd[output] for output in node.outputs],
dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"),
)
_c_code += s
_c_code += "\n"
# Set the carry variables to the output variables
_c_code += "\n"
for init, update in zip(carry_subd.values(), update_subd.values()):
_c_code += f"{init} = {update};\n"
# _c_code += 'printf("%%ld\\n", i);\n'
# for carry in range(1, 10):
# _c_code += f'printf("\\t %%.g\\n", i, %(i{carry})s_copy{carry-1});\n'
if self.is_while:
_c_code += "\nif(until){break;}\n"
_c_code += "}\n"
# End of the loop
if self.is_while:
if self.until_condition_failed == "raise":
_c_code += dedent(
f"""
if (!until) {{
PyErr_SetString(PyExc_RuntimeError, "Until condition in ScalarLoop {self.name} not reached!");
%(fail)s
}}
"""
)
elif self.until_condition_failed == "warn":
_c_code += dedent(
f"""
if (!until) {{
PyErr_WarnEx(PyExc_RuntimeWarning, "Until condition in ScalarLoop {self.name} not reached!", 1);
}}
"""
)
_c_code += "}\n"
self._c_code = _c_code
return self._c_code
def c_code(self, node, nodename, inames, onames, sub):
d = dict(
chain(
zip((f"i{int(i)}" for i in range(len(inames))), inames),
zip((f"o{int(i)}" for i in range(len(onames))), onames),
),
**sub,
)
d["nodename"] = nodename
if "id" not in sub:
# The use of a dummy id is safe as the code is in a separate block.
# It won't generate conflicting variable name.
d["id"] = "_DUMMY_ID_"
# When called inside Elemwise we don't have access to the dtype
# via the usual `f"dtype_{inames[i]}"` variable
d["n_steps"] = inames[0]
d["n_steps_dtype"] = "npy_" + node.inputs[0].dtype
res = self.c_code_template % d
# print(res)
return res
def c_code_cache_version_outer(self):
return (1,)
def __eq__(self, other):
return (
super().__eq__(other)
and self.until_condition_failed == other.until_condition_failed
)
def __hash__(self):
return hash((super().__hash__(), self.until_condition_failed))
...@@ -22,6 +22,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -22,6 +22,7 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc, alloc,
...@@ -66,9 +67,12 @@ class InplaceElemwiseOptimizer(GraphRewriter): ...@@ -66,9 +67,12 @@ class InplaceElemwiseOptimizer(GraphRewriter):
print(blanc, n, ndim[n], file=stream) print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node): def candidate_input_idxs(self, node):
if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1: # TODO: Implement specialized InplaceCompositeOptimizer with logic
# TODO: Implement specialized InplaceCompositeOptimizer with logic # needed to correctly assign inplace for multi-output Composites
# needed to correctly assign inplace for multi-output Composites # and ScalarLoops
if isinstance(node.op.scalar_op, ScalarLoop):
return []
if isinstance(node.op.scalar_op, aes.Composite) and (len(node.outputs) > 1):
return [] return []
else: else:
return range(len(node.outputs)) return range(len(node.outputs))
......
import re
import numpy as np
import pytest
from pytensor import Mode, function
from pytensor.scalar import (
Composite,
as_scalar,
cos,
exp,
float16,
float32,
float64,
identity,
int64,
sin,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import exp as tensor_exp
mode = pytest.mark.parametrize(
"mode",
[
Mode(optimizer="fast_compile", linker="py"),
Mode(optimizer="fast_compile", linker="cvm"),
],
)
@mode
def test_single_output(mode):
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)
fn = function([n_steps, x0, const], x, mode=mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)
@mode
def test_multiple_output(mode):
n_steps = int64("n_steps")
x0 = float64("x0")
y0 = int64("y0")
const = float64("const")
x = x0 + const
y = y0 + 1
op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y])
x, y = op(n_steps, x0, y0, const)
fn = function([n_steps, x0, y0, const], [x, y], mode=mode)
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1)
np.testing.assert_allclose(res_x, 5)
np.testing.assert_allclose(res_y, 5)
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2)
np.testing.assert_allclose(res_x, 10)
np.testing.assert_allclose(res_y, 5)
res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1)
np.testing.assert_allclose(res_x, -1)
np.testing.assert_allclose(res_y, 6)
@mode
def test_input_not_aliased_to_update(mode):
n_steps = int64("n_steps")
x0 = float64("x0")
y0 = float64("y0")
const = float64("const")
def update(x_prev, y_prev):
x = x_prev + const
# y depends on x_prev, so x_prev should not be overriden by x!
y = y_prev + x_prev
return [x, y]
op = ScalarLoop(init=[x0, y0], constant=[const], update=update(x0, y0))
x, y = op(n_steps, x0, y0, const)
fn = function([n_steps, x0, y0, const], y, mode=mode)
np.testing.assert_allclose(fn(n_steps=1, x0=0, y0=0, const=1), 0.0)
np.testing.assert_allclose(fn(n_steps=2, x0=0, y0=0, const=1), 1.0)
np.testing.assert_allclose(fn(n_steps=3, x0=0, y0=0, const=1), 3.0)
np.testing.assert_allclose(fn(n_steps=4, x0=0, y0=0, const=1), 6.0)
np.testing.assert_allclose(fn(n_steps=5, x0=0, y0=0, const=1), 10.0)
@mode
def test_until(mode):
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10
op = ScalarLoop(init=[x0], update=[x], until=until, until_condition_failed="ignore")
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
np.testing.assert_allclose(fn(n_steps=5, x0=1), 6)
op = ScalarLoop(
init=[x0],
update=[x],
until=until,
until_condition_failed="warn",
name="TestLoop",
)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
with pytest.warns(
RuntimeWarning, match="Until condition in ScalarLoop TestLoop not reached!"
):
np.testing.assert_allclose(fn(n_steps=5, x0=1), 6)
op = ScalarLoop(
init=[x0],
update=[x],
until=until,
until_condition_failed="raise",
name="TestLoop",
)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
with pytest.raises(
RuntimeError, match="Until condition in ScalarLoop TestLoop not reached!"
):
fn(n_steps=5, x0=1)
def test_update_missing_error():
x0 = float64("x0")
const = float64("const")
with pytest.raises(
ValueError, match="An update must be given for each init variable"
):
ScalarLoop(init=[x0], constant=[const], update=[])
def test_init_update_type_error():
x0 = float32("x0")
const = float64("const")
x = x0 + const
assert x.type.dtype == "float64"
with pytest.raises(TypeError, match="Init and update types must be the same"):
ScalarLoop(init=[x0], constant=[const], update=[x])
def test_rebuild_dtype():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
# If x0 is float32 but const is still float64, the output type will not be able to match
x0_float32 = float32("x0_float32")
with pytest.raises(TypeError, match="Init and update types must be the same"):
op(n_steps, x0_float32, const)
# Now it should be fine
const_float32 = float32("const_float32")
y = op(n_steps, x0_float32, const_float32)
assert y.dtype == "float32"
def test_non_scalar_error():
x0 = float64("x0")
x = as_scalar(tensor_exp(x0))
with pytest.raises(
TypeError, match="must be composed exclusively of ScalarOp nodes"
):
ScalarLoop(init=[x0], constant=[], update=[x])
def test_n_steps_type_error():
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
with pytest.raises(
TypeError, match=re.escape("(n_steps) must be of integer type. Got float64")
):
op(float64("n_steps"), x0, const)
def test_same_out_as_inp_error():
xtm2 = float64("xtm2")
xtm1 = float64("xtm1")
x = xtm2 + xtm1
with pytest.raises(
ValueError, match="Some inputs and outputs are the same variable"
):
ScalarLoop(init=[xtm2, xtm1], update=[xtm1, x])
@mode
def test_lags(mode):
n_steps = int64("n_steps")
xtm2 = float64("xtm2")
xtm1 = float64("xtm1")
x = xtm2 + xtm1
op = ScalarLoop(init=[xtm2, xtm1], update=[identity(xtm1), x])
_, x = op(n_steps, xtm2, xtm1)
fn = function([n_steps, xtm2, xtm1], x, mode=mode)
np.testing.assert_allclose(fn(n_steps=5, xtm2=0, xtm1=1), 8)
@mode
def test_inner_composite(mode):
n_steps = int64("n_steps")
x = float64("x")
one = Composite([x], [cos(exp(x)) ** 2 + sin(exp(x)) ** 2])(x)
op = ScalarLoop(init=[x], update=[one + x])
y = op(n_steps, x)
fn = function([n_steps, x], y, mode=mode)
np.testing.assert_allclose(fn(n_steps=5, x=2.53), 2.53 + 5)
# Now with a dtype that must be rebuilt
x16 = float16("x16")
y16 = op(n_steps, x16)
assert y16.type.dtype == "float16"
fn32 = function([n_steps, x16], y16, mode=mode)
np.testing.assert_allclose(
fn32(n_steps=9, x16=np.array(4.73, dtype="float16")),
4.73 + 9,
rtol=1e-3,
)
@mode
def test_inner_loop(mode):
n_steps = int64("n_steps")
x = float64("x")
x_in = float64("x_in")
inner_loop_op = ScalarLoop(init=[x_in], update=[x_in + 1])
outer_loop_op = ScalarLoop(
init=[x], update=[inner_loop_op(n_steps, x)], constant=[n_steps]
)
y = outer_loop_op(n_steps, x, n_steps)
fn = function([n_steps, x], y, mode=mode)
np.testing.assert_allclose(fn(n_steps=5, x=0), 5**2)
np.testing.assert_allclose(fn(n_steps=7, x=0), 7**2)
np.testing.assert_allclose(fn(n_steps=7, x=1), 7**2 + 1)
# Now with a dtype that must be rebuilt
x16 = float16("x16")
y16 = outer_loop_op(n_steps, x16, n_steps)
assert y16.type.dtype == "float16"
fn32 = function([n_steps, x16], y16, mode=mode)
np.testing.assert_allclose(
fn32(n_steps=3, x16=np.array(2.5, dtype="float16")),
3**2 + 2.5,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论