提交 77396f7f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inner-graph cloning to Apply.clone and core cloning functions

上级 c752a8e3
...@@ -800,7 +800,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -800,7 +800,7 @@ class OpFromGraph(Op, HasInnerGraph):
# If the new shared variables are inconsistent with the inner-graph, # If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step # such errors should arise in this step
new_inner_outputs = clone_replace( new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True self.inner_outputs, replace=replace, copy_inputs_over=True
) )
# It's possible that the new shared variable inputs aren't actually # It's possible that the new shared variable inputs aren't actually
......
...@@ -12,7 +12,7 @@ from aesara.compile.io import In, Out ...@@ -12,7 +12,7 @@ from aesara.compile.io import In, Out
from aesara.compile.profiling import ProfileStats from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable, clone_node_and_cache
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -29,8 +29,9 @@ def rebuild_collect_shared( ...@@ -29,8 +29,9 @@ def rebuild_collect_shared(
rebuild_strict=True, rebuild_strict=True,
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=False, no_default_updates=False,
clone_inner_graphs=False,
): ):
"""Replace subgraphs of a computational graph. r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?) It returns a set of dictionaries and lists which collect (partial?)
different information about shared variables. This info is required by different information about shared variables. This info is required by
...@@ -59,6 +60,9 @@ def rebuild_collect_shared( ...@@ -59,6 +60,9 @@ def rebuild_collect_shared(
If False (default), perform them all. If False (default), perform them all.
Else, perform automatic updates on all Variables that are neither in Else, perform automatic updates on all Variables that are neither in
"updates" nor in "no_default_updates". "updates" nor in "no_default_updates".
clone_inner_graphs : bool
If ``True``, clone `Op`\s that are subclasses of `HasInnerGraph` and their
inner-graphs.
""" """
...@@ -89,13 +93,12 @@ def rebuild_collect_shared( ...@@ -89,13 +93,12 @@ def rebuild_collect_shared(
if owner not in clone_d: if owner not in clone_d:
for i in owner.inputs: for i in owner.inputs:
clone_v_get_shared_updates(i, copy_inputs_over) clone_v_get_shared_updates(i, copy_inputs_over)
clone_node_and_cache(
clone_d[owner] = owner.clone_with_new_inputs( owner,
[clone_d[i] for i in owner.inputs], strict=rebuild_strict clone_d,
strict=rebuild_strict,
clone_inner_graphs=clone_inner_graphs,
) )
for old_o, new_o in zip(owner.outputs, clone_d[owner].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d.setdefault(v, v) return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs:
...@@ -494,6 +497,7 @@ def construct_pfunc_ins_and_outs( ...@@ -494,6 +497,7 @@ def construct_pfunc_ins_and_outs(
rebuild_strict=rebuild_strict, rebuild_strict=rebuild_strict,
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=no_default_updates, no_default_updates=no_default_updates,
clone_inner_graphs=True,
) )
input_variables, cloned_extended_outputs, other_stuff = output_vars input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff clone_d, update_d, update_expr, shared_inputs = other_stuff
......
...@@ -41,6 +41,7 @@ from aesara.misc.ordered_set import OrderedSet ...@@ -41,6 +41,7 @@ from aesara.misc.ordered_set import OrderedSet
if TYPE_CHECKING: if TYPE_CHECKING:
from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -96,34 +97,28 @@ class Apply(Node): ...@@ -96,34 +97,28 @@ class Apply(Node):
Attributes Attributes
---------- ----------
op : Op op
The operation that produces `outputs` given `inputs`. The operation that produces `outputs` given `inputs`.
inputs : List[Variable] inputs
The arguments of the expression modeled by the `Apply` node. The arguments of the expression modeled by the `Apply` node.
outputs : List[Variable] outputs
The outputs of the expression modeled by the `Apply` node. The outputs of the expression modeled by the `Apply` node.
""" """
def __init__(self, op, inputs, outputs): def __init__(
""" self, op: "Op", inputs: Sequence["Variable"], outputs: Sequence["Variable"]
Parameters ):
----------
op : Op
inputs : List[Variable]
outputs : List[Variable]
"""
self.op = op
self.inputs: List[Variable] = []
self.tag = Scratchpad()
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
raise TypeError("The inputs of an Apply must be a list or tuple") raise TypeError("The inputs of an Apply must be a list or tuple")
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, (list, tuple)):
raise TypeError("The output of an Apply must be a list or tuple") raise TypeError("The output of an Apply must be a list or tuple")
self.op = op
self.inputs: List[Variable] = []
self.tag = Scratchpad()
# filter inputs to make sure each element is a Variable # filter inputs to make sure each element is a Variable
for input in inputs: for input in inputs:
if isinstance(input, Variable): if isinstance(input, Variable):
...@@ -202,28 +197,40 @@ class Apply(Node): ...@@ -202,28 +197,40 @@ class Apply(Node):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def clone(self): def clone(self, clone_inner_graph: bool = False) -> "Apply":
""" r"""Clone this `Apply` instance.
Duplicate this Apply instance with inputs = self.inputs.
Parameters
----------
clone_inner_graph
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
Returns Returns
------- -------
object A new `Apply` instance with new outputs.
A new Apply instance (or subclass instance) with new outputs.
Notes Notes
----- -----
Tags are copied from self to the returned instance. Tags are copied from `self` to the returned instance.
""" """
from aesara.graph.op import HasInnerGraph
new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
new_op = new_op.clone()
cp = self.__class__( cp = self.__class__(
self.op, self.inputs, [output.clone() for output in self.outputs] new_op, self.inputs, [output.clone() for output in self.outputs]
) )
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
def clone_with_new_inputs(self, inputs, strict=True): def clone_with_new_inputs(
"""Duplicate this `Apply` instance in a new graph. self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False
) -> "Apply":
r"""Duplicate this `Apply` instance in a new graph.
Parameters Parameters
---------- ----------
...@@ -238,6 +245,8 @@ class Apply(Node): ...@@ -238,6 +245,8 @@ class Apply(Node):
``self.outputs``. If ``False``, then there's no guarantee that the ``self.outputs``. If ``False``, then there's no guarantee that the
clone's outputs will have the same types as ``self.outputs``, clone's outputs will have the same types as ``self.outputs``,
and cloning may not even be possible (it depends on the `Op`). and cloning may not even be possible (it depends on the `Op`).
clone_inner_graph : bool
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
Returns Returns
------- -------
...@@ -245,9 +254,11 @@ class Apply(Node): ...@@ -245,9 +254,11 @@ class Apply(Node):
An `Apply` instance with the same `Op` but different outputs. An `Apply` instance with the same `Op` but different outputs.
""" """
from aesara.graph.op import HasInnerGraph
assert isinstance(inputs, (list, tuple)) assert isinstance(inputs, (list, tuple))
remake_node = False remake_node = False
new_inputs = inputs[:] new_inputs: List["Variable"] = list(inputs)
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if curr.type != new.type: if curr.type != new.type:
if strict: if strict:
...@@ -260,10 +271,15 @@ class Apply(Node): ...@@ -260,10 +271,15 @@ class Apply(Node):
remake_node = True remake_node = True
if remake_node: if remake_node:
new_node = self.op.make_node(*new_inputs) new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
new_op = new_op.clone()
new_node = new_op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag) new_node.tag = copy(self.tag).__update__(new_node.tag)
else: else:
new_node = self.clone() new_node = self.clone(clone_inner_graph=clone_inner_graph)
new_node.inputs = new_inputs new_node.inputs = new_inputs
return new_node return new_node
...@@ -485,19 +501,16 @@ class Variable(Node): ...@@ -485,19 +501,16 @@ class Variable(Node):
return "\n".join(to_print) return "\n".join(to_print)
def clone(self): def clone(self):
"""Return a new `Variable` like `self`. """Return a new, un-owned `Variable` like `self`.
Returns Returns
------- -------
Variable instance Variable instance
A new `Variable` instance (or subclass instance) with no owner or A new `Variable` instance with no owner or index.
index.
Notes Notes
----- -----
Tags are copied to the returned instance. Tags and names are copied to the returned instance.
Name is copied to the returned instance.
""" """
# return copy(self) # return copy(self)
...@@ -941,6 +954,7 @@ def clone( ...@@ -941,6 +954,7 @@ def clone(
outputs: List[Variable], outputs: List[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: Optional[bool] = None, copy_orphans: Optional[bool] = None,
clone_inner_graphs: bool = False,
) -> Tuple[Collection[Variable], Collection[Variable]]: ) -> Tuple[Collection[Variable], Collection[Variable]]:
r"""Copies the sub-graph contained between inputs and outputs. r"""Copies the sub-graph contained between inputs and outputs.
...@@ -956,6 +970,8 @@ def clone( ...@@ -956,6 +970,8 @@ def clone(
When ``None``, use the `copy_inputs` value. When ``None``, use the `copy_inputs` value.
When ``True``, new orphans nodes are created. When ``True``, new orphans nodes are created.
When ``False``, original orphans nodes are reused in the new graph. When ``False``, original orphans nodes are reused in the new graph.
clone_inner_graphs : bool
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
Returns Returns
------- -------
...@@ -971,20 +987,81 @@ def clone( ...@@ -971,20 +987,81 @@ def clone(
""" """
if copy_orphans is None: if copy_orphans is None:
copy_orphans = copy_inputs copy_orphans = copy_inputs
equiv = clone_get_equiv(inputs, outputs, copy_inputs, copy_orphans) equiv = clone_get_equiv(
inputs,
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
clone_inner_graphs=clone_inner_graphs,
)
return [cast(Variable, equiv[input]) for input in inputs], [ return [cast(Variable, equiv[input]) for input in inputs], [
cast(Variable, equiv[output]) for output in outputs cast(Variable, equiv[output]) for output in outputs
] ]
def clone_node_and_cache(
node: Apply,
clone_d: Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]],
clone_inner_graphs=False,
**kwargs,
) -> Optional[Apply]:
"""Clone an `Apply` node and cache the results in `clone_d`.
This function handles `Op` clones that are generated by inner-graph
cloning.
Returns
-------
``None`` if all of `node`'s outputs are already in `clone_d`; otherwise,
return the clone of `node`.
"""
if all(out in clone_d for out in node.outputs):
# If all of `node`'s outputs already have replacements or clones in
# `clone_d`, then there's likely no need to clone it
return None
# Use a cached `Op` clone when available
new_op: Optional["Op"] = cast(Optional["Op"], clone_d.get(node.op))
cloned_inputs: List[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
new_node = node.clone_with_new_inputs(
cloned_inputs,
# Only clone inner-graph `Op`s when there isn't a cached clone (and
# when `clone_inner_graphs` is enabled)
clone_inner_graph=clone_inner_graphs if new_op is None else False,
**kwargs,
)
if new_op:
# If we didn't clone the inner-graph `Op` above, because
# there was a cached version, set the cloned `Apply` to use
# the cached clone `Op`
new_node.op = new_op
clone_d[node] = new_node
if new_node.op is not node.op:
clone_d.setdefault(node.op, new_node.op)
for old_o, new_o in zip(node.outputs, new_node.outputs):
clone_d.setdefault(old_o, new_o)
return new_node
def clone_get_equiv( def clone_get_equiv(
inputs: Sequence[Variable], inputs: Sequence[Variable],
outputs: Sequence[Variable], outputs: Sequence[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: Optional[Dict[Node, Node]] = None, memo: Optional[
) -> Dict[Node, Node]: Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
""" ] = None,
clone_inner_graphs: bool = False,
) -> Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]:
r"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph. original graph to a new node (a clone) in a new graph.
...@@ -993,20 +1070,22 @@ def clone_get_equiv( ...@@ -993,20 +1070,22 @@ def clone_get_equiv(
Parameters Parameters
---------- ----------
inputs : a list of Variables inputs
outputs : a list of Variables outputs
copy_inputs : bool copy_inputs
True means to create the cloned graph from new input True means to create the cloned graph from new input
nodes (the bottom of a feed-upward graph). nodes (the bottom of a feed-upward graph).
False means to clone a graph that is rooted at the original input False means to clone a graph that is rooted at the original input
nodes. nodes.
copy_orphans : copy_orphans
When ``True``, new constant nodes are created. When ``False``, original When ``True``, new constant nodes are created. When ``False``, original
constant nodes are reused in the new graph. constant nodes are reused in the new graph.
memo : None or dict memo
Optionally start with a partly-filled dictionary for the return value. Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that If a dictionary is passed, this function will work in-place on that
dictionary and return it. dictionary and return it.
clone_inner_graphs
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
""" """
if memo is None: if memo is None:
...@@ -1032,10 +1111,7 @@ def clone_get_equiv( ...@@ -1032,10 +1111,7 @@ def clone_get_equiv(
else: else:
memo[input] = input memo[input] = input
new_apply = apply.clone_with_new_inputs([memo[i] for i in apply.inputs]) clone_node_and_cache(apply, memo, clone_inner_graphs=clone_inner_graphs)
memo.setdefault(apply, new_apply)
for output, new_output in zip(apply.outputs, new_apply.outputs):
memo.setdefault(output, new_output)
# finish up by cloning any remaining outputs (it can happen) # finish up by cloning any remaining outputs (it can happen)
for output in outputs: for output in outputs:
...@@ -1046,12 +1122,11 @@ def clone_get_equiv( ...@@ -1046,12 +1122,11 @@ def clone_get_equiv(
def clone_replace( def clone_replace(
output: List[Variable], output: Collection[Variable],
replace: Optional[ replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]] Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None, ] = None,
strict: bool = True, **rebuild_kwds,
share_inputs: bool = True,
) -> List[Variable]: ) -> List[Variable]:
"""Clone a graph and replace subgraphs within it. """Clone a graph and replace subgraphs within it.
...@@ -1064,11 +1139,8 @@ def clone_replace( ...@@ -1064,11 +1139,8 @@ def clone_replace(
Aesara expression that represents the computational graph. Aesara expression that represents the computational graph.
replace : dict replace : dict
Dictionary describing which subgraphs should be replaced by what. Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool rebuild_kwds
If ``True``, use the same inputs (and shared variables) as the original Keywords to `rebuild_collect_shared`.
graph. If ``False``, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
""" """
from aesara.compile.function.pfunc import rebuild_collect_shared from aesara.compile.function.pfunc import rebuild_collect_shared
...@@ -1090,14 +1162,10 @@ def clone_replace( ...@@ -1090,14 +1162,10 @@ def clone_replace(
) )
tmp_replace = [(x, x.type()) for x, y in items] tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared( _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
output, [], tmp_replace, [], strict, share_inputs
)
# TODO Explain why we call it twice ?! # TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared( _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
_outs, [], new_replace, [], strict, share_inputs
)
return cast(List[Variable], outs) return cast(List[Variable], outs)
...@@ -1473,13 +1541,12 @@ def view_roots(node: Variable) -> List[Variable]: ...@@ -1473,13 +1541,12 @@ def view_roots(node: Variable) -> List[Variable]:
owner = node.owner owner = node.owner
if owner is not None: if owner is not None:
try: try:
view_map = owner.op.view_map vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
view_map = {owner.outputs[o]: i for o, i in view_map.items()}
except AttributeError: except AttributeError:
return [node] return [node]
if node in view_map: if node in vars_to_views:
answer = [] answer = []
for i in view_map[node]: for i in vars_to_views[node]:
answer += view_roots(owner.inputs[i]) answer += view_roots(owner.inputs[i])
return answer return answer
else: else:
......
...@@ -18,7 +18,7 @@ from typing_extensions import Literal ...@@ -18,7 +18,7 @@ from typing_extensions import Literal
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, AtomicVariable, Node, Variable, applys_between from aesara.graph.basic import Apply, AtomicVariable, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
...@@ -69,7 +69,7 @@ class FunctionGraph(MetaObject): ...@@ -69,7 +69,7 @@ class FunctionGraph(MetaObject):
features: Optional[Sequence[Feature]] = None, features: Optional[Sequence[Feature]] = None,
clone: bool = True, clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None, update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict] = None,
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
): ):
...@@ -111,7 +111,7 @@ class FunctionGraph(MetaObject): ...@@ -111,7 +111,7 @@ class FunctionGraph(MetaObject):
outputs, outputs,
copy_inputs=copy_inputs, copy_inputs=copy_inputs,
copy_orphans=copy_orphans, copy_orphans=copy_orphans,
memo=cast(Dict[Node, Node], memo), memo=memo,
) )
outputs = [cast(Variable, _memo[o]) for o in outputs] outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs] inputs = [cast(Variable, _memo[i]) for i in inputs]
...@@ -869,7 +869,7 @@ class FunctionGraph(MetaObject): ...@@ -869,7 +869,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv( def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict[Node, Node]]: ) -> Tuple["FunctionGraph", Dict]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes. """Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters Parameters
......
...@@ -12,6 +12,7 @@ from typing import ( ...@@ -12,6 +12,7 @@ from typing import (
Sequence, Sequence,
Text, Text,
Tuple, Tuple,
TypeVar,
Union, Union,
cast, cast,
) )
...@@ -50,12 +51,14 @@ ThunkCallableType = Callable[ ...@@ -50,12 +51,14 @@ ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None [PerformMethodType, StorageMapType, ComputeMapType, Apply], None
] ]
C = TypeVar("C", bound=Callable)
class ThunkType(Protocol):
class ThunkType(Protocol[C]):
inputs: List[List[Optional[List[Any]]]] inputs: List[List[Optional[List[Any]]]]
outputs: List[List[Optional[List[Any]]]] outputs: List[List[Optional[List[Any]]]]
lazy: bool lazy: bool
__call__: ThunkCallableType __call__: C
perform: PerformMethodType perform: PerformMethodType
...@@ -132,8 +135,7 @@ def compute_test_value(node: Apply): ...@@ -132,8 +135,7 @@ def compute_test_value(node: Apply):
thunk.inputs = [storage_map[v] for v in node.inputs] thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs] thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk() thunk()
assert not required # We provided all inputs
for output in node.outputs: for output in node.outputs:
# Check that the output has been computed # Check that the output has been computed
...@@ -495,7 +497,7 @@ class Op(MetaObject): ...@@ -495,7 +497,7 @@ class Op(MetaObject):
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType,
no_recycling: bool, no_recycling: List[Variable],
debug: bool = False, debug: bool = False,
) -> ThunkType: ) -> ThunkType:
"""Make a Python thunk. """Make a Python thunk.
...@@ -506,8 +508,8 @@ class Op(MetaObject): ...@@ -506,8 +508,8 @@ class Op(MetaObject):
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
if debug: if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform p = node.op.debug_perform # type: ignore
else: else:
p = node.op.perform p = node.op.perform
...@@ -551,7 +553,7 @@ class Op(MetaObject): ...@@ -551,7 +553,7 @@ class Op(MetaObject):
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType,
no_recycling: bool, no_recycling: List[Variable],
impl: Optional[Text] = None, impl: Optional[Text] = None,
) -> ThunkType: ) -> ThunkType:
r"""Create a thunk. r"""Create a thunk.
......
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import traceback import traceback
from abc import ABCMeta from abc import ABCMeta
from io import StringIO from io import StringIO
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, TypeVar, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, TypeVar, Union
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -282,6 +282,13 @@ class Scratchpad: ...@@ -282,6 +282,13 @@ class Scratchpad:
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
print(f" {k}: {v}") print(f" {k}: {v}")
# These two methods have been added to help Mypy
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
self.__dict__[name] = value
class ValidatingScratchpad(Scratchpad): class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values.""" """This `Scratchpad` validates attribute values."""
......
...@@ -318,8 +318,6 @@ def raise_with_op( ...@@ -318,8 +318,6 @@ def raise_with_op(
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
trace = getattr(node.outputs[0].tag, "trace", ()) trace = getattr(node.outputs[0].tag, "trace", ())
if not trace and hasattr(node.op, "tag"):
trace = getattr(node.op.tag, "trace", ())
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = node exc_value.__op_instance__ = node
...@@ -366,8 +364,6 @@ def raise_with_op( ...@@ -366,8 +364,6 @@ def raise_with_op(
detailed_err_msg += "\nInputs type_num: %s" % str( detailed_err_msg += "\nInputs type_num: %s" % str(
[getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs] [getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs]
) )
if hasattr(node.op, "__input_name__"):
detailed_err_msg += f"\nInputs name: {node.op.__input_name__}\n"
detailed_err_msg += f"\nOutputs clients: {clients}\n" detailed_err_msg += f"\nOutputs clients: {clients}\n"
else: else:
......
...@@ -720,6 +720,8 @@ def push_out_inner_vars( ...@@ -720,6 +720,8 @@ def push_out_inner_vars(
fgraph, old_scan_node, old_scan_args, add_as_nitsots fgraph, old_scan_node, old_scan_args, add_as_nitsots
) )
assert isinstance(new_scan_node.op, Scan)
new_scan_args = ScanArgs( new_scan_args = ScanArgs(
new_scan_node.inputs, new_scan_node.inputs,
new_scan_node.outputs, new_scan_node.outputs,
...@@ -761,6 +763,8 @@ def add_nitsot_outputs( ...@@ -761,6 +763,8 @@ def add_nitsot_outputs(
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner) new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value) new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
assert isinstance(old_scan_node.op, Scan)
# Create the `Scan` `Op` from the `ScanArgs` # Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan( new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_inputs,
......
...@@ -14,6 +14,7 @@ from aesara.graph.basic import ( ...@@ -14,6 +14,7 @@ from aesara.graph.basic import (
applys_between, applys_between,
as_string, as_string,
clone, clone,
clone_get_equiv,
clone_replace, clone_replace,
equal_computations, equal_computations,
general_toposort, general_toposort,
...@@ -186,6 +187,31 @@ class TestClone(X): ...@@ -186,6 +187,31 @@ class TestClone(X):
i, o = clone([c1], [c1], False, True) i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1 assert i[0] is c1 and o[0] is c1
def test_clone_inner_graph(self):
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable(4)
igo_in_2 = MyVariable(5)
igo_out_1 = MyOp(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
o2 = igo(r3, o1)
o2.name = "o1"
o2_node = o2.owner
o2_node_clone = o2_node.clone(clone_inner_graph=True)
assert o2_node_clone is not o2_node
assert o2_node_clone.op.fgraph is not o2_node.op.fgraph
assert equal_computations(
o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs
)
def prenode(obj): def prenode(obj):
if isinstance(obj, Variable): if isinstance(obj, Variable):
...@@ -535,7 +561,7 @@ class TestCloneReplace: ...@@ -535,7 +561,7 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True) f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
...@@ -551,7 +577,9 @@ class TestCloneReplace: ...@@ -551,7 +577,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False) f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
...@@ -568,7 +596,9 @@ class TestCloneReplace: ...@@ -568,7 +596,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=True, share_inputs=True) f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -584,7 +614,9 @@ class TestCloneReplace: ...@@ -584,7 +614,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=False, share_inputs=True) f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -600,7 +632,9 @@ class TestCloneReplace: ...@@ -600,7 +632,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False) f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -616,7 +650,9 @@ class TestCloneReplace: ...@@ -616,7 +650,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False) f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -672,6 +708,27 @@ def test_clone_new_inputs(): ...@@ -672,6 +708,27 @@ def test_clone_new_inputs():
assert z_node_new.inputs[1].type.shape == (1,) assert z_node_new.inputs[1].type.shape == (1,)
def test_clone_get_equiv():
x = vector("x")
y = vector("y")
z = vector("z")
a = x * y
a_node = a.owner
b = a + 1.0
memo = {a: z}
_ = clone_get_equiv([x, y], [b], copy_inputs=False, copy_orphans=False, memo=memo)
assert x in memo
assert y in memo
assert memo[a] is z
# All the outputs of `a` already had replacements/clones in the map, so
# there is no need to re-clone it (unless another replacement/clone
# re-introduces `a.owner` somehow).
assert a_node not in memo
assert equal_computations([memo[b]], [z + 1.0])
def test_NominalVariable(): def test_NominalVariable():
type1 = MyType(1) type1 = MyType(1)
......
...@@ -157,3 +157,6 @@ class MyInnerGraphOp(Op, HasInnerGraph): ...@@ -157,3 +157,6 @@ class MyInnerGraphOp(Op, HasInnerGraph):
@property @property
def inner_outputs(self): def inner_outputs(self):
return self.fgraph.outputs return self.fgraph.outputs
def clone(self):
return type(self)(self.fgraph.inputs, self.fgraph.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论