Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c95acebd
提交
c95acebd
authored
9月 15, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
9月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Create HasInnerGraph mixin
上级
8f692472
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
77 行增加
和
72 行删除
+77
-72
builders.py
aesara/compile/builders.py
+42
-36
debugmode.py
aesara/compile/debugmode.py
+4
-7
types.py
aesara/compile/function/types.py
+4
-3
formatting.py
aesara/d3viz/formatting.py
+2
-3
op.py
aesara/graph/op.py
+23
-16
op.py
aesara/scan/op.py
+2
-7
没有找到文件。
aesara/compile/builders.py
浏览文件 @
c95acebd
...
...
@@ -18,7 +18,7 @@ from aesara.graph.basic import (
)
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.null_type
import
NullType
from
aesara.graph.op
import
Op
,
ops_with_inner_function
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.opt
import
in2out
,
local_optimizer
from
aesara.tensor.basic_opt
import
ShapeFeature
...
...
@@ -76,11 +76,11 @@ def infer_shape(outs, inputs, input_shapes):
return
ret
class
OpFromGraph
(
Op
):
class
OpFromGraph
(
Op
,
HasInnerGraph
):
r"""
This creates an `
`Op`
` from inputs and outputs lists of variables.
This creates an `
Op
` from inputs and outputs lists of variables.
The signature is similar to :func:`aesara.function <aesara.function>`
and the resulting `
`Op`
`'s perform will do the same operation as::
and the resulting `
Op
`'s perform will do the same operation as::
orig_function(inputs, outputs, **kwargs)
...
...
@@ -139,8 +139,8 @@ class OpFromGraph(Op):
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
`
`NullType() instance``
: Treat as non-differentiable
`
`DisconnectedType() instance``
: Treat as disconnected gradient, numerically gives zero
`
NullType` instance
: Treat as non-differentiable
`
DisconnectedType` instance
: Treat as disconnected gradient, numerically gives zero
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds to gradient of
...
...
@@ -160,8 +160,8 @@ class OpFromGraph(Op):
Must return list of :class:`Variable <aesara.graph.basic.Variable>`.
Variable :
`
`NullType() instance``
: Treat as non-differentiable
`
`DisconnectedType() instance``
: Treat as zero since DisconnectedType is not yet supported in R_op
`
NullType` instance
: Treat as non-differentiable
`
DisconnectedType` instance
: Treat as zero since DisconnectedType is not yet supported in R_op
list: Each OpFromGraph/callable must return a single
:class:`Variable <aesara.graph.basic.Variable>`. Each list element corresponds
...
...
@@ -363,8 +363,8 @@ class OpFromGraph(Op):
assert
not
update_expr
assert
not
shared_inputs
self
.
local
_inputs
=
local_inputs
self
.
local
_outputs
=
local_outputs
self
.
_inner
_inputs
=
local_inputs
self
.
_inner
_outputs
=
local_outputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
kwargs
=
kwargs
...
...
@@ -411,8 +411,8 @@ class OpFromGraph(Op):
converts self._lop_op from user supplied form to type(self) instance
"""
local_inputs
=
self
.
local
_inputs
local_outputs
=
self
.
local
_outputs
local_inputs
=
self
.
inner
_inputs
local_outputs
=
self
.
inner
_outputs
inp_len
=
len
(
local_inputs
)
lop_op
=
self
.
_lop_op
...
...
@@ -424,7 +424,7 @@ class OpFromGraph(Op):
)
if
self
.
_lop_type
==
"grad"
:
needed_ninps
=
inp_len
+
len
(
local_outputs
)
ninps
=
len
(
lop_op
.
local
_inputs
)
ninps
=
len
(
lop_op
.
inner
_inputs
)
if
needed_ninps
!=
ninps
:
raise
ValueError
(
self
.
OV_INP_LEN_ERR_MSG
%
(
needed_ninps
,
ninps
))
# make a wrapper callable
...
...
@@ -435,7 +435,7 @@ class OpFromGraph(Op):
elif
self
.
_lop_type
==
"lop"
:
# OfG can be directly used in L_op format
needed_ninps
=
inp_len
+
2
*
len
(
local_outputs
)
ninps
=
len
(
lop_op
.
local
_inputs
)
ninps
=
len
(
lop_op
.
inner
_inputs
)
if
needed_ninps
!=
ninps
:
raise
ValueError
(
self
.
OV_INP_LEN_ERR_MSG
%
(
needed_ninps
,
ninps
))
self
.
_lop_op_is_cached
=
True
...
...
@@ -551,8 +551,8 @@ class OpFromGraph(Op):
converts self._rop_op from user supplied form to type(self) instance
"""
local_inputs
=
self
.
local
_inputs
local_outputs
=
self
.
local
_outputs
local_inputs
=
self
.
inner
_inputs
local_outputs
=
self
.
inner
_outputs
out_len
=
len
(
local_outputs
)
rop_op
=
self
.
_rop_op
...
...
@@ -728,7 +728,7 @@ class OpFromGraph(Op):
return
ret_l
def
make_node
(
self
,
*
inputs
):
num_expected_inps
=
len
(
self
.
local
_inputs
)
-
len
(
self
.
shared_inputs
)
num_expected_inps
=
len
(
self
.
inner
_inputs
)
-
len
(
self
.
shared_inputs
)
if
len
(
inputs
)
!=
num_expected_inps
:
raise
ValueError
(
f
"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
...
...
@@ -741,8 +741,6 @@ class OpFromGraph(Op):
list
(
inputs
)
+
self
.
shared_inputs
,
[
type
()
for
type
in
self
.
output_types
],
)
apply_node
.
local_inputs
=
self
.
local_inputs
apply_node
.
local_outputs
=
self
.
local_outputs
return
apply_node
def
connection_pattern
(
self
,
node
):
...
...
@@ -753,13 +751,13 @@ class OpFromGraph(Op):
if
self
.
_connection_pattern
is
not
None
:
return
self
.
_connection_pattern
inp_len
=
len
(
self
.
local
_inputs
)
out_len
=
len
(
self
.
local
_outputs
)
cpmat_self
=
io_connection_pattern
(
self
.
local_inputs
,
self
.
local
_outputs
)
inp_len
=
len
(
self
.
inner
_inputs
)
out_len
=
len
(
self
.
inner
_outputs
)
cpmat_self
=
io_connection_pattern
(
self
.
inner_inputs
,
self
.
inner
_outputs
)
lop_op
=
self
.
get_lop_op
()
cpmat_grad
=
io_connection_pattern
(
lop_op
.
local_inputs
[
inp_len
:],
lop_op
.
local
_outputs
lop_op
.
inner_inputs
[
inp_len
:],
lop_op
.
inner
_outputs
)
# cpmat_self |= cpmat_grad.T
...
...
@@ -781,7 +779,7 @@ class OpFromGraph(Op):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
# TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes
=
infer_shape
(
self
.
local_outputs
,
self
.
local
_inputs
,
shapes
)
out_shapes
=
infer_shape
(
self
.
inner_outputs
,
self
.
inner
_inputs
,
shapes
)
# Clone the output shape so that shape are computed from outer inputs.
# Note:
...
...
@@ -791,7 +789,7 @@ class OpFromGraph(Op):
# each shape call. Aesara optimizer will clean this up later, but this
# will make extra work for the optimizer.
repl
=
dict
(
zip
(
self
.
local
_inputs
,
node
.
inputs
))
repl
=
dict
(
zip
(
self
.
inner
_inputs
,
node
.
inputs
))
clone_out_shapes
=
[
s
for
s
in
out_shapes
if
isinstance
(
s
,
tuple
)]
cloned
=
clone_replace
(
sum
(
clone_out_shapes
,
()),
replace
=
repl
)
ret
=
[]
...
...
@@ -806,12 +804,24 @@ class OpFromGraph(Op):
return
ret
def
prepare_node
(
self
,
node
,
storage_map
,
compute_map
,
impl
):
if
not
hasattr
(
self
,
"fn"
)
and
impl
==
"py"
:
self
.
fn
=
orig_function
(
self
.
local_inputs
,
self
.
local_outputs
,
**
self
.
kwargs
)
self
.
fn
.
trust_input
=
True
@property
def
fn
(
self
):
"""Lazily compile the inner function graph."""
if
getattr
(
self
,
"_fn"
,
None
)
is
not
None
:
return
self
.
_fn
self
.
_fn
=
orig_function
(
self
.
inner_inputs
,
self
.
inner_outputs
,
**
self
.
kwargs
)
self
.
_fn
.
trust_input
=
True
return
self
.
_fn
@property
def
inner_inputs
(
self
):
return
self
.
_inner_inputs
@property
def
inner_outputs
(
self
):
return
self
.
_inner_outputs
def
perform
(
self
,
node
,
inputs
,
outputs
):
variables
=
self
.
fn
(
*
inputs
)
...
...
@@ -833,7 +843,7 @@ def inline_ofg_expansion(fgraph, node):
if
not
op
.
is_inline
:
return
False
return
clone_replace
(
op
.
local_outputs
,
{
u
:
v
for
u
,
v
in
zip
(
node
.
op
.
local
_inputs
,
node
.
inputs
)}
op
.
inner_outputs
,
{
u
:
v
for
u
,
v
in
zip
(
op
.
inner
_inputs
,
node
.
inputs
)}
)
...
...
@@ -846,7 +856,3 @@ optdb.register(
"fast_compile"
,
"fast_run"
,
)
# Since OpFromGraph contains an Aesara compiled function,
# we should let DebugMode know about it
ops_with_inner_function
[
OpFromGraph
]
=
"fn"
aesara/compile/debugmode.py
浏览文件 @
c95acebd
...
...
@@ -32,7 +32,7 @@ from aesara.graph.basic import Variable, graph_inputs, io_toposort
from
aesara.graph.destroyhandler
import
DestroyHandler
from
aesara.graph.features
import
BadOptimization
from
aesara.graph.fg
import
InconsistencyError
from
aesara.graph.op
import
COp
,
Op
,
ops_with_inner_function
from
aesara.graph.op
import
COp
,
HasInnerGraph
,
Op
from
aesara.graph.utils
import
MethodNotDefined
from
aesara.link.basic
import
Container
,
LocalLinker
from
aesara.link.utils
import
map_storage
,
raise_with_op
...
...
@@ -1104,13 +1104,10 @@ def _check_preallocated_output(
# disable memory checks in that mode, since they were already run.
try
:
changed_inner_mode
=
False
if
type
(
getattr
(
node
,
"op"
,
None
))
in
ops_with_inner_function
:
fn_attr_name
=
ops_with_inner_function
[
type
(
node
.
op
)]
fn
=
getattr
(
node
.
op
,
fn_attr_name
,
None
)
if
isinstance
(
getattr
(
node
,
"op"
,
None
),
HasInnerGraph
):
fn
=
node
.
op
.
fn
if
not
fn
or
not
hasattr
(
fn
,
"maker"
)
or
not
hasattr
(
fn
.
maker
,
"mode"
):
_logger
.
warning
(
f
"Expected aesara function not found in {node.op}.{fn_attr_name}"
)
_logger
.
warning
(
f
"Expected aesara function not found in {node.op}.fn"
)
else
:
if
isinstance
(
fn
.
maker
.
mode
,
DebugMode
):
backup_mode
=
fn
.
maker
.
mode
...
...
aesara/compile/function/types.py
浏览文件 @
c95acebd
...
...
@@ -32,7 +32,7 @@ from aesara.graph.basic import (
from
aesara.graph.destroyhandler
import
DestroyHandler
from
aesara.graph.features
import
PreserveVariableAttributes
from
aesara.graph.fg
import
FunctionGraph
,
InconsistencyError
from
aesara.graph.op
import
ops_with_inner_function
from
aesara.graph.op
import
HasInnerGraph
from
aesara.graph.opt_utils
import
is_same_graph
from
aesara.graph.utils
import
get_variable_trace_string
from
aesara.link.basic
import
Container
...
...
@@ -548,7 +548,7 @@ class Function:
self
.
n_returned_outputs
-=
1
for
node
in
self
.
maker
.
fgraph
.
apply_nodes
:
if
node
.
op
in
ops_with_inner_function
:
if
isinstance
(
node
.
op
,
HasInnerGraph
)
:
self
.
nodes_with_inner_function
.
append
(
node
.
op
)
def
__contains__
(
self
,
item
):
...
...
@@ -1099,7 +1099,8 @@ class Function:
self
.
fn
.
storage_map
[
key
][
0
]
=
None
for
node
in
self
.
nodes_with_inner_function
:
ops_with_inner_function
[
node
.
op
]
.
free
()
if
hasattr
(
node
.
fn
,
"free"
):
node
.
fn
.
free
()
def
get_shared
(
self
):
"""
...
...
aesara/d3viz/formatting.py
浏览文件 @
c95acebd
...
...
@@ -232,7 +232,6 @@ class PyDotFormatter:
gf
=
PyDotFormatter
()
# Use different node prefix for sub-graphs
gf
.
__node_prefix
=
__node_id
node
.
op
.
prepare_node
(
node
,
None
,
None
,
"py"
)
gf
(
node
.
op
.
fn
,
subgraph
)
graph
.
add_subgraph
(
subgraph
)
pd_node
.
get_attributes
()[
"subg"
]
=
subgraph
.
get_name
()
...
...
@@ -242,14 +241,14 @@ class PyDotFormatter:
# Inputs mapping
ext_inputs
=
[
self
.
__node_id
(
x
)
for
x
in
node
.
inputs
]
int_inputs
=
[
gf
.
__node_id
(
x
)
for
x
in
node
.
op
.
local
_inputs
]
int_inputs
=
[
gf
.
__node_id
(
x
)
for
x
in
node
.
op
.
inner
_inputs
]
assert
len
(
ext_inputs
)
==
len
(
int_inputs
)
h
=
format_map
(
zip
(
ext_inputs
,
int_inputs
))
pd_node
.
get_attributes
()[
"subg_map_inputs"
]
=
h
# Outputs mapping
ext_outputs
=
[
self
.
__node_id
(
x
)
for
x
in
node
.
outputs
]
int_outputs
=
[
gf
.
__node_id
(
x
)
for
x
in
node
.
op
.
local
_outputs
]
int_outputs
=
[
gf
.
__node_id
(
x
)
for
x
in
node
.
op
.
inner
_outputs
]
assert
len
(
ext_outputs
)
==
len
(
int_outputs
)
h
=
format_map
(
zip
(
int_outputs
,
ext_outputs
))
pd_node
.
get_attributes
()[
"subg_map_outputs"
]
=
h
...
...
aesara/graph/op.py
浏览文件 @
c95acebd
...
...
@@ -13,6 +13,7 @@ import sys
import
warnings
from
abc
import
abstractmethod
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
...
...
@@ -43,6 +44,9 @@ from aesara.graph.utils import (
from
aesara.link.c.interface
import
CLinkerOp
if
TYPE_CHECKING
:
from
aesara.compile.function.types
import
Function
StorageMapType
=
List
[
Optional
[
List
[
Any
]]]
ComputeMapType
=
List
[
bool
]
OutputStorageType
=
List
[
Optional
[
List
[
Any
]]]
...
...
@@ -574,6 +578,25 @@ class Op(MetaObject):
return
getattr
(
type
(
self
),
"__name__"
,
super
()
.
__str__
())
class
HasInnerGraph
:
r"""A mixin for an `Op` that contain an inner graph."""
@property
@abstractmethod
def
fn
(
self
)
->
"Function"
:
"""The inner function."""
@property
@abstractmethod
def
inner_inputs
(
self
)
->
List
[
Variable
]:
"""The inner function's inputs."""
@property
@abstractmethod
def
inner_outputs
(
self
)
->
List
[
Variable
]:
"""The inner function's outputs."""
class
COp
(
Op
,
CLinkerOp
):
"""An `Op` with a C implementation."""
...
...
@@ -767,22 +790,6 @@ def get_test_values(*args: Variable) -> Union[Any, List[Any]]:
return
[
tuple
(
rval
)]
ops_with_inner_function
:
Dict
[
Op
,
Text
]
=
{}
r"""
Registry of `Op`\s that have an inner compiled Aesara function.
The keys are `Op` classes (not instances), and values are the name of the
attribute that contains the function. For instance, if the function is
``self.fn``, the value will be ``'fn'``.
We need that to be able not to run debug checks a number of times that is
exponential in the nesting level of those `Op`\s.
For instance, `Scan` will be registered here.
"""
class
OpenMPOp
(
COp
):
r"""Base class for `Op`\s using OpenMP.
...
...
aesara/scan/op.py
浏览文件 @
c95acebd
...
...
@@ -73,7 +73,7 @@ from aesara.graph.basic import (
)
from
aesara.graph.features
import
NoOutputFromInplace
from
aesara.graph.fg
import
MissingInputError
from
aesara.graph.op
import
Op
,
ops_with_inner_function
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.link.c.basic
import
CLinker
from
aesara.link.c.exceptions
import
MissingGXX
from
aesara.link.utils
import
raise_with_op
...
...
@@ -570,7 +570,7 @@ class ScanMethodsMixin:
)
class
Scan
(
Op
,
ScanMethodsMixin
):
class
Scan
(
Op
,
ScanMethodsMixin
,
HasInnerGraph
):
def
__init__
(
self
,
inputs
:
List
[
Variable
],
...
...
@@ -3126,11 +3126,6 @@ class Scan(Op, ScanMethodsMixin):
return
final_outs
# Since Scan is an op that contains an Aesara compiled function, it is
# useful to let DebugMode know about it.
ops_with_inner_function
[
Scan
]
=
"fn"
@register_profiler_printer
def
profile_printer
(
message
,
compile_time
,
fct_call_time
,
apply_time
,
apply_cimpl
,
outputs_size
,
file
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论