提交 c95acebd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create HasInnerGraph mixin

上级 8f692472
......@@ -18,7 +18,7 @@ from aesara.graph.basic import (
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
from aesara.graph.op import Op, ops_with_inner_function
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.opt import in2out, local_optimizer
from aesara.tensor.basic_opt import ShapeFeature
......@@ -76,11 +76,11 @@ def infer_shape(outs, inputs, input_shapes):
return ret
class OpFromGraph(Op):
class OpFromGraph(Op, HasInnerGraph):
r"""
This creates an ``Op`` from inputs and outputs lists of variables.
This creates an `Op` from inputs and outputs lists of variables.
The signature is similar to :func:`aesara.function <aesara.function>`
and the resulting ``Op``'s perform will do the same operation as::
and the resulting `Op`'s perform will do the same operation as::
orig_function(inputs, outputs, **kwargs)
......@@ -139,8 +139,8 @@ class OpFromGraph(Op):
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
......@@ -160,8 +160,8 @@ class OpFromGraph(Op):
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as zero since DisconnectedType is not yet supported in R_op
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds
......@@ -363,8 +363,8 @@ class OpFromGraph(Op):
assert not update_expr
assert not shared_inputs
self.local_inputs = local_inputs
self.local_outputs = local_outputs
self._inner_inputs = local_inputs
self._inner_outputs = local_outputs
self.inputs = inputs
self.outputs = outputs
self.kwargs = kwargs
......@@ -411,8 +411,8 @@ class OpFromGraph(Op):
converts self._lop_op from user supplied form to type(self) instance
"""
local_inputs = self.local_inputs
local_outputs = self.local_outputs
local_inputs = self.inner_inputs
local_outputs = self.inner_outputs
inp_len = len(local_inputs)
lop_op = self._lop_op
......@@ -424,7 +424,7 @@ class OpFromGraph(Op):
)
if self._lop_type == "grad":
needed_ninps = inp_len + len(local_outputs)
ninps = len(lop_op.local_inputs)
ninps = len(lop_op.inner_inputs)
if needed_ninps != ninps:
raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
# make a wrapper callable
......@@ -435,7 +435,7 @@ class OpFromGraph(Op):
elif self._lop_type == "lop":
# OfG can be directly used in L_op format
needed_ninps = inp_len + 2 * len(local_outputs)
ninps = len(lop_op.local_inputs)
ninps = len(lop_op.inner_inputs)
if needed_ninps != ninps:
raise ValueError(self.OV_INP_LEN_ERR_MSG % (needed_ninps, ninps))
self._lop_op_is_cached = True
......@@ -551,8 +551,8 @@ class OpFromGraph(Op):
converts self._rop_op from user supplied form to type(self) instance
"""
local_inputs = self.local_inputs
local_outputs = self.local_outputs
local_inputs = self.inner_inputs
local_outputs = self.inner_outputs
out_len = len(local_outputs)
rop_op = self._rop_op
......@@ -728,7 +728,7 @@ class OpFromGraph(Op):
return ret_l
def make_node(self, *inputs):
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps:
raise ValueError(
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
......@@ -741,8 +741,6 @@ class OpFromGraph(Op):
list(inputs) + self.shared_inputs,
[type() for type in self.output_types],
)
apply_node.local_inputs = self.local_inputs
apply_node.local_outputs = self.local_outputs
return apply_node
def connection_pattern(self, node):
......@@ -753,13 +751,13 @@ class OpFromGraph(Op):
if self._connection_pattern is not None:
return self._connection_pattern
inp_len = len(self.local_inputs)
out_len = len(self.local_outputs)
cpmat_self = io_connection_pattern(self.local_inputs, self.local_outputs)
inp_len = len(self.inner_inputs)
out_len = len(self.inner_outputs)
cpmat_self = io_connection_pattern(self.inner_inputs, self.inner_outputs)
lop_op = self.get_lop_op()
cpmat_grad = io_connection_pattern(
lop_op.local_inputs[inp_len:], lop_op.local_outputs
lop_op.inner_inputs[inp_len:], lop_op.inner_outputs
)
# cpmat_self |= cpmat_grad.T
......@@ -781,7 +779,7 @@ class OpFromGraph(Op):
def infer_shape(self, fgraph, node, shapes):
# TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.local_outputs, self.local_inputs, shapes)
out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)
# Clone the output shape so that shape are computed from outer inputs.
# Note:
......@@ -791,7 +789,7 @@ class OpFromGraph(Op):
# each shape call. Aesara optimizer will clean this up later, but this
# will make extra work for the optimizer.
repl = dict(zip(self.local_inputs, node.inputs))
repl = dict(zip(self.inner_inputs, node.inputs))
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
ret = []
......@@ -806,12 +804,24 @@ class OpFromGraph(Op):
return ret
def prepare_node(self, node, storage_map, compute_map, impl):
if not hasattr(self, "fn") and impl == "py":
self.fn = orig_function(
self.local_inputs, self.local_outputs, **self.kwargs
)
self.fn.trust_input = True
@property
def fn(self):
"""Lazily compile the inner function graph."""
if getattr(self, "_fn", None) is not None:
return self._fn
self._fn = orig_function(self.inner_inputs, self.inner_outputs, **self.kwargs)
self._fn.trust_input = True
return self._fn
@property
def inner_inputs(self):
return self._inner_inputs
@property
def inner_outputs(self):
return self._inner_outputs
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
......@@ -833,7 +843,7 @@ def inline_ofg_expansion(fgraph, node):
if not op.is_inline:
return False
return clone_replace(
op.local_outputs, {u: v for u, v in zip(node.op.local_inputs, node.inputs)}
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
)
......@@ -846,7 +856,3 @@ optdb.register(
"fast_compile",
"fast_run",
)
# Since OpFromGraph contains an Aesara compiled function,
# we should let DebugMode know about it
ops_with_inner_function[OpFromGraph] = "fn"
......@@ -32,7 +32,7 @@ from aesara.graph.basic import Variable, graph_inputs, io_toposort
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op, ops_with_inner_function
from aesara.graph.op import COp, HasInnerGraph, Op
from aesara.graph.utils import MethodNotDefined
from aesara.link.basic import Container, LocalLinker
from aesara.link.utils import map_storage, raise_with_op
......@@ -1104,13 +1104,10 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
try:
changed_inner_mode = False
if type(getattr(node, "op", None)) in ops_with_inner_function:
fn_attr_name = ops_with_inner_function[type(node.op)]
fn = getattr(node.op, fn_attr_name, None)
if isinstance(getattr(node, "op", None), HasInnerGraph):
fn = node.op.fn
if not fn or not hasattr(fn, "maker") or not hasattr(fn.maker, "mode"):
_logger.warning(
f"Expected aesara function not found in {node.op}.{fn_attr_name}"
)
_logger.warning(f"Expected aesara function not found in {node.op}.fn")
else:
if isinstance(fn.maker.mode, DebugMode):
backup_mode = fn.maker.mode
......
......@@ -32,7 +32,7 @@ from aesara.graph.basic import (
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import ops_with_inner_function
from aesara.graph.op import HasInnerGraph
from aesara.graph.opt_utils import is_same_graph
from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container
......@@ -548,7 +548,7 @@ class Function:
self.n_returned_outputs -= 1
for node in self.maker.fgraph.apply_nodes:
if node.op in ops_with_inner_function:
if isinstance(node.op, HasInnerGraph):
self.nodes_with_inner_function.append(node.op)
def __contains__(self, item):
......@@ -1099,7 +1099,8 @@ class Function:
self.fn.storage_map[key][0] = None
for node in self.nodes_with_inner_function:
ops_with_inner_function[node.op].free()
if hasattr(node.fn, "free"):
node.fn.free()
def get_shared(self):
"""
......
......@@ -232,7 +232,6 @@ class PyDotFormatter:
gf = PyDotFormatter()
# Use different node prefix for sub-graphs
gf.__node_prefix = __node_id
node.op.prepare_node(node, None, None, "py")
gf(node.op.fn, subgraph)
graph.add_subgraph(subgraph)
pd_node.get_attributes()["subg"] = subgraph.get_name()
......@@ -242,14 +241,14 @@ class PyDotFormatter:
# Inputs mapping
ext_inputs = [self.__node_id(x) for x in node.inputs]
int_inputs = [gf.__node_id(x) for x in node.op.local_inputs]
int_inputs = [gf.__node_id(x) for x in node.op.inner_inputs]
assert len(ext_inputs) == len(int_inputs)
h = format_map(zip(ext_inputs, int_inputs))
pd_node.get_attributes()["subg_map_inputs"] = h
# Outputs mapping
ext_outputs = [self.__node_id(x) for x in node.outputs]
int_outputs = [gf.__node_id(x) for x in node.op.local_outputs]
int_outputs = [gf.__node_id(x) for x in node.op.inner_outputs]
assert len(ext_outputs) == len(int_outputs)
h = format_map(zip(int_outputs, ext_outputs))
pd_node.get_attributes()["subg_map_outputs"] = h
......
......@@ -13,6 +13,7 @@ import sys
import warnings
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
......@@ -43,6 +44,9 @@ from aesara.graph.utils import (
from aesara.link.c.interface import CLinkerOp
if TYPE_CHECKING:
from aesara.compile.function.types import Function
StorageMapType = List[Optional[List[Any]]]
ComputeMapType = List[bool]
OutputStorageType = List[Optional[List[Any]]]
......@@ -574,6 +578,25 @@ class Op(MetaObject):
return getattr(type(self), "__name__", super().__str__())
class HasInnerGraph:
r"""A mixin for an `Op` that contain an inner graph."""
@property
@abstractmethod
def fn(self) -> "Function":
"""The inner function."""
@property
@abstractmethod
def inner_inputs(self) -> List[Variable]:
"""The inner function's inputs."""
@property
@abstractmethod
def inner_outputs(self) -> List[Variable]:
"""The inner function's outputs."""
class COp(Op, CLinkerOp):
"""An `Op` with a C implementation."""
......@@ -767,22 +790,6 @@ def get_test_values(*args: Variable) -> Union[Any, List[Any]]:
return [tuple(rval)]
ops_with_inner_function: Dict[Op, Text] = {}
r"""
Registry of `Op`\s that have an inner compiled Aesara function.
The keys are `Op` classes (not instances), and values are the name of the
attribute that contains the function. For instance, if the function is
``self.fn``, the value will be ``'fn'``.
We need that to be able not to run debug checks a number of times that is
exponential in the nesting level of those `Op`\s.
For instance, `Scan` will be registered here.
"""
class OpenMPOp(COp):
r"""Base class for `Op`\s using OpenMP.
......
......@@ -73,7 +73,7 @@ from aesara.graph.basic import (
)
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op, ops_with_inner_function
from aesara.graph.op import HasInnerGraph, Op
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
......@@ -570,7 +570,7 @@ class ScanMethodsMixin:
)
class Scan(Op, ScanMethodsMixin):
class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def __init__(
self,
inputs: List[Variable],
......@@ -3126,11 +3126,6 @@ class Scan(Op, ScanMethodsMixin):
return final_outs
# Since Scan is an op that contains an Aesara compiled function, it is
# useful to let DebugMode know about it.
ops_with_inner_function[Scan] = "fn"
@register_profiler_printer
def profile_printer(
message, compile_time, fct_call_time, apply_time, apply_cimpl, outputs_size, file
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论