Unverified 提交 14d2454c authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Unify signatures of `graph_replace` and `clone_replace` (#398)

* more type hints
上级 e9a7d7ce
...@@ -4,7 +4,7 @@ Provide a simple user friendly API. ...@@ -4,7 +4,7 @@ Provide a simple user friendly API.
""" """
from copy import copy from copy import copy
from typing import Optional from typing import Optional, Sequence, Union, overload
from pytensor.compile.function.types import Function, UnusedInputError, orig_function from pytensor.compile.function.types import Function, UnusedInputError, orig_function
from pytensor.compile.io import In, Out from pytensor.compile.io import In, Out
...@@ -15,8 +15,9 @@ from pytensor.graph.basic import Constant, Variable, clone_node_and_cache ...@@ -15,8 +15,9 @@ from pytensor.graph.basic import Constant, Variable, clone_node_and_cache
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
@overload
def rebuild_collect_shared( def rebuild_collect_shared(
outputs, outputs: Variable,
inputs=None, inputs=None,
replace=None, replace=None,
updates=None, updates=None,
...@@ -24,7 +25,107 @@ def rebuild_collect_shared( ...@@ -24,7 +25,107 @@ def rebuild_collect_shared(
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=False, no_default_updates=False,
clone_inner_graphs=False, clone_inner_graphs=False,
): ) -> tuple[
list[Variable],
Variable,
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
@overload
def rebuild_collect_shared(
outputs: Sequence[Variable],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
list[Variable],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
@overload
def rebuild_collect_shared(
outputs: Out,
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
Out,
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
@overload
def rebuild_collect_shared(
outputs: Sequence[Out],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
list[Out],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...
def rebuild_collect_shared(
outputs: Union[Sequence[Variable], Variable, Out, Sequence[Out]],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
Union[list[Variable], Variable, Out, list[Out]],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
r"""Replace subgraphs of a computational graph. r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?) It returns a set of dictionaries and lists which collect (partial?)
...@@ -260,7 +361,7 @@ def rebuild_collect_shared( ...@@ -260,7 +361,7 @@ def rebuild_collect_shared(
return ( return (
input_variables, input_variables,
cloned_outputs, cloned_outputs,
[clone_d, update_d, update_expr, shared_inputs], (clone_d, update_d, update_expr, shared_inputs),
) )
......
from functools import partial from functools import partial
from typing import ( from typing import Iterable, Optional, Sequence, Union, cast, overload
Collection,
Dict, from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
def _format_replace(replace: Optional[ReplaceTypes] = None) -> dict[Variable, Variable]:
items: dict[Variable, Variable]
if isinstance(replace, dict):
# PyLance has issues with type resolution
items = cast(dict[Variable, Variable], replace)
elif isinstance(replace, Iterable):
items = dict(replace)
elif replace is None:
items = {}
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
return items
@overload
def clone_replace(
output: Sequence[Variable],
replace: Optional[ReplaceTypes] = None,
**rebuild_kwds,
) -> list[Variable]:
...
@overload
def clone_replace( def clone_replace(
output: Collection[Variable], output: Variable,
replace: Optional[ replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]] Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
] = None, ] = None,
**rebuild_kwds, **rebuild_kwds,
) -> List[Variable]: ) -> Variable:
...
def clone_replace(
output: Union[Sequence[Variable], Variable],
replace: Optional[ReplaceTypes] = None,
**rebuild_kwds,
) -> Union[list[Variable], Variable]:
"""Clone a graph and replace subgraphs within it. """Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding It returns a copy of the initial subgraph with the corresponding
...@@ -39,19 +68,8 @@ def clone_replace( ...@@ -39,19 +68,8 @@ def clone_replace(
""" """
from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.function.pfunc import rebuild_collect_shared
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]] items = list(_format_replace(replace).items())
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
tmp_replace = [(x, x.type()) for x, y in items] tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds) _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
...@@ -59,20 +77,40 @@ def clone_replace( ...@@ -59,20 +77,40 @@ def clone_replace(
# TODO Explain why we call it twice ?! # TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds) _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs) return outs
@overload
def graph_replace(
outputs: Variable,
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> Variable:
...
@overload
def graph_replace( def graph_replace(
outputs: Sequence[Variable], outputs: Sequence[Variable],
replace: Dict[Variable, Variable], replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> list[Variable]:
...
def graph_replace(
outputs: Union[Sequence[Variable], Variable],
replace: Optional[ReplaceTypes] = None,
*, *,
strict=True, strict=True,
) -> List[Variable]: ) -> Union[list[Variable], Variable]:
"""Replace variables in ``outputs`` by ``replace``. """Replace variables in ``outputs`` by ``replace``.
Parameters Parameters
---------- ----------
outputs: Sequence[Variable] outputs: Union[Sequence[Variable], Variable]
Output graph Output graph
replace: Dict[Variable, Variable] replace: Dict[Variable, Variable]
Replace mapping Replace mapping
...@@ -83,20 +121,26 @@ def graph_replace( ...@@ -83,20 +121,26 @@ def graph_replace(
Returns Returns
------- -------
List[Variable] Union[Variable, List[Variable]]
Output graph with subgraphs replaced Output graph with subgraphs replaced, see function overload for the exact type
Raises Raises
------ ------
ValueError ValueError
If some replacemens could not be applied and strict is True If some replacements could not be applied and strict is True
""" """
as_list = False
if not isinstance(outputs, Sequence):
outputs = [outputs]
else:
as_list = True
replace_dict = _format_replace(replace)
# collect minimum graph inputs which is required to compute outputs # collect minimum graph inputs which is required to compute outputs
# and depend on replacements # and depend on replacements
# additionally remove constants, they do not matter in clone get equiv # additionally remove constants, they do not matter in clone get equiv
conditions = [ conditions = [
c c
for c in truncated_graph_inputs(outputs, replace) for c in truncated_graph_inputs(outputs, replace_dict)
if not isinstance(c, Constant) if not isinstance(c, Constant)
] ]
# for the function graph we need the clean graph where # for the function graph we need the clean graph where
...@@ -117,7 +161,7 @@ def graph_replace( ...@@ -117,7 +161,7 @@ def graph_replace(
# replace the conditions back # replace the conditions back
fg_replace = {equiv[c]: c for c in conditions} fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings # add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv}) fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
# replacements have to be done in reverse topological order so that nested # replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly # expressions get recursively replaced correctly
...@@ -126,12 +170,14 @@ def graph_replace( ...@@ -126,12 +170,14 @@ def graph_replace(
# So far FunctionGraph does these replacements inplace it is thus unsafe # So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph # apply them using fg.replace, it may change the original graph
if strict: if strict:
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv} non_fg_replace = {r: v for r, v in replace_dict.items() if r not in equiv}
if non_fg_replace: if non_fg_replace:
raise ValueError(f"Some replacements were not used: {non_fg_replace}") raise ValueError(f"Some replacements were not used: {non_fg_replace}")
toposort = fg.toposort() toposort = fg.toposort()
def toposort_key(fg: FunctionGraph, ts, pair): def toposort_key(
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable]
) -> int:
key, _ = pair key, _ = pair
if key.owner is not None: if key.owner is not None:
return ts.index(key.owner) return ts.index(key.owner)
...@@ -148,4 +194,7 @@ def graph_replace( ...@@ -148,4 +194,7 @@ def graph_replace(
reverse=True, reverse=True,
) )
fg.replace_all(sorted_replacements, import_missing=True) fg.replace_all(sorted_replacements, import_missing=True)
if as_list:
return list(fg.outputs) return list(fg.outputs)
else:
return fg.outputs[0]
...@@ -169,6 +169,17 @@ class TestGraphReplace: ...@@ -169,6 +169,17 @@ class TestGraphReplace:
# the old reference is still kept # the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w assert oc.owner.inputs[0].owner.inputs[1] is w
def test_non_list_input(self):
x = MyVariable("x")
y = MyVariable("y")
o = MyOp("xyop")(x, y)
new_x = x.clone(name="x_new")
new_y = y.clone(name="y2_new")
# test non list inputs as well
oc = graph_replace(o, {x: new_x, y: new_y})
assert oc.owner.inputs[1] is new_y
assert oc.owner.inputs[0] is new_x
def test_graph_replace_advanced(self): def test_graph_replace_advanced(self):
x = MyVariable("x") x = MyVariable("x")
y = MyVariable("y") y = MyVariable("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论