提交 77396f7f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inner-graph cloning to Apply.clone and core cloning functions

上级 c752a8e3
...@@ -800,7 +800,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -800,7 +800,7 @@ class OpFromGraph(Op, HasInnerGraph):
# If the new shared variables are inconsistent with the inner-graph, # If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step # such errors should arise in this step
new_inner_outputs = clone_replace( new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True self.inner_outputs, replace=replace, copy_inputs_over=True
) )
# It's possible that the new shared variable inputs aren't actually # It's possible that the new shared variable inputs aren't actually
......
...@@ -12,7 +12,7 @@ from aesara.compile.io import In, Out ...@@ -12,7 +12,7 @@ from aesara.compile.io import In, Out
from aesara.compile.profiling import ProfileStats from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable, clone_node_and_cache
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -29,8 +29,9 @@ def rebuild_collect_shared( ...@@ -29,8 +29,9 @@ def rebuild_collect_shared(
rebuild_strict=True, rebuild_strict=True,
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=False, no_default_updates=False,
clone_inner_graphs=False,
): ):
"""Replace subgraphs of a computational graph. r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?) It returns a set of dictionaries and lists which collect (partial?)
different information about shared variables. This info is required by different information about shared variables. This info is required by
...@@ -59,6 +60,9 @@ def rebuild_collect_shared( ...@@ -59,6 +60,9 @@ def rebuild_collect_shared(
If False (default), perform them all. If False (default), perform them all.
Else, perform automatic updates on all Variables that are neither in Else, perform automatic updates on all Variables that are neither in
"updates" nor in "no_default_updates". "updates" nor in "no_default_updates".
clone_inner_graphs : bool
If ``True``, clone `Op`\s that are subclasses of `HasInnerGraph` and their
inner-graphs.
""" """
...@@ -89,13 +93,12 @@ def rebuild_collect_shared( ...@@ -89,13 +93,12 @@ def rebuild_collect_shared(
if owner not in clone_d: if owner not in clone_d:
for i in owner.inputs: for i in owner.inputs:
clone_v_get_shared_updates(i, copy_inputs_over) clone_v_get_shared_updates(i, copy_inputs_over)
clone_node_and_cache(
clone_d[owner] = owner.clone_with_new_inputs( owner,
[clone_d[i] for i in owner.inputs], strict=rebuild_strict clone_d,
strict=rebuild_strict,
clone_inner_graphs=clone_inner_graphs,
) )
for old_o, new_o in zip(owner.outputs, clone_d[owner].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d.setdefault(v, v) return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs:
...@@ -494,6 +497,7 @@ def construct_pfunc_ins_and_outs( ...@@ -494,6 +497,7 @@ def construct_pfunc_ins_and_outs(
rebuild_strict=rebuild_strict, rebuild_strict=rebuild_strict,
copy_inputs_over=True, copy_inputs_over=True,
no_default_updates=no_default_updates, no_default_updates=no_default_updates,
clone_inner_graphs=True,
) )
input_variables, cloned_extended_outputs, other_stuff = output_vars input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff clone_d, update_d, update_expr, shared_inputs = other_stuff
......
差异被折叠。
...@@ -18,7 +18,7 @@ from typing_extensions import Literal ...@@ -18,7 +18,7 @@ from typing_extensions import Literal
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, AtomicVariable, Node, Variable, applys_between from aesara.graph.basic import Apply, AtomicVariable, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
...@@ -69,7 +69,7 @@ class FunctionGraph(MetaObject): ...@@ -69,7 +69,7 @@ class FunctionGraph(MetaObject):
features: Optional[Sequence[Feature]] = None, features: Optional[Sequence[Feature]] = None,
clone: bool = True, clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None, update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict] = None,
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
): ):
...@@ -111,7 +111,7 @@ class FunctionGraph(MetaObject): ...@@ -111,7 +111,7 @@ class FunctionGraph(MetaObject):
outputs, outputs,
copy_inputs=copy_inputs, copy_inputs=copy_inputs,
copy_orphans=copy_orphans, copy_orphans=copy_orphans,
memo=cast(Dict[Node, Node], memo), memo=memo,
) )
outputs = [cast(Variable, _memo[o]) for o in outputs] outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs] inputs = [cast(Variable, _memo[i]) for i in inputs]
...@@ -869,7 +869,7 @@ class FunctionGraph(MetaObject): ...@@ -869,7 +869,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv( def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict[Node, Node]]: ) -> Tuple["FunctionGraph", Dict]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes. """Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters Parameters
......
...@@ -12,6 +12,7 @@ from typing import ( ...@@ -12,6 +12,7 @@ from typing import (
Sequence, Sequence,
Text, Text,
Tuple, Tuple,
TypeVar,
Union, Union,
cast, cast,
) )
...@@ -50,12 +51,14 @@ ThunkCallableType = Callable[ ...@@ -50,12 +51,14 @@ ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None [PerformMethodType, StorageMapType, ComputeMapType, Apply], None
] ]
C = TypeVar("C", bound=Callable)
class ThunkType(Protocol):
class ThunkType(Protocol[C]):
inputs: List[List[Optional[List[Any]]]] inputs: List[List[Optional[List[Any]]]]
outputs: List[List[Optional[List[Any]]]] outputs: List[List[Optional[List[Any]]]]
lazy: bool lazy: bool
__call__: ThunkCallableType __call__: C
perform: PerformMethodType perform: PerformMethodType
...@@ -132,8 +135,7 @@ def compute_test_value(node: Apply): ...@@ -132,8 +135,7 @@ def compute_test_value(node: Apply):
thunk.inputs = [storage_map[v] for v in node.inputs] thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs] thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk() thunk()
assert not required # We provided all inputs
for output in node.outputs: for output in node.outputs:
# Check that the output has been computed # Check that the output has been computed
...@@ -495,7 +497,7 @@ class Op(MetaObject): ...@@ -495,7 +497,7 @@ class Op(MetaObject):
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType,
no_recycling: bool, no_recycling: List[Variable],
debug: bool = False, debug: bool = False,
) -> ThunkType: ) -> ThunkType:
"""Make a Python thunk. """Make a Python thunk.
...@@ -506,8 +508,8 @@ class Op(MetaObject): ...@@ -506,8 +508,8 @@ class Op(MetaObject):
node_input_storage = [storage_map[r] for r in node.inputs] node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
if debug: if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform p = node.op.debug_perform # type: ignore
else: else:
p = node.op.perform p = node.op.perform
...@@ -551,7 +553,7 @@ class Op(MetaObject): ...@@ -551,7 +553,7 @@ class Op(MetaObject):
node: Apply, node: Apply,
storage_map: StorageMapType, storage_map: StorageMapType,
compute_map: ComputeMapType, compute_map: ComputeMapType,
no_recycling: bool, no_recycling: List[Variable],
impl: Optional[Text] = None, impl: Optional[Text] = None,
) -> ThunkType: ) -> ThunkType:
r"""Create a thunk. r"""Create a thunk.
......
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import traceback import traceback
from abc import ABCMeta from abc import ABCMeta
from io import StringIO from io import StringIO
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, TypeVar, Union from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, TypeVar, Union
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -282,6 +282,13 @@ class Scratchpad: ...@@ -282,6 +282,13 @@ class Scratchpad:
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
print(f" {k}: {v}") print(f" {k}: {v}")
# These two methods have been added to help Mypy
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
self.__dict__[name] = value
class ValidatingScratchpad(Scratchpad): class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values.""" """This `Scratchpad` validates attribute values."""
......
...@@ -318,8 +318,6 @@ def raise_with_op( ...@@ -318,8 +318,6 @@ def raise_with_op(
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
trace = getattr(node.outputs[0].tag, "trace", ()) trace = getattr(node.outputs[0].tag, "trace", ())
if not trace and hasattr(node.op, "tag"):
trace = getattr(node.op.tag, "trace", ())
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = node exc_value.__op_instance__ = node
...@@ -366,8 +364,6 @@ def raise_with_op( ...@@ -366,8 +364,6 @@ def raise_with_op(
detailed_err_msg += "\nInputs type_num: %s" % str( detailed_err_msg += "\nInputs type_num: %s" % str(
[getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs] [getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs]
) )
if hasattr(node.op, "__input_name__"):
detailed_err_msg += f"\nInputs name: {node.op.__input_name__}\n"
detailed_err_msg += f"\nOutputs clients: {clients}\n" detailed_err_msg += f"\nOutputs clients: {clients}\n"
else: else:
......
...@@ -720,6 +720,8 @@ def push_out_inner_vars( ...@@ -720,6 +720,8 @@ def push_out_inner_vars(
fgraph, old_scan_node, old_scan_args, add_as_nitsots fgraph, old_scan_node, old_scan_args, add_as_nitsots
) )
assert isinstance(new_scan_node.op, Scan)
new_scan_args = ScanArgs( new_scan_args = ScanArgs(
new_scan_node.inputs, new_scan_node.inputs,
new_scan_node.outputs, new_scan_node.outputs,
...@@ -761,6 +763,8 @@ def add_nitsot_outputs( ...@@ -761,6 +763,8 @@ def add_nitsot_outputs(
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner) new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value) new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
assert isinstance(old_scan_node.op, Scan)
# Create the `Scan` `Op` from the `ScanArgs` # Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan( new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_inputs,
......
...@@ -14,6 +14,7 @@ from aesara.graph.basic import ( ...@@ -14,6 +14,7 @@ from aesara.graph.basic import (
applys_between, applys_between,
as_string, as_string,
clone, clone,
clone_get_equiv,
clone_replace, clone_replace,
equal_computations, equal_computations,
general_toposort, general_toposort,
...@@ -186,6 +187,31 @@ class TestClone(X): ...@@ -186,6 +187,31 @@ class TestClone(X):
i, o = clone([c1], [c1], False, True) i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1 assert i[0] is c1 and o[0] is c1
def test_clone_inner_graph(self):
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable(4)
igo_in_2 = MyVariable(5)
igo_out_1 = MyOp(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
o2 = igo(r3, o1)
o2.name = "o1"
o2_node = o2.owner
o2_node_clone = o2_node.clone(clone_inner_graph=True)
assert o2_node_clone is not o2_node
assert o2_node_clone.op.fgraph is not o2_node.op.fgraph
assert equal_computations(
o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs
)
def prenode(obj): def prenode(obj):
if isinstance(obj, Variable): if isinstance(obj, Variable):
...@@ -535,7 +561,7 @@ class TestCloneReplace: ...@@ -535,7 +561,7 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True) f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
...@@ -551,7 +577,9 @@ class TestCloneReplace: ...@@ -551,7 +577,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False) f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
...@@ -568,7 +596,9 @@ class TestCloneReplace: ...@@ -568,7 +596,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=True, share_inputs=True) f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -584,7 +614,9 @@ class TestCloneReplace: ...@@ -584,7 +614,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=False, share_inputs=True) f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -600,7 +632,9 @@ class TestCloneReplace: ...@@ -600,7 +632,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False) f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -616,7 +650,9 @@ class TestCloneReplace: ...@@ -616,7 +650,9 @@ class TestCloneReplace:
z = shared(0.25) z = shared(0.25)
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False) f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2]) f2_inp = graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -672,6 +708,27 @@ def test_clone_new_inputs(): ...@@ -672,6 +708,27 @@ def test_clone_new_inputs():
assert z_node_new.inputs[1].type.shape == (1,) assert z_node_new.inputs[1].type.shape == (1,)
def test_clone_get_equiv():
x = vector("x")
y = vector("y")
z = vector("z")
a = x * y
a_node = a.owner
b = a + 1.0
memo = {a: z}
_ = clone_get_equiv([x, y], [b], copy_inputs=False, copy_orphans=False, memo=memo)
assert x in memo
assert y in memo
assert memo[a] is z
# All the outputs of `a` already had replacements/clones in the map, so
# there is no need to re-clone it (unless another replacement/clone
# re-introduces `a.owner` somehow).
assert a_node not in memo
assert equal_computations([memo[b]], [z + 1.0])
def test_NominalVariable(): def test_NominalVariable():
type1 = MyType(1) type1 = MyType(1)
......
...@@ -157,3 +157,6 @@ class MyInnerGraphOp(Op, HasInnerGraph): ...@@ -157,3 +157,6 @@ class MyInnerGraphOp(Op, HasInnerGraph):
@property @property
def inner_outputs(self): def inner_outputs(self):
return self.fgraph.outputs return self.fgraph.outputs
def clone(self):
return type(self)(self.fgraph.inputs, self.fgraph.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论