提交 62cc36f7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up and fix FunctionGraph cloning

This adds the missing `update_mapping`s step during cloning, as well as a new `Feature` cloning step that prevents issues when features are copied to their clones.
上级 712c53a6
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import time import time
from collections import OrderedDict from collections import OrderedDict
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Dict, Dict,
Iterable, Iterable,
...@@ -26,6 +27,9 @@ from aesara.graph.utils import MetaObject, MissingInputError, TestValueError ...@@ -26,6 +27,9 @@ from aesara.graph.utils import MetaObject, MissingInputError, TestValueError
from aesara.misc.ordered_set import OrderedSet from aesara.misc.ordered_set import OrderedSet
if TYPE_CHECKING:
from aesara.graph.op import Op
ApplyOrOutput = Union[Apply, Literal["output"]] ApplyOrOutput = Union[Apply, Literal["output"]]
ClientType = Tuple[ApplyOrOutput, int] ClientType = Tuple[ApplyOrOutput, int]
...@@ -69,9 +73,7 @@ class FunctionGraph(MetaObject): ...@@ -69,9 +73,7 @@ class FunctionGraph(MetaObject):
features: Optional[Sequence[Feature]] = None, features: Optional[Sequence[Feature]] = None,
clone: bool = True, clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None, update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict] = None, **clone_kwds,
copy_inputs: bool = True,
copy_orphans: bool = True,
): ):
""" """
Create a `FunctionGraph` which operates on the subgraph between the Create a `FunctionGraph` which operates on the subgraph between the
...@@ -83,19 +85,15 @@ class FunctionGraph(MetaObject): ...@@ -83,19 +85,15 @@ class FunctionGraph(MetaObject):
Input variables of the graph. Input variables of the graph.
outputs outputs
Output variables of the graph. Output variables of the graph.
clone
If ``True``, the graph will be cloned.
features features
A list of features to be added to the `FunctionGraph`. A list of features to be added to the `FunctionGraph`.
clone
If ``True``, the graph will be cloned.
update_mapping update_mapping
Mapping between the `inputs` with updates and the `outputs` Mapping between the `inputs` with updates and the `outputs`
corresponding to their updates. corresponding to their updates.
memo clone_kwds
See :func:`aesara.graph.basic.clone_get_equiv`. Keywords passed to `clone_get_equiv` when `clone` is ``True``.
copy_inputs
See :func:`aesara.graph.basic.clone_get_equiv`.
copy_orphans
See :func:`aesara.graph.basic.clone_get_equiv`.
""" """
if outputs is None: if outputs is None:
raise ValueError("No outputs specified") raise ValueError("No outputs specified")
...@@ -109,9 +107,7 @@ class FunctionGraph(MetaObject): ...@@ -109,9 +107,7 @@ class FunctionGraph(MetaObject):
_memo = clone_get_equiv( _memo = clone_get_equiv(
inputs, inputs,
outputs, outputs,
copy_inputs=copy_inputs, **clone_kwds,
copy_orphans=copy_orphans,
memo=memo,
) )
outputs = [cast(Variable, _memo[o]) for o in outputs] outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs] inputs = [cast(Variable, _memo[i]) for i in inputs]
...@@ -868,33 +864,36 @@ class FunctionGraph(MetaObject): ...@@ -868,33 +864,36 @@ class FunctionGraph(MetaObject):
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv( def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True self, check_integrity: bool = True, attach_feature: bool = True, **kwargs
) -> Tuple["FunctionGraph", Dict]: ) -> Tuple[
"FunctionGraph",
Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]],
]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes. """Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters Parameters
---------- ----------
check_integrity check_integrity
Whether to check integrity. Whether or not to check the resulting graph's integrity.
attach_feature attach_feature
Whether to attach feature of origin graph to cloned graph. Whether or not to attach `self`'s features to the cloned graph.
Returns Returns
------- -------
e e
Cloned fgraph. Every node in cloned graph is cloned. The cloned `FunctionGraph`. Every node in the cloned graph is cloned.
equiv equiv
A ``dict`` that maps old nodes to the new nodes. A ``dict`` that maps old nodes to the new nodes.
""" """
equiv = clone_get_equiv(self.inputs, self.outputs) equiv = clone_get_equiv(self.inputs, self.outputs, **kwargs)
if check_integrity:
self.check_integrity()
e = FunctionGraph( e = FunctionGraph(
[cast(Variable, equiv[i]) for i in self.inputs], [cast(Variable, equiv[i]) for i in self.inputs],
[cast(Variable, equiv[o]) for o in self.outputs], [cast(Variable, equiv[o]) for o in self.outputs],
clone=False, clone=False,
update_mapping=self.update_mapping,
) )
if check_integrity: if check_integrity:
e.check_integrity() e.check_integrity()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论