提交 77396f7f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inner-graph cloning to Apply.clone and core cloning functions

上级 c752a8e3
......@@ -800,7 +800,7 @@ class OpFromGraph(Op, HasInnerGraph):
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_inner_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True
self.inner_outputs, replace=replace, copy_inputs_over=True
)
# It's possible that the new shared variable inputs aren't actually
......
......@@ -12,7 +12,7 @@ from aesara.compile.io import In, Out
from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Constant, Variable, clone_node_and_cache
from aesara.graph.fg import FunctionGraph
......@@ -29,8 +29,9 @@ def rebuild_collect_shared(
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
):
"""Replace subgraphs of a computational graph.
r"""Replace subgraphs of a computational graph.
It returns a set of dictionaries and lists which collect (partial?)
different information about shared variables. This info is required by
......@@ -59,6 +60,9 @@ def rebuild_collect_shared(
If False (default), perform them all.
Else, perform automatic updates on all Variables that are neither in
"updates" nor in "no_default_updates".
clone_inner_graphs : bool
If ``True``, clone `Op`\s that are subclasses of `HasInnerGraph` and their
inner-graphs.
"""
......@@ -89,13 +93,12 @@ def rebuild_collect_shared(
if owner not in clone_d:
for i in owner.inputs:
clone_v_get_shared_updates(i, copy_inputs_over)
clone_d[owner] = owner.clone_with_new_inputs(
[clone_d[i] for i in owner.inputs], strict=rebuild_strict
clone_node_and_cache(
owner,
clone_d,
strict=rebuild_strict,
clone_inner_graphs=clone_inner_graphs,
)
for old_o, new_o in zip(owner.outputs, clone_d[owner].outputs):
clone_d.setdefault(old_o, new_o)
return clone_d.setdefault(v, v)
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
......@@ -494,6 +497,7 @@ def construct_pfunc_ins_and_outs(
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
clone_inner_graphs=True,
)
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
......
差异被折叠。
......@@ -18,7 +18,7 @@ from typing_extensions import Literal
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, AtomicVariable, Node, Variable, applys_between
from aesara.graph.basic import Apply, AtomicVariable, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
......@@ -69,7 +69,7 @@ class FunctionGraph(MetaObject):
features: Optional[Sequence[Feature]] = None,
clone: bool = True,
update_mapping: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict[Variable, Variable]] = None,
memo: Optional[Dict] = None,
copy_inputs: bool = True,
copy_orphans: bool = True,
):
......@@ -111,7 +111,7 @@ class FunctionGraph(MetaObject):
outputs,
copy_inputs=copy_inputs,
copy_orphans=copy_orphans,
memo=cast(Dict[Node, Node], memo),
memo=memo,
)
outputs = [cast(Variable, _memo[o]) for o in outputs]
inputs = [cast(Variable, _memo[i]) for i in inputs]
......@@ -869,7 +869,7 @@ class FunctionGraph(MetaObject):
def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True
) -> Tuple["FunctionGraph", Dict[Node, Node]]:
) -> Tuple["FunctionGraph", Dict]:
"""Clone the graph and return a ``dict`` that maps old nodes to new nodes.
Parameters
......
......@@ -12,6 +12,7 @@ from typing import (
Sequence,
Text,
Tuple,
TypeVar,
Union,
cast,
)
......@@ -50,12 +51,14 @@ ThunkCallableType = Callable[
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
]
C = TypeVar("C", bound=Callable)
class ThunkType(Protocol):
class ThunkType(Protocol[C]):
inputs: List[List[Optional[List[Any]]]]
outputs: List[List[Optional[List[Any]]]]
lazy: bool
__call__: ThunkCallableType
__call__: C
perform: PerformMethodType
......@@ -132,8 +135,7 @@ def compute_test_value(node: Apply):
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk()
assert not required # We provided all inputs
thunk()
for output in node.outputs:
# Check that the output has been computed
......@@ -495,7 +497,7 @@ class Op(MetaObject):
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: bool,
no_recycling: List[Variable],
debug: bool = False,
) -> ThunkType:
"""Make a Python thunk.
......@@ -506,8 +508,8 @@ class Op(MetaObject):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
if debug:
p = node.op.debug_perform
if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform # type: ignore
else:
p = node.op.perform
......@@ -551,7 +553,7 @@ class Op(MetaObject):
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
no_recycling: bool,
no_recycling: List[Variable],
impl: Optional[Text] = None,
) -> ThunkType:
r"""Create a thunk.
......
......@@ -3,7 +3,7 @@ import sys
import traceback
from abc import ABCMeta
from io import StringIO
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, TypeVar, Union
if TYPE_CHECKING:
......@@ -282,6 +282,13 @@ class Scratchpad:
for k, v in self.__dict__.items():
print(f" {k}: {v}")
# These two methods have been added to help Mypy
def __getattribute__(self, name):
return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None:
self.__dict__[name] = value
class ValidatingScratchpad(Scratchpad):
"""This `Scratchpad` validates attribute values."""
......
......@@ -318,8 +318,6 @@ def raise_with_op(
raise exc_value.with_traceback(exc_trace)
trace = getattr(node.outputs[0].tag, "trace", ())
if not trace and hasattr(node.op, "tag"):
trace = getattr(node.op.tag, "trace", ())
exc_value.__thunk_trace__ = trace
exc_value.__op_instance__ = node
......@@ -366,8 +364,6 @@ def raise_with_op(
detailed_err_msg += "\nInputs type_num: %s" % str(
[getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs]
)
if hasattr(node.op, "__input_name__"):
detailed_err_msg += f"\nInputs name: {node.op.__input_name__}\n"
detailed_err_msg += f"\nOutputs clients: {clients}\n"
else:
......
......@@ -720,6 +720,8 @@ def push_out_inner_vars(
fgraph, old_scan_node, old_scan_args, add_as_nitsots
)
assert isinstance(new_scan_node.op, Scan)
new_scan_args = ScanArgs(
new_scan_node.inputs,
new_scan_node.outputs,
......@@ -761,6 +763,8 @@ def add_nitsot_outputs(
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
assert isinstance(old_scan_node.op, Scan)
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan(
new_scan_args.inner_inputs,
......
......@@ -14,6 +14,7 @@ from aesara.graph.basic import (
applys_between,
as_string,
clone,
clone_get_equiv,
clone_replace,
equal_computations,
general_toposort,
......@@ -186,6 +187,31 @@ class TestClone(X):
i, o = clone([c1], [c1], False, True)
assert i[0] is c1 and o[0] is c1
def test_clone_inner_graph(self):
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
# Inner graph
igo_in_1 = MyVariable(4)
igo_in_2 = MyVariable(5)
igo_out_1 = MyOp(igo_in_1, igo_in_2)
igo_out_1.name = "igo1"
igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1])
o2 = igo(r3, o1)
o2.name = "o1"
o2_node = o2.owner
o2_node_clone = o2_node.clone(clone_inner_graph=True)
assert o2_node_clone is not o2_node
assert o2_node_clone.op.fgraph is not o2_node.op.fgraph
assert equal_computations(
o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs
)
def prenode(obj):
if isinstance(obj, Variable):
......@@ -535,7 +561,7 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=True)
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
f2_inp = graph_inputs([f2])
assert z in f2_inp
......@@ -551,7 +577,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=None, strict=True, share_inputs=False)
f2 = clone_replace(
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
......@@ -568,7 +596,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=True, share_inputs=True)
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -584,7 +614,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace={y: y2}, strict=False, share_inputs=True)
f2 = clone_replace(
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
)
f2_inp = graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -600,7 +632,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
......@@ -616,7 +650,9 @@ class TestCloneReplace:
z = shared(0.25)
f1 = z * (x + y) ** 2 + 5
f2 = clone_replace(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2 = clone_replace(
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
)
f2_inp = graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
......@@ -672,6 +708,27 @@ def test_clone_new_inputs():
assert z_node_new.inputs[1].type.shape == (1,)
def test_clone_get_equiv():
x = vector("x")
y = vector("y")
z = vector("z")
a = x * y
a_node = a.owner
b = a + 1.0
memo = {a: z}
_ = clone_get_equiv([x, y], [b], copy_inputs=False, copy_orphans=False, memo=memo)
assert x in memo
assert y in memo
assert memo[a] is z
# All the outputs of `a` already had replacements/clones in the map, so
# there is no need to re-clone it (unless another replacement/clone
# re-introduces `a.owner` somehow).
assert a_node not in memo
assert equal_computations([memo[b]], [z + 1.0])
def test_NominalVariable():
type1 = MyType(1)
......
......@@ -157,3 +157,6 @@ class MyInnerGraphOp(Op, HasInnerGraph):
@property
def inner_outputs(self):
return self.fgraph.outputs
def clone(self):
return type(self)(self.fgraph.inputs, self.fgraph.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论