Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ff4b91dd
提交
ff4b91dd
authored
12月 16, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
12月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add type hints and move docstring to OpFromGraph constructor
上级
d7a02fa1
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
96 行增加
和
103 行删除
+96
-103
builders.py
aesara/compile/builders.py
+96
-103
没有找到文件。
aesara/compile/builders.py
浏览文件 @
ff4b91dd
"""Define new Ops from existing Ops"""
from
collections
import
OrderedDict
from
functools
import
partial
from
typing
import
List
,
Optional
import
aesara.tensor
as
aet
from
aesara.compile.function.pfunc
import
rebuild_collect_shared
...
...
@@ -87,101 +88,6 @@ class OpFromGraph(Op, HasInnerGraph):
Currently does not support ``updates`` or ``givens`` argument.
Parameters
----------
inputs: list of :class:`Variable <aesara.graph.basic.Variable>`
outputs: list of :class:`Variable <aesara.graph.basic.Variable>`
inline: bool, optional
Defaults to ``False``
``True`` : Cause the Op's original graph being used during
compilation, the Op will not be visible in the compiled
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
grad_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
This argument is mutually exclusive with lop_overrides.
``'default'`` : Do not override, use default grad() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in grad() method.
callable : Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <aesara.graph.basic.Variable>`.
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
``NullType() instance`` : Treat as non-differentiable
``DisconnectedType() instance`` : Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
lop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``'default'``.
This argument is mutually exclusive with ``grad_overrides``.
``'default'`` : Do not override, use default L_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs``, ``outputs`` and ``output_grads``
arguments as one would specify in grad() method.
callable : Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable <aesara.graph.basic.Variable>`.
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides : single or list of {'default', OpFromGraph, callable, Variable with special type}, optional
Defaults to ``default``.
``'default'`` : Do not override, use default R_op() result
OpFromGraph instance : Override with another OpFromGraph, should
accept inputs as the same order and types of ``inputs`` and ``eval_points``
arguments as one would specify in R_op() method.
callable : Should take two args: ``inputs`` and ``eval_points``.
Each argument is expected to be a list of :class:`Variable <aesara.graph.basic.Variable>`.
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds
to a specific output of R_op, length of list must be equal to number of outputs.
connection_pattern : list of list
If not ``None``, this will be used as the connection_pattern
for this op.
name : string, optional
A name for debugging purposes
\*\*kwargs : optional
Check
:func:`orig_function <aesara.compile.function.types.orig_function>`
for more arguments, only works when not inline.
.. TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try
...
...
@@ -322,16 +228,103 @@ class OpFromGraph(Op, HasInnerGraph):
def
__init__
(
self
,
inputs
,
outputs
,
inline
=
False
,
lop_overrides
=
"default"
,
grad_overrides
=
"default"
,
rop_overrides
=
"default"
,
connection_pattern
=
None
,
name
=
None
,
inputs
:
List
[
Variable
]
,
outputs
:
List
[
Variable
]
,
inline
:
bool
=
False
,
lop_overrides
:
str
=
"default"
,
grad_overrides
:
str
=
"default"
,
rop_overrides
:
str
=
"default"
,
connection_pattern
:
Optional
[
List
[
List
[
bool
]]]
=
None
,
name
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
"""
Parameters
----------
inputs
The inputs to the graph.
outputs
The outputs to the graph.
inline
Defaults to ``False``
``True`` : Cause the :class:`Op`'s original graph being used during
compilation, the :class:`Op` will not be visible in the compiled
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
grad_overrides
Defaults to ``'default'``.
This argument is mutually exclusive with ``lop_overrides``.
``'default'`` : Do not override, use default grad() result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in :meth:`Op.grad`() method.
`callable`: Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `.
lop_overrides
Defaults to ``'default'``.
This argument is mutually exclusive with ``grad_overrides``.
These options are similar to the ``grad_overrides`` above, but for
the :meth:`Op.L_op` method.
``'default'``: Do not override, use the default :meth:`Op.L_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs``,
``outputs`` and ``output_grads`` arguments as one would specify in
:meth:`Op.grad` method.
`callable`: Should take three args: ``inputs``, ``outputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable`.
Must return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable
`DisconnectedType` instance: Treat as disconnected gradient,
numerically gives zero
``list``: Each `OpFromGraph`/callable must return a single
:class:`Variable`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
rop_overrides
One of ``{'default', OpFromGraph, callable, Variable}``.
Defaults to ``'default'``.
``'default'``: Do not override, use the default :meth:`Op.R_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``eval_points``
arguments as one would specify in :meth:`Op.R_op` method.
`callable`: Should take two args: ``inputs`` and ``eval_points``.
Each argument is expected to be a list of :class:`Variable`. Must
return list of :class:`Variable`.
`NullType` instance: Treat as non-differentiable `DisconnectedType`
instance: Treat as zero since `DisconnectedType` is not yet supported
in :meth:`Op.R_op`.
``list``:
Each :class:`OpFromGraph`/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element
corresponds to a specific output of :meth:`Op.R_op`, length of list
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
name
A name for debugging purposes.
kwargs
Check :func:`orig_function` for more arguments, only works when not
inline.
"""
if
not
(
isinstance
(
inputs
,
list
)
and
isinstance
(
outputs
,
list
)):
raise
TypeError
(
"Inputs and outputs must be lists"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论