提交 c752a8e3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a HasInnerGraph.clone method

This method adds the ability to clone a `HasInnerGraph` `Op` and its inner-`FunctionGraph`.
上级 ec12c35c
"""Define new Ops from existing Ops"""
from collections import OrderedDict
from copy import copy
from functools import partial
from typing import List, Optional, Sequence, cast
......@@ -915,6 +916,11 @@ class OpFromGraph(Op, HasInnerGraph):
def inner_outputs(self):
return self.fgraph.outputs
def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone()
return res
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
......
import copy
import sys
import warnings
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
......@@ -612,7 +612,7 @@ class _NoPythonOp(Op):
raise NotImplementedError("No Python implementation is provided by this Op.")
class HasInnerGraph:
class HasInnerGraph(ABC):
r"""A mixin for an `Op` that contain an inner graph."""
fgraph: "FunctionGraph"
......@@ -621,7 +621,7 @@ class HasInnerGraph:
@property
@abstractmethod
def fn(self) -> "Function":
"""The inner function."""
"""The compiled inner-graph function."""
@property
@abstractmethod
......@@ -633,6 +633,10 @@ class HasInnerGraph:
def inner_outputs(self) -> List[Variable]:
"""The inner function's outputs."""
@abstractmethod
def clone(self) -> Op:
"""Clone the `Op` and its inner-graph."""
def get_test_value(v: Any) -> Any:
"""Get the test value for `v`.
......
......@@ -47,6 +47,7 @@ import dataclasses
import logging
import time
from collections import OrderedDict
from copy import copy
from itertools import chain, product
from typing import Callable, List, Optional, Union
......@@ -1495,6 +1496,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def inner_outputs(self):
return self.fgraph.outputs
def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone()
return res
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
"""
......
......@@ -9,6 +9,7 @@ from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.basic import equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph
......@@ -52,6 +53,17 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={})
def test_clone(self):
x, y, z = matrices("xyz")
ofg = OpFromGraph([x], [2 * x])
ofg_clone = ofg.clone()
assert ofg_clone.fgraph is not ofg.fgraph
assert ofg_clone.fgraph.outputs != ofg.fgraph.outputs
assert equal_computations(ofg_clone.fgraph.outputs, ofg.fgraph.outputs)
@pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
)
......
......@@ -28,7 +28,7 @@ from aesara.compile.monitormode import MonitorMode
from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from aesara.graph.basic import Apply, ancestors
from aesara.graph.basic import Apply, ancestors, equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import MergeOptimizer
......@@ -291,6 +291,20 @@ class TestScan:
assert len(set(res)) == 4
assert len(set(z_res)) == 1
def test_clone(self):
a = vector()
output, _ = scan(fn=lambda x: x**2, sequences=[a])
scan_op = output.owner.op
assert isinstance(scan_op, Scan)
scan_op_clone = scan_op.clone()
assert scan_op_clone is not scan_op
assert scan_op_clone.fgraph is not scan_op.fgraph
assert scan_op_clone.fgraph.outputs != scan_op.fgraph.outputs
assert equal_computations(scan_op_clone.fgraph.outputs, scan_op.fgraph.outputs)
@pytest.mark.skipif(
isinstance(get_default_mode(), DebugMode),
reason="This test fails in DebugMode, because it is not yet picklable.",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论