提交 77396f7f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inner-graph cloning to Apply.clone and core cloning functions

上级 c752a8e3
......@@ -800,7 +800,7 @@ class OpFromGraph(Op, HasInnerGraph):
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True
self.inner_outputs, replace=replace, copy_inputs_over=True
)
# It's possible that the new shared variable inputs aren't actually
......
......@@ -12,7 +12,7 @@ from aesara.compile.io import In, Out
from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Constant, Variable, clone_node_and_cache
from aesara.graph.fg import FunctionGraph
......@@ -29,8 +29,9 @@ def rebuild_collect_shared(
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
):
"""Replace subgraphs of a computational graph.
r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?)
different information about shared variables. This info is required by
......@@ -59,6 +60,9 @@ def rebuild_collect_shared(
If False (default), perform them all.
Else, perform automatic updates on all Variables that are neither in
"updates" nor in "no_default_updates".
clone_inner_graphs : bool
If ``True``, clone `Op`\s that are subclasses of `HasInnerGraph` and their
inner-graphs.
"""
......@@ -89,13 +93,12 @@ def rebuild_collect_shared(
if owner not in clone_d:
for i in owner.inputs:
clone_v_get_shared_updates(i, copy_inputs_over)
clone_d[owner] = owner.clone_with_new_inputs(
[clone_d[i] for i in owner.inputs], strict=rebuild_strict
clone_node_and_cache(
owner,
clone_d,
strict=rebuild_strict,
clone_inner_graphs=clone_inner_graphs,
)
for old_o, new_o in zip(owner.outputs, clone_d[owner].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
......@@ -494,6 +497,7 @@ def construct_pfunc_ins_and_outs(
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
clone_inner_graphs=True,
)
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
......
......@@ -41,6 +41,7 @@ from aesara.misc.ordered_set import OrderedSet
if TYPE_CHECKING:
from aesara.graph.op import Op
from aesara.graph.type import Type
......@@ -96,34 +97,28 @@ class Apply(Node):
Attributes
----------
op : Op
op
The operation that produces `outputs` given `inputs`.
inputs : List[Variable]
inputs
The arguments of the expression modeled by the `Apply` node.
outputs : List[Variable]
outputs
The outputs of the expression modeled by the `Apply` node.
"""
def __init__(self, op, inputs, outputs):
"""
Parameters
----------
op : Op
inputs : List[Variable]
outputs : List[Variable]
"""
self.op = op
self.inputs: List[Variable] = []
self.tag = Scratchpad()
def __init__(
self, op: "Op", inputs: Sequence["Variable"], outputs: Sequence["Variable"]
):
if not isinstance(inputs, (list, tuple)):
raise TypeError("The inputs of an Apply must be a list or tuple")
if not isinstance(outputs, (list, tuple)):
raise TypeError("The output of an Apply must be a list or tuple")
self.op = op
self.inputs: List[Variable] = []
self.tag = Scratchpad()
# filter inputs to make sure each element is a Variable
for input in inputs:
if isinstance(input, Variable):
......@@ -202,28 +197,40 @@ class Apply(Node):
def __repr__(self):
return str(self)
def clone(self):
"""
Duplicate this Apply instance with inputs = self.inputs.
def clone(self, clone_inner_graph: bool = False) -> "Apply":
r"""Clone this `Apply` instance.
Parameters
----------
clone_inner_graph
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
Returns
-------
object
A new Apply instance (or subclass instance) with new outputs.
A new `Apply` instance with new outputs.
Notes
-----
Tags are copied from self to the returned instance.
Tags are copied from `self` to the returned instance.
"""
from aesara.graph.op import HasInnerGraph
new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
new_op = new_op.clone()
cp = self.__class__(
self.op, self.inputs, [output.clone() for output in self.outputs]
new_op, self.inputs, [output.clone() for output in self.outputs]
)
cp.tag = copy(self.tag)
return cp
def clone_with_new_inputs(self, inputs, strict=True):
"""Duplicate this `Apply` instance in a new graph.
def clone_with_new_inputs(
self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False
) -> "Apply":
r"""Duplicate this `Apply` instance in a new graph.
Parameters
----------
......@@ -238,6 +245,8 @@ class Apply(Node):
``self.outputs``. If ``False``, then there's no guarantee that the
clone's outputs will have the same types as ``self.outputs``,
and cloning may not even be possible (it depends on the `Op`).
clone_inner_graph : bool
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
Returns
-------
......@@ -245,9 +254,11 @@ class Apply(Node):
An `Apply` instance with the same `Op` but different outputs.
"""
from aesara.graph.op import HasInnerGraph
assert isinstance(inputs, (list, tuple))
remake_node = False
new_inputs = inputs[:]
new_inputs: List["Variable"] = list(inputs)
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if curr.type != new.type:
if strict:
......@@ -260,10 +271,15 @@ class Apply(Node):
remake_node = True
if remake_node:
new_node = self.op.make_node(*new_inputs)
new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
new_op = new_op.clone()
new_node = new_op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
else:
new_node = self.clone()
new_node = self.clone(clone_inner_graph=clone_inner_graph)
new_node.inputs = new_inputs
return new_node
......@@ -485,19 +501,16 @@ class Variable(Node):
return "\n".join(to_print)
def clone(self):
"""Return a new `Variable` like `self`.
"""Return a new, un-owned `Variable` like `self`.
Returns
-------
Variable instance
A new `Variable` instance (or subclass instance) with no owner or
index.
A new `Variable` instance with no owner or index.
Notes
-----
Tags are copied to the returned instance.
Name is copied to the returned instance.
Tags and names are copied to the returned instance.
"""
# return copy(self)
......@@ -941,6 +954,7 @@ def clone(
outputs: List[Variable],
copy_inputs: bool = True,
copy_orphans: Optional[bool] = None,
clone_inner_graphs: bool = False,
) -> Tuple[Collection[Variable], Collection[Variable]]:
r"""Copies the sub-graph contained between inputs and outputs.
......@@ -956,6 +970,8 @@ def clone(
When ``None``, use the `copy_inputs` value.
When ``True``, new orphans nodes are created.
When ``False``, original orphans nodes are reused in the new graph.
clone_inner_graphs : bool
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
Returns
-------
......@@ -971,20 +987,81 @@ def clone(
"""
if copy_orphans is None:
copy_orphans = copy_inputs
equiv = clone_get_equiv(inputs, outputs, copy_inputs, copy_orphans)
equiv = clone_get_equiv(
inputs,
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
clone_inner_graphs=clone_inner_graphs,
)
return [cast(Variable, equiv[input]) for input in inputs], [
cast(Variable, equiv[output]) for output in outputs
]
def clone_node_and_cache(
node: Apply,
clone_d: Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]],
clone_inner_graphs=False,
**kwargs,
) -> Optional[Apply]:
"""Clone an `Apply` node and cache the results in `clone_d`.
This function handles `Op` clones that are generated by inner-graph
cloning.
Returns
-------
``None`` if all of `node`'s outputs are already in `clone_d`; otherwise,
return the clone of `node`.
"""
if all(out in clone_d for out in node.outputs):
# If all of `node`'s outputs already have replacements or clones in
# `clone_d`, then there's likely no need to clone it
return None
# Use a cached `Op` clone when available
new_op: Optional["Op"] = cast(Optional["Op"], clone_d.get(node.op))
cloned_inputs: List[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs]
new_node = node.clone_with_new_inputs(
cloned_inputs,
# Only clone inner-graph `Op`s when there isn't a cached clone (and
# when `clone_inner_graphs` is enabled)
clone_inner_graph=clone_inner_graphs if new_op is None else False,
**kwargs,
)
if new_op:
# If we didn't clone the inner-graph `Op` above, because
# there was a cached version, set the cloned `Apply` to use
# the cached clone `Op`
new_node.op = new_op
clone_d[node] = new_node
if new_node.op is not node.op:
clone_d.setdefault(node.op, new_node.op)
for old_o, new_o in zip(node.outputs, new_node.outputs):
clone_d.setdefault(old_o, new_o)
return new_node
def clone_get_equiv(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: Optional[Dict[Node, Node]] = None,
) -> Dict[Node, Node]:
"""
memo: Optional[
Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
] = None,
clone_inner_graphs: bool = False,
) -> Dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]:
r"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph.
......@@ -993,20 +1070,22 @@ def clone_get_equiv(
Parameters
----------
inputs : a list of Variables
outputs : a list of Variables
copy_inputs : bool
inputs
outputs
copy_inputs
True means to create the cloned graph from new input
nodes (the bottom of a feed-upward graph).
False means to clone a graph that is rooted at the original input
nodes.
copy_orphans :
copy_orphans
When ``True``, new constant nodes are created. When ``False``, original
constant nodes are reused in the new graph.
memo : None or dict
memo
Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that
dictionary and return it.
clone_inner_graphs
If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs.
"""
if memo is None:
......@@ -1032,10 +1111,7 @@ def clone_get_equiv(
else:
memo[input] = input
new_apply = apply.clone_with_new_inputs([memo[i] for i in apply.inputs])
memo.setdefault(apply, new_apply)
for output, new_output in zip(apply.outputs, new_apply.outputs):
memo.setdefault(output, new_output)
clone_node_and_cache(apply, memo, clone_inner_graphs=clone_inner_graphs)
# finish up by cloning any remaining outputs (it can happen)
for output in outputs:
......@@ -1046,12 +1122,11 @@ def clone_get_equiv(
def clone_replace(
output: List[Variable],
output: Collection[Variable],
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
] = None,
strict: bool = True,
share_inputs: bool = True,
**rebuild_kwds,
) -> List[Variable]:
"""Clone a graph and replace subgraphs within it.
......@@ -1064,11 +1139,8 @@ def clone_replace(
Aesara expression that represents the computational graph.
replace : dict
Dictionary describing which subgraphs should be replaced by what.
share_inputs : bool
If ``True``, use the same inputs (and shared variables) as the original
graph. If ``False``, clone them. Note that cloned shared variables still
use the same underlying storage, so they will always have the same
value.
rebuild_kwds
Keywords to `rebuild_collect_shared`.
"""
from aesara.compile.function.pfunc import rebuild_collect_shared
......@@ -1090,14 +1162,10 @@ def clone_replace(
)
tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(
output, [], tmp_replace, [], strict, share_inputs
)
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(
_outs, [], new_replace, [], strict, share_inputs
)
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
return cast(List[Variable], outs)
......@@ -1473,13 +1541,12 @@ def view_roots(node: Variable) -> List[Variable]:
owner = node.owner
if owner is not None:
try:
view_map = owner.op.view_map
view_map = {owner.outputs[o]: i for o, i in view_map.items()}
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
except AttributeError:
return [node]
if node in view_map:
if node in vars_to_views:
answer = []
for i in view_map[node]:
for i in vars_to_views[node]:
answer += view_roots(owner.inputs[i])
return answer
else:
......
......@@ -18,7 +18,7 @@ from typing_extensions import Literal
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, AtomicVariable, Node, Variable, applys_between
from aesara.graph.basic import Apply, AtomicVariable, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
......@@ -69,7 +69,7 @@ class FunctionGraph(MetaObject):
features: Optional[Sequence[Feature]] = None,
clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict] = None,
copy_inputs: bool = True,
copy_orphans: bool = True,
):
......@@ -111,7 +111,7 @@ class FunctionGraph(MetaObject):
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
memo=cast(Dict[Node, Node], memo),
memo=memo,
)
outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs]
......@@ -869,7 +869,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict[Node, Node]]:
) -> Tuple["FunctionGraph", Dict]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters
......
......@@ -12,6 +12,7 @@ from typing import (
Sequence,
Text,
Tuple,
TypeVar,
Union,
cast,
)
......@@ -50,12 +51,14 @@ ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
]
C = TypeVar("C", bound=Callable)
class ThunkType(Protocol):
class ThunkType(Protocol[C]):
inputs: List[List[Optional[List[Any]]]]
outputs: List[List[Optional[List[Any]]]]
lazy: bool
__call__: ThunkCallableType
__call__: C
perform: PerformMethodType
......@@ -132,8 +135,7 @@ def compute_test_value(node: Apply):
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk()
assert not required # We provided all inputs
thunk()
for output in node.outputs:
# Check that the output has been computed
......@@ -495,7 +497,7 @@ class Op(MetaObject):
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: bool,
no_recycling: List[Variable],
debug: bool = False,
) -> ThunkType:
"""Make a Python thunk.
......@@ -506,8 +508,8 @@ class Op(MetaObject):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
if debug:
p = node.op.debug_perform
if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform # type: ignore
else:
p = node.op.perform
......@@ -551,7 +553,7 @@ class Op(MetaObject):
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: bool,
no_recycling: List[Variable],
impl: Optional[Text] = None,
) -> ThunkType:
r"""Create a thunk.
......
......@@ -3,7 +3,7 @@ import sys
import traceback
from abc import ABCMeta
from io import StringIO
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, TypeVar, Union
if TYPE_CHECKING:
......@@ -282,6 +282,13 @@ class Scratchpad:
for k, v in self.__dict__.items():
print(f" {k}: {v}")
# These two methods have been added to help Mypy
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
self.__dict__[name] = value
class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values."""
......
......@@ -318,8 +318,6 @@ def raise_with_op(
raise exc_value.with_traceback(exc_trace)
trace = getattr(node.outputs[0].tag, "trace", ())
if not trace and hasattr(node.op, "tag"):
trace = getattr(node.op.tag, "trace", ())
exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = node
......@@ -366,8 +364,6 @@ def raise_with_op(
detailed_err_msg += "\nInputs type_num: %s" % str(
[getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs]
)
if hasattr(node.op, "__input_name__"):
detailed_err_msg += f"\nInputs name: {node.op.__input_name__}\n"
detailed_err_msg += f"\nOutputs clients: {clients}\n"
else:
......
......@@ -720,6 +720,8 @@ def push_out_inner_vars(
fgraph, old_scan_node, old_scan_args, add_as_nitsots
)
assert isinstance(new_scan_node.op, Scan)
new_scan_args = ScanArgs(
new_scan_node.inputs,
new_scan_node.outputs,
......@@ -761,6 +763,8 @@ def add_nitsot_outputs(
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
assert isinstance(old_scan_node.op, Scan)
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan(
new_scan_args.inner_inputs,
......
......@@ -14,6 +14,7 @@ from aesara.graph.basic import (
applys_between,
as_string,
clone,
clone_get_equiv,
clone_replace,
equal_computations,
general_toposort,
......@@ -186,6 +187,31 @@ class TestClone(X):
i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1
def test_clone_inner_graph(self):
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable(4)
igo_in_2 = MyVariable(5)
igo_out_1 = MyOp(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
o2 = igo(r3, o1)
o2.name = "o1"
o2_node = o2.owner
o2_node_clone = o2_node.clone(clone_inner_graph=True)
assert o2_node_clone is not o2_node
assert o2_node_clone.op.fgraph is not o2_node.op.fgraph
assert equal_computations(
o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs
)
def prenode(obj):
if isinstance(obj, Variable):
......@@ -535,7 +561,7 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
......@@ -551,7 +577,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
......@@ -568,7 +596,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=True, share_inputs=True)
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -584,7 +614,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=False, share_inputs=True)
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -600,7 +632,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
......@@ -616,7 +650,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
......@@ -672,6 +708,27 @@ def test_clone_new_inputs():
assert z_node_new.inputs[1].type.shape == (1,)
def test_clone_get_equiv():
x = vector("x")
y = vector("y")
z = vector("z")
a = x * y
a_node = a.owner
b = a + 1.0
memo = {a: z}
_ = clone_get_equiv([x, y], [b], copy_inputs=False, copy_orphans=False, memo=memo)
assert x in memo
assert y in memo
assert memo[a] is z
# All the outputs of `a` already had replacements/clones in the map, so
# there is no need to re-clone it (unless another replacement/clone
# re-introduces `a.owner` somehow).
assert a_node not in memo
assert equal_computations([memo[b]], [z + 1.0])
def test_NominalVariable():
type1 = MyType(1)
......
......@@ -157,3 +157,6 @@ class MyInnerGraphOp(Op, HasInnerGraph):
@property
def inner_outputs(self):
return self.fgraph.outputs
def clone(self):
return type(self)(self.fgraph.inputs, self.fgraph.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论