提交 c752a8e3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a HasInnerGraph.clone method

This method adds the ability to clone a `HasInnerGraph` `Op` and its inner-`FunctionGraph`.
上级 ec12c35c
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
from collections import OrderedDict from collections import OrderedDict
from copy import copy
from functools import partial from functools import partial
from typing import List, Optional, Sequence, cast from typing import List, Optional, Sequence, cast
...@@ -915,6 +916,11 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -915,6 +916,11 @@ class OpFromGraph(Op, HasInnerGraph):
def inner_outputs(self): def inner_outputs(self):
return self.fgraph.outputs return self.fgraph.outputs
def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone()
return res
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
......
import copy import copy
import sys import sys
import warnings import warnings
from abc import abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
...@@ -612,7 +612,7 @@ class _NoPythonOp(Op): ...@@ -612,7 +612,7 @@ class _NoPythonOp(Op):
raise NotImplementedError("No Python implementation is provided by this Op.") raise NotImplementedError("No Python implementation is provided by this Op.")
class HasInnerGraph: class HasInnerGraph(ABC):
r"""A mixin for an `Op` that contain an inner graph.""" r"""A mixin for an `Op` that contain an inner graph."""
fgraph: "FunctionGraph" fgraph: "FunctionGraph"
...@@ -621,7 +621,7 @@ class HasInnerGraph: ...@@ -621,7 +621,7 @@ class HasInnerGraph:
@property @property
@abstractmethod @abstractmethod
def fn(self) -> "Function": def fn(self) -> "Function":
"""The inner function.""" """The compiled inner-graph function."""
@property @property
@abstractmethod @abstractmethod
...@@ -633,6 +633,10 @@ class HasInnerGraph: ...@@ -633,6 +633,10 @@ class HasInnerGraph:
def inner_outputs(self) -> List[Variable]: def inner_outputs(self) -> List[Variable]:
"""The inner function's outputs.""" """The inner function's outputs."""
@abstractmethod
def clone(self) -> Op:
"""Clone the `Op` and its inner-graph."""
def get_test_value(v: Any) -> Any: def get_test_value(v: Any) -> Any:
"""Get the test value for `v`. """Get the test value for `v`.
......
...@@ -47,6 +47,7 @@ import dataclasses ...@@ -47,6 +47,7 @@ import dataclasses
import logging import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
from copy import copy
from itertools import chain, product from itertools import chain, product
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
...@@ -1495,6 +1496,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1495,6 +1496,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def inner_outputs(self): def inner_outputs(self):
return self.fgraph.outputs return self.fgraph.outputs
def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone()
return res
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
""" """
......
...@@ -9,6 +9,7 @@ from aesara.compile.builders import OpFromGraph ...@@ -9,6 +9,7 @@ from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function from aesara.compile.function import function
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad
from aesara.graph.basic import equal_computations
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
...@@ -52,6 +53,17 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -52,6 +53,17 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={}) OpFromGraph([x], [x], updates={})
def test_clone(self):
x, y, z = matrices("xyz")
ofg = OpFromGraph([x], [2 * x])
ofg_clone = ofg.clone()
assert ofg_clone.fgraph is not ofg.fgraph
assert ofg_clone.fgraph.outputs != ofg.fgraph.outputs
assert equal_computations(ofg_clone.fgraph.outputs, ofg.fgraph.outputs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
) )
......
...@@ -28,7 +28,7 @@ from aesara.compile.monitormode import MonitorMode ...@@ -28,7 +28,7 @@ from aesara.compile.monitormode import MonitorMode
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from aesara.graph.basic import Apply, ancestors from aesara.graph.basic import Apply, ancestors, equal_computations
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import MergeOptimizer from aesara.graph.opt import MergeOptimizer
...@@ -291,6 +291,20 @@ class TestScan: ...@@ -291,6 +291,20 @@ class TestScan:
assert len(set(res)) == 4 assert len(set(res)) == 4
assert len(set(z_res)) == 1 assert len(set(z_res)) == 1
def test_clone(self):
a = vector()
output, _ = scan(fn=lambda x: x**2, sequences=[a])
scan_op = output.owner.op
assert isinstance(scan_op, Scan)
scan_op_clone = scan_op.clone()
assert scan_op_clone is not scan_op
assert scan_op_clone.fgraph is not scan_op.fgraph
assert scan_op_clone.fgraph.outputs != scan_op.fgraph.outputs
assert equal_computations(scan_op_clone.fgraph.outputs, scan_op.fgraph.outputs)
@pytest.mark.skipif( @pytest.mark.skipif(
isinstance(get_default_mode(), DebugMode), isinstance(get_default_mode(), DebugMode),
reason="This test fails in DebugMode, because it is not yet picklable.", reason="This test fails in DebugMode, because it is not yet picklable.",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论