提交 c95acebd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create HasInnerGraph mixin

上级 8f692472
...@@ -18,7 +18,7 @@ from aesara.graph.basic import ( ...@@ -18,7 +18,7 @@ from aesara.graph.basic import (
) )
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.op import Op, ops_with_inner_function from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer from aesara.graph.opt import in2out, local_optimizer
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.basic_opt import ShapeFeature
...@@ -76,11 +76,11 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -76,11 +76,11 @@ def infer_shape(outs, inputs, input_shapes):
return ret return ret
class OpFromGraph(Op): class OpFromGraph(Op, HasInnerGraph):
r""" r"""
This creates an ``Op`` from inputs and outputs lists of variables. This creates an `Op` from inputs and outputs lists of variables.
The signature is similar to :func:`aesara.function <aesara.function>` The signature is similar to :func:`aesara.function <aesara.function>`
and the resulting ``Op``'s perform will do the same operation as:: and the resulting `Op`'s perform will do the same operation as::
orig_function(inputs, outputs, **kwargs) orig_function(inputs, outputs, **kwargs)
...@@ -139,8 +139,8 @@ class OpFromGraph(Op): ...@@ -139,8 +139,8 @@ class OpFromGraph(Op):
Must return list of :class:`Variable <aesara.graph.basic.Variable>`. Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable `NullType` instance: Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero `DisconnectedType` instance: Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of :class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
...@@ -160,8 +160,8 @@ class OpFromGraph(Op): ...@@ -160,8 +160,8 @@ class OpFromGraph(Op):
Must return list of :class:`Variable <aesara.graph.basic.Variable>`. Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable : Variable :
``NullType() instance`` : Treat as non-differentiable `NullType` instance: Treat as non-differentiable
``DisconnectedType() instance`` : Treat as zero since DisconnectedType is not yet supported in R_op `DisconnectedType` instance: Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds :class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds
...@@ -363,8 +363,8 @@ class OpFromGraph(Op): ...@@ -363,8 +363,8 @@ class OpFromGraph(Op):
assert not update_expr assert not update_expr
assert not shared_inputs assert not shared_inputs
self.local_inputs = local_inputs self._inner_inputs = local_inputs
self.local_outputs = local_outputs self._inner_outputs = local_outputs
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.kwargs = kwargs self.kwargs = kwargs
...@@ -411,8 +411,8 @@ class OpFromGraph(Op): ...@@ -411,8 +411,8 @@ class OpFromGraph(Op):
converts self._lop_op from user supplied form to type(self) instance converts self._lop_op from user supplied form to type(self) instance
""" """
local_inputs = self.local_inputs local_inputs = self.inner_inputs
local_outputs = self.local_outputs local_outputs = self.inner_outputs
inp_len = len(local_inputs) inp_len = len(local_inputs)
lop_op = self._lop_op lop_op = self._lop_op
...@@ -424,7 +424,7 @@ class OpFromGraph(Op): ...@@ -424,7 +424,7 @@ class OpFromGraph(Op):
) )
if self._lop_type == "grad": if self._lop_type == "grad":
needed_ninps = inp_len + len(local_outputs) needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs) ninps = len(lop_op.inner_inputs)
if needed_ninps != ninps: if needed_ninps != ninps:
raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable # make a wrapper callable
...@@ -435,7 +435,7 @@ class OpFromGraph(Op): ...@@ -435,7 +435,7 @@ class OpFromGraph(Op):
elif self._lop_type == "lop": elif self._lop_type == "lop":
# OfG can be directly used in L_op format # OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs) needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.local_inputs) ninps = len(lop_op.inner_inputs)
if needed_ninps != ninps: if needed_ninps != ninps:
raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps)) raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
self._lop_op_is_cached = True self._lop_op_is_cached = True
...@@ -551,8 +551,8 @@ class OpFromGraph(Op): ...@@ -551,8 +551,8 @@ class OpFromGraph(Op):
converts self._rop_op from user supplied form to type(self) instance converts self._rop_op from user supplied form to type(self) instance
""" """
local_inputs = self.local_inputs local_inputs = self.inner_inputs
local_outputs = self.local_outputs local_outputs = self.inner_outputs
out_len = len(local_outputs) out_len = len(local_outputs)
rop_op = self._rop_op rop_op = self._rop_op
...@@ -728,7 +728,7 @@ class OpFromGraph(Op): ...@@ -728,7 +728,7 @@ class OpFromGraph(Op):
return ret_l return ret_l
def make_node(self, *inputs): def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs) num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps: if len(inputs) != num_expected_inps:
raise ValueError( raise ValueError(
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}" f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
...@@ -741,8 +741,6 @@ class OpFromGraph(Op): ...@@ -741,8 +741,6 @@ class OpFromGraph(Op):
list(inputs) + self.shared_inputs, list(inputs) + self.shared_inputs,
[type() for type in self.output_types], [type() for type in self.output_types],
) )
apply_node.local_inputs = self.local_inputs
apply_node.local_outputs = self.local_outputs
return apply_node return apply_node
def connection_pattern(self, node): def connection_pattern(self, node):
...@@ -753,13 +751,13 @@ class OpFromGraph(Op): ...@@ -753,13 +751,13 @@ class OpFromGraph(Op):
if self._connection_pattern is not None: if self._connection_pattern is not None:
return self._connection_pattern return self._connection_pattern
inp_len = len(self.local_inputs) inp_len = len(self.inner_inputs)
out_len = len(self.local_outputs) out_len = len(self.inner_outputs)
cpmat_self = io_connection_pattern(self.local_inputs, self.local_outputs) cpmat_self = io_connection_pattern(self.inner_inputs, self.inner_outputs)
lop_op = self.get_lop_op() lop_op = self.get_lop_op()
cpmat_grad = io_connection_pattern( cpmat_grad = io_connection_pattern(
lop_op.local_inputs[inp_len:], lop_op.local_outputs lop_op.inner_inputs[inp_len:], lop_op.inner_outputs
) )
# cpmat_self |= cpmat_grad.T # cpmat_self |= cpmat_grad.T
...@@ -781,7 +779,7 @@ class OpFromGraph(Op): ...@@ -781,7 +779,7 @@ class OpFromGraph(Op):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
# TODO: Use `fgraph.shape_feature` to do this instead. # TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.local_outputs, self.local_inputs, shapes) out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)
# Clone the output shape so that shape are computed from outer inputs. # Clone the output shape so that shape are computed from outer inputs.
# Note: # Note:
...@@ -791,7 +789,7 @@ class OpFromGraph(Op): ...@@ -791,7 +789,7 @@ class OpFromGraph(Op):
# each shape call. Aesara optimizer will clean this up later, but this # each shape call. Aesara optimizer will clean this up later, but this
# will make extra work for the optimizer. # will make extra work for the optimizer.
repl = dict(zip(self.local_inputs, node.inputs)) repl = dict(zip(self.inner_inputs, node.inputs))
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)] clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl) cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
ret = [] ret = []
...@@ -806,12 +804,24 @@ class OpFromGraph(Op): ...@@ -806,12 +804,24 @@ class OpFromGraph(Op):
return ret return ret
def prepare_node(self, node, storage_map, compute_map, impl): @property
if not hasattr(self, "fn") and impl == "py": def fn(self):
self.fn = orig_function( """Lazily compile the inner function graph."""
self.local_inputs, self.local_outputs, **self.kwargs if getattr(self, "_fn", None) is not None:
) return self._fn
self.fn.trust_input = True
self._fn = orig_function(self.inner_inputs, self.inner_outputs, **self.kwargs)
self._fn.trust_input = True
return self._fn
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
...@@ -833,7 +843,7 @@ def inline_ofg_expansion(fgraph, node): ...@@ -833,7 +843,7 @@ def inline_ofg_expansion(fgraph, node):
if not op.is_inline: if not op.is_inline:
return False return False
return clone_replace( return clone_replace(
op.local_outputs, {u: v for u, v in zip(node.op.local_inputs, node.inputs)} op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
) )
...@@ -846,7 +856,3 @@ optdb.register( ...@@ -846,7 +856,3 @@ optdb.register(
"fast_compile", "fast_compile",
"fast_run", "fast_run",
) )
# Since OpFromGraph contains an Aesara compiled function,
# we should let DebugMode know about it
ops_with_inner_function[OpFromGraph] = "fn"
...@@ -32,7 +32,7 @@ from aesara.graph.basic import Variable, graph_inputs, io_toposort ...@@ -32,7 +32,7 @@ from aesara.graph.basic import Variable, graph_inputs, io_toposort
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization from aesara.graph.features import BadOptimization
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op, ops_with_inner_function from aesara.graph.op import COp, HasInnerGraph, Op
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.basic import Container, LocalLinker from aesara.link.basic import Container, LocalLinker
from aesara.link.utils import map_storage, raise_with_op from aesara.link.utils import map_storage, raise_with_op
...@@ -1104,13 +1104,10 @@ def _check_preallocated_output( ...@@ -1104,13 +1104,10 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run. # disable memory checks in that mode, since they were already run.
try: try:
changed_inner_mode = False changed_inner_mode = False
if type(getattr(node, "op", None)) in ops_with_inner_function: if isinstance(getattr(node, "op", None), HasInnerGraph):
fn_attr_name = ops_with_inner_function[type(node.op)] fn = node.op.fn
fn = getattr(node.op, fn_attr_name, None)
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"): if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
_logger.warning( _logger.warning(f"Expected aesara function not found in {node.op}.fn")
f"Expected aesara function not found in {node.op}.{fn_attr_name}"
)
else: else:
if isinstance(fn.maker.mode, DebugMode): if isinstance(fn.maker.mode, DebugMode):
backup_mode = fn.maker.mode backup_mode = fn.maker.mode
......
...@@ -32,7 +32,7 @@ from aesara.graph.basic import ( ...@@ -32,7 +32,7 @@ from aesara.graph.basic import (
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph, InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import ops_with_inner_function from aesara.graph.op import HasInnerGraph
from aesara.graph.opt_utils import is_same_graph from aesara.graph.opt_utils import is_same_graph
from aesara.graph.utils import get_variable_trace_string from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container from aesara.link.basic import Container
...@@ -548,7 +548,7 @@ class Function: ...@@ -548,7 +548,7 @@ class Function:
self.n_returned_outputs -= 1 self.n_returned_outputs -= 1
for node in self.maker.fgraph.apply_nodes: for node in self.maker.fgraph.apply_nodes:
if node.op in ops_with_inner_function: if isinstance(node.op, HasInnerGraph):
self.nodes_with_inner_function.append(node.op) self.nodes_with_inner_function.append(node.op)
def __contains__(self, item): def __contains__(self, item):
...@@ -1099,7 +1099,8 @@ class Function: ...@@ -1099,7 +1099,8 @@ class Function:
self.fn.storage_map[key][0] = None self.fn.storage_map[key][0] = None
for node in self.nodes_with_inner_function: for node in self.nodes_with_inner_function:
ops_with_inner_function[node.op].free() if hasattr(node.fn, "free"):
node.fn.free()
def get_shared(self): def get_shared(self):
""" """
......
...@@ -232,7 +232,6 @@ class PyDotFormatter: ...@@ -232,7 +232,6 @@ class PyDotFormatter:
gf = PyDotFormatter() gf = PyDotFormatter()
# Use different node prefix for sub-graphs # Use different node prefix for sub-graphs
gf.__node_prefix = __node_id gf.__node_prefix = __node_id
node.op.prepare_node(node, None, None, "py")
gf(node.op.fn, subgraph) gf(node.op.fn, subgraph)
graph.add_subgraph(subgraph) graph.add_subgraph(subgraph)
pd_node.get_attributes()["subg"] = subgraph.get_name() pd_node.get_attributes()["subg"] = subgraph.get_name()
...@@ -242,14 +241,14 @@ class PyDotFormatter: ...@@ -242,14 +241,14 @@ class PyDotFormatter:
# Inputs mapping # Inputs mapping
ext_inputs = [self.__node_id(x) for x in node.inputs] ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x) for x in node.op.local_inputs] int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
assert len(ext_inputs) == len(int_inputs) assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs)) h = format_map(zip(ext_inputs, int_inputs))
pd_node.get_attributes()["subg_map_inputs"] = h pd_node.get_attributes()["subg_map_inputs"] = h
# Outputs mapping # Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs] ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = [gf.__node_id(x) for x in node.op.local_outputs] int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
assert len(ext_outputs) == len(int_outputs) assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs)) h = format_map(zip(int_outputs, ext_outputs))
pd_node.get_attributes()["subg_map_outputs"] = h pd_node.get_attributes()["subg_map_outputs"] = h
......
...@@ -13,6 +13,7 @@ import sys ...@@ -13,6 +13,7 @@ import sys
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
ClassVar, ClassVar,
...@@ -43,6 +44,9 @@ from aesara.graph.utils import ( ...@@ -43,6 +44,9 @@ from aesara.graph.utils import (
from aesara.link.c.interface import CLinkerOp from aesara.link.c.interface import CLinkerOp
if TYPE_CHECKING:
from aesara.compile.function.types import Function
StorageMapType = List[Optional[List[Any]]] StorageMapType = List[Optional[List[Any]]]
ComputeMapType = List[bool] ComputeMapType = List[bool]
OutputStorageType = List[Optional[List[Any]]] OutputStorageType = List[Optional[List[Any]]]
...@@ -574,6 +578,25 @@ class Op(MetaObject): ...@@ -574,6 +578,25 @@ class Op(MetaObject):
return getattr(type(self), "__name__", super().__str__()) return getattr(type(self), "__name__", super().__str__())
class HasInnerGraph:
r"""A mixin for an `Op` that contain an inner graph."""
@property
@abstractmethod
def fn(self) -> "Function":
"""The inner function."""
@property
@abstractmethod
def inner_inputs(self) -> List[Variable]:
"""The inner function's inputs."""
@property
@abstractmethod
def inner_outputs(self) -> List[Variable]:
"""The inner function's outputs."""
class COp(Op, CLinkerOp): class COp(Op, CLinkerOp):
"""An `Op` with a C implementation.""" """An `Op` with a C implementation."""
...@@ -767,22 +790,6 @@ def get_test_values(*args: Variable) -> Union[Any, List[Any]]: ...@@ -767,22 +790,6 @@ def get_test_values(*args: Variable) -> Union[Any, List[Any]]:
return [tuple(rval)] return [tuple(rval)]
ops_with_inner_function: Dict[Op, Text] = {}
r"""
Registry of `Op`\s that have an inner compiled Aesara function.
The keys are `Op` classes (not instances), and values are the name of the
attribute that contains the function. For instance, if the function is
``self.fn``, the value will be ``'fn'``.
We need that to be able not to run debug checks a number of times that is
exponential in the nesting level of those `Op`\s.
For instance, `Scan` will be registered here.
"""
class OpenMPOp(COp): class OpenMPOp(COp):
r"""Base class for `Op`\s using OpenMP. r"""Base class for `Op`\s using OpenMP.
......
...@@ -73,7 +73,7 @@ from aesara.graph.basic import ( ...@@ -73,7 +73,7 @@ from aesara.graph.basic import (
) )
from aesara.graph.features import NoOutputFromInplace from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import MissingInputError from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op, ops_with_inner_function from aesara.graph.op import HasInnerGraph, Op
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
...@@ -570,7 +570,7 @@ class ScanMethodsMixin: ...@@ -570,7 +570,7 @@ class ScanMethodsMixin:
) )
class Scan(Op, ScanMethodsMixin): class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def __init__( def __init__(
self, self,
inputs: List[Variable], inputs: List[Variable],
...@@ -3126,11 +3126,6 @@ class Scan(Op, ScanMethodsMixin): ...@@ -3126,11 +3126,6 @@ class Scan(Op, ScanMethodsMixin):
return final_outs return final_outs
# Since Scan is an op that contains an Aesara compiled function, it is
# useful to let DebugMode know about it.
ops_with_inner_function[Scan] = "fn"
@register_profiler_printer @register_profiler_printer
def profile_printer( def profile_printer(
message, compile_time, fct_call_time, apply_time, apply_cimpl, outputs_size, file message, compile_time, fct_call_time, apply_time, apply_cimpl, outputs_size, file
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论