提交 62cc36f7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up and fix FunctionGraph cloning

This adds the missing `update_mapping`s step during cloning, as well as a new `Feature` cloning step that prevents issues when features are copied to their clones.
上级 712c53a6
......@@ -2,6 +2,7 @@
import time
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
......@@ -26,6 +27,9 @@ from aesara.graph.utils import MetaObject, MissingInputError, TestValueError
from aesara.misc.ordered_set import OrderedSet
if TYPE_CHECKING:
from aesara.graph.op import Op
ApplyOrOutput = Union[Apply, Literal["output"]]
ClientType = Tuple[ApplyOrOutput, int]
......@@ -69,9 +73,7 @@ class FunctionGraph(MetaObject):
features: Optional[Sequence[Feature]] = None,
clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict] = None,
copy_inputs: bool = True,
copy_orphans: bool = True,
**clone_kwds,
):
"""
Create a `FunctionGraph` which operates on the subgraph between the
......@@ -83,19 +85,15 @@ class FunctionGraph(MetaObject):
Input variables of the graph.
outputs
Output variables of the graph.
clone
If ``True``, the graph will be cloned.
features
A list of features to be added to the `FunctionGraph`.
clone
If ``True``, the graph will be cloned.
update_mapping
Mapping between the `inputs` with updates and the `outputs`
corresponding to their updates.
memo
See :func:`aesara.graph.basic.clone_get_equiv`.
copy_inputs
See :func:`aesara.graph.basic.clone_get_equiv`.
copy_orphans
See :func:`aesara.graph.basic.clone_get_equiv`.
clone_kwds
Keywords passed to `clone_get_equiv` when `clone` is ``True``.
"""
if outputs is None:
raise ValueError("No outputs specified")
......@@ -109,9 +107,7 @@ class FunctionGraph(MetaObject):
_memo = clone_get_equiv(
inputs,
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
memo=memo,
**clone_kwds,
)
outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs]
......@@ -868,33 +864,36 @@ class FunctionGraph(MetaObject):
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict]:
self, check_integrity: bool = True, attach_feature: bool = True, **kwargs
) -> Tuple[
"FunctionGraph",
Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]],
]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters
----------
check_integrity
Whether to check integrity.
Whether or not to check the resulting graph's integrity.
attach_feature
Whether to attach feature of origin graph to cloned graph.
Whether or not to attach `self`'s features to the cloned graph.
Returns
-------
e
Cloned fgraph. Every node in cloned graph is cloned.
The cloned `FunctionGraph`. Every node in the cloned graph is cloned.
equiv
A ``dict`` that maps old nodes to the new nodes.
"""
equiv = clone_get_equiv(self.inputs, self.outputs)
equiv = clone_get_equiv(self.inputs, self.outputs, **kwargs)
if check_integrity:
self.check_integrity()
e = FunctionGraph(
[cast(Variable, equiv[i]) for i in self.inputs],
[cast(Variable, equiv[o]) for o in self.outputs],
clone=False,
update_mapping=self.update_mapping,
)
if check_integrity:
e.check_integrity()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论