Unverified 提交 14d2454c authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Unify signatures of `graph_replace` and `clone_replace` (#398)

* more type hints
上级 e9a7d7ce
......@@ -4,7 +4,7 @@ Provide a simple user friendly API.
"""
from copy import copy
from typing import Optional
from typing import Optional, Sequence, Union, overload
from pytensor.compile.function.types import Function, UnusedInputError, orig_function
from pytensor.compile.io import In, Out
......@@ -15,8 +15,9 @@ from pytensor.graph.basic import Constant, Variable, clone_node_and_cache
from pytensor.graph.fg import FunctionGraph
@overload
def rebuild_collect_shared(
outputs,
outputs: Variable,
inputs=None,
replace=None,
updates=None,
......@@ -24,7 +25,107 @@ def rebuild_collect_shared(
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
):
) -> tuple[
list[Variable],
Variable,
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
@overload
def rebuild_collect_shared(
outputs: Sequence[Variable],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
list[Variable],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
@overload
def rebuild_collect_shared(
outputs: Out,
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
Out,
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
@overload
def rebuild_collect_shared(
outputs: Sequence[Out],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
list[Out],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
def rebuild_collect_shared(
outputs: Union[Sequence[Variable], Variable, Out, Sequence[Out]],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
Union[list[Variable], Variable, Out, list[Out]],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?)
......@@ -260,7 +361,7 @@ def rebuild_collect_shared(
return (
input_variables,
cloned_outputs,
[clone_d, update_d, update_expr, shared_inputs],
(clone_d, update_d, update_expr, shared_inputs),
)
......
from functools import partial
from typing import (
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
from typing import Iterable, Optional, Sequence, Union, cast, overload
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
def _format_replace(replace: Optional[ReplaceTypes] = None) -> dict[Variable, Variable]:
items: dict[Variable, Variable]
if isinstance(replace, dict):
# PyLance has issues with type resolution
items = cast(dict[Variable, Variable], replace)
elif isinstance(replace, Iterable):
items = dict(replace)
elif replace is None:
items = {}
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
return items
@overload
def clone_replace(
output: Sequence[Variable],
replace: Optional[ReplaceTypes] = None,
**rebuild_kwds,
) -> list[Variable]:
...
@overload
def clone_replace(
output: Collection[Variable],
output: Variable,
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
) -> Variable:
...
def clone_replace(
output: Union[Sequence[Variable], Variable],
replace: Optional[ReplaceTypes] = None,
**rebuild_kwds,
) -> Union[list[Variable], Variable]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
......@@ -39,19 +68,8 @@ def clone_replace(
"""
from pytensor.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
items = list(_format_replace(replace).items())
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
......@@ -59,20 +77,40 @@ def clone_replace(
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
return outs
@overload
def graph_replace(
outputs: Variable,
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> Variable:
...
@overload
def graph_replace(
outputs: Sequence[Variable],
replace: Dict[Variable, Variable],
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> list[Variable]:
...
def graph_replace(
outputs: Union[Sequence[Variable], Variable],
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> List[Variable]:
) -> Union[list[Variable], Variable]:
"""Replace variables in ``outputs`` by ``replace``.
Parameters
----------
outputs: Sequence[Variable]
outputs: Union[Sequence[Variable], Variable]
Output graph
replace: Dict[Variable, Variable]
Replace mapping
......@@ -83,20 +121,26 @@ def graph_replace(
Returns
-------
List[Variable]
Output graph with subgraphs replaced
Union[Variable, List[Variable]]
Output graph with subgraphs replaced, see function overload for the exact type
Raises
------
ValueError
If some replacemens could not be applied and strict is True
If some replacements could not be applied and strict is True
"""
as_list = False
if not isinstance(outputs, Sequence):
outputs = [outputs]
else:
as_list = True
replace_dict = _format_replace(replace)
# collect minimum graph inputs which is required to compute outputs
# and depend on replacements
# additionally remove constants, they do not matter in clone get equiv
conditions = [
c
for c in truncated_graph_inputs(outputs, replace)
for c in truncated_graph_inputs(outputs, replace_dict)
if not isinstance(c, Constant)
]
# for the function graph we need the clean graph where
......@@ -117,7 +161,7 @@ def graph_replace(
# replace the conditions back
fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv})
fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly
......@@ -126,12 +170,14 @@ def graph_replace(
# So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph
if strict:
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv}
non_fg_replace = {r: v for r, v in replace_dict.items() if r not in equiv}
if non_fg_replace:
raise ValueError(f"Some replacements were not used: {non_fg_replace}")
toposort = fg.toposort()
def toposort_key(fg: FunctionGraph, ts, pair):
def toposort_key(
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable]
) -> int:
key, _ = pair
if key.owner is not None:
return ts.index(key.owner)
......@@ -148,4 +194,7 @@ def graph_replace(
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
return list(fg.outputs)
if as_list:
return list(fg.outputs)
else:
return fg.outputs[0]
......@@ -169,6 +169,17 @@ class TestGraphReplace:
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
def test_non_list_input(self):
x = MyVariable("x")
y = MyVariable("y")
o = MyOp("xyop")(x, y)
new_x = x.clone(name="x_new")
new_y = y.clone(name="y2_new")
# test non list inputs as well
oc = graph_replace(o, {x: new_x, y: new_y})
assert oc.owner.inputs[1] is new_y
assert oc.owner.inputs[0] is new_x
def test_graph_replace_advanced(self):
x = MyVariable("x")
y = MyVariable("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论