提交 ff4b91dd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add type hints and move docstring to OpFromGraph constructor

上级 d7a02fa1
"""Define new Ops from existing Ops"""
from collections import OrderedDict
from functools import partial
from typing import List, Optional
import aesara.tensor as aet
from aesara.compile.function.pfunc import rebuild_collect_shared
......@@ -87,101 +88,6 @@ class OpFromGraph(Op, HasInnerGraph):
Currently does not support ``updates`` or ``givens`` argument.
Parameters
----------
inputs: list of :class:`Variable <aesara.graph.basic.Variable>`
outputs: list of :class:`Variable <aesara.graph.basic.Variable>`
inline: bool, optional
Defaults to ``False``
``True`` : Cause the Op's original graph being used during
compilation, the Op will not be visible in the compiled
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
grad_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
This argument is mutually exclusive with lop_overrides.
``'default'`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in grad() method.
callable : Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <aesara.graph.basic.Variable>`.
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
This argument is mutually exclusive with ``grad_overrides``.
``'default'`` : Do not override, use default L_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs``, ``outputs`` and ``output_grads``
arguments as one would specify in grad() method.
callable : Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <aesara.graph.basic.Variable>`.
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``.
``'default'`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs`` and ``eval_points``
arguments as one would specify in R_op() method.
callable : Should take two args: ``inputs`` and ``eval_points``.
Each argument is expected to be a list of :class:`Variable <aesara.graph.basic.Variable>`.
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs.
connection_pattern : list of list
If not ``None``, this will be used as the connection_pattern
for this op.
name : string, optional
A name for debugging purposes
\*\*kwargs : optional
Check
:func:`orig_function <aesara.compile.function.types.orig_function>`
for more arguments, only works when not inline.
.. TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try
......@@ -322,16 +228,103 @@ class OpFromGraph(Op, HasInnerGraph):
def __init__(
self,
inputs,
outputs,
inline=False,
lop_overrides="default",
grad_overrides="default",
rop_overrides="default",
connection_pattern=None,
name=None,
inputs: List[Variable],
outputs: List[Variable],
inline: bool = False,
lop_overrides: str = "default",
grad_overrides: str = "default",
rop_overrides: str = "default",
connection_pattern: Optional[List[List[bool]]] = None,
name: Optional[str] = None,
**kwargs,
):
"""
Parameters
----------
inputs
The inputs to the graph.
outputs
The outputs to the graph.
inline
Defaults to ``False``
``True`` : Cause the :class:`Op`'s original graph being used during
compilation, the :class:`Op` will not be visible in the compiled
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
grad_overrides
Defaults to ``'default'``.
This argument is mutually exclusive with ``lop_overrides``.
``'default'`` : Do not override, use default grad() result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in :meth:`Op.grad`() method.
`callable`: Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `.
lop_overrides
Defaults to ``'default'``.
This argument is mutually exclusive with ``grad_overrides``.
These options are similar to the ``grad_overrides`` above, but for
the :meth:`Op.L_op` method.
``'default'``: Do not override, use the default :meth:`Op.L_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs``,
``outputs`` and ``output_grads`` arguments as one would specify in
:meth:`Op.grad` method.
`callable`: Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable`.
Must return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient,
numerically gives zero
``list``: Each `OpFromGraph`/callable must return a single
:class:`Variable`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides
One of ``{'default', OpFromGraph, callable, Variable}``.
Defaults to ``'default'``.
``'default'``: Do not override, use the default :meth:`Op.R_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``eval_points``
arguments as one would specify in :meth:`Op.R_op` method.
`callable`: Should take two args: ``inputs`` and ``eval_points``.
Each argument is expected to be a list of :class:`Variable`. Must
return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable `DisconnectedType`
instance: Treat as zero since `DisconnectedType` is not yet supported
in :meth:`Op.R_op`.
``list``:
Each :class:`OpFromGraph`/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element
corresponds to a specific output of :meth:`Op.R_op`, length of list
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
name
A name for debugging purposes.
kwargs
Check :func:`orig_function` for more arguments, only works when not
inline.
"""
if not (isinstance(inputs, list) and isinstance(outputs, list)):
raise TypeError("Inputs and outputs must be lists")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论