Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f8d06511
提交
f8d06511
authored
12月 05, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
12月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use inheritance in local optimizer Op tracking
This commit introduces a `LocalOptTracker` object that performs an MRO-based lookup of `LocalOptimizer`s that track `Op` types.
上级
d475421b
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
152 行增加
和
52 行删除
+152
-52
optdb.py
aesara/gpuarray/optdb.py
+2
-2
opt.py
aesara/graph/opt.py
+93
-50
test_opt.py
tests/graph/test_opt.py
+57
-0
没有找到文件。
aesara/gpuarray/optdb.py
浏览文件 @
f8d06511
...
...
@@ -29,8 +29,8 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
def
transform
(
self
,
fgraph
,
op
,
context_name
,
inputs
,
outputs
):
if
len
(
self
.
opts
)
==
0
:
return
opts
=
self
.
track_map
[
type
(
op
)]
+
self
.
track_map
[
op
]
+
self
.
track_map
[
None
]
for
opt
in
opts
:
for
opt
in
self
.
tracker
.
get_trackers
(
op
)
:
opt_start
=
time
.
time
()
new_repl
=
opt
.
transform
(
fgraph
,
op
,
context_name
,
inputs
,
outputs
)
opt_finish
=
time
.
time
()
...
...
aesara/graph/opt.py
浏览文件 @
f8d06511
...
...
@@ -6,6 +6,7 @@ amount of useful generic optimization tools.
import
abc
import
contextlib
import
copy
import
functools
import
inspect
import
logging
import
pdb
...
...
@@ -13,9 +14,10 @@ import sys
import
time
import
traceback
import
warnings
from
collections
import
OrderedDict
,
UserList
,
defaultdict
,
deque
from
collections
import
UserList
,
defaultdict
,
deque
from
collections.abc
import
Iterable
from
functools
import
partial
,
reduce
from
itertools
import
chain
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -1269,32 +1271,90 @@ def local_optimizer(
return
decorator
class
LocalOptTracker
:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def
__init__
(
self
):
self
.
tracked_instances
=
{}
self
.
tracked_types
=
{}
self
.
untracked_opts
=
[]
def
add_tracker
(
self
,
rw
:
LocalOptimizer
):
"""Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally."""
tracks
=
rw
.
tracks
()
if
tracks
is
None
:
self
.
untracked_opts
.
append
(
rw
)
else
:
for
c
in
tracks
:
if
isinstance
(
c
,
type
):
self
.
tracked_types
.
setdefault
(
c
,
[])
.
append
(
rw
)
else
:
self
.
tracked_instances
.
setdefault
(
c
,
[])
.
append
(
rw
)
def
_find_impl
(
self
,
cls
):
r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`.
"""
mro
=
functools
.
_compose_mro
(
cls
,
self
.
tracked_types
.
keys
())
matches
=
[]
for
t
in
mro
:
match
=
self
.
tracked_types
.
get
(
t
,
None
)
if
match
:
matches
.
extend
(
match
)
return
matches
@functools.lru_cache
()
def
get_trackers
(
self
,
op
:
Op
)
->
List
[
LocalOptimizer
]:
"""Get all the rewrites applicable to `op`."""
return
(
self
.
_find_impl
(
type
(
op
))
+
self
.
tracked_instances
.
get
(
op
,
[])
+
self
.
untracked_opts
)
def
get_rewriters
(
self
):
return
chain
(
chain
.
from_iterable
(
chain
(
self
.
tracked_types
.
values
(),
self
.
tracked_instances
.
values
())
),
self
.
untracked_opts
,
)
class
LocalOptGroup
(
LocalOptimizer
):
r"""An optimizer that applies a list of `LocalOptimizer`\s to a node.
Parameters
----------
optimizers :
A list of optimizers to be applied to nodes.
apply_all_opts : bool (Default False)
If ``False``, it will return after the new node after the first optimizer
applied. Otherwise, it will start again with the new node until no new
optimization apply.
profile :
Whether or not to profile the optimizations.
Attributes
----------
reentrant : bool
Some global optimizer like `NavigatorOptimizer` can use this value to
determine if it ignore new nodes during a pass on the nodes. Sometimes,
``ignore_newtrees`` is not reentrant.
Some global optimizers, like `NavigatorOptimizer`, use this value to
determine if they should ignore new nodes.
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
to the outputs.
"""
def
__init__
(
self
,
*
optimizers
,
apply_all_opts
=
False
,
profile
=
False
):
def
__init__
(
self
,
*
optimizers
,
apply_all_opts
:
bool
=
False
,
profile
:
bool
=
False
):
"""
Parameters
----------
optimizers
A list of optimizers to be applied to nodes.
apply_all_opts
If ``False``, it will return after the first successfully applied
rewrite; otherwise, it will apply every applicable rewrite
incrementally.
profile
Whether or not to profile the optimizations.
"""
super
()
.
__init__
()
if
len
(
optimizers
)
==
1
and
isinstance
(
optimizers
[
0
],
list
):
# This happen when created by LocalGroupDB.
optimizers
=
tuple
(
optimizers
[
0
])
...
...
@@ -1307,26 +1367,25 @@ class LocalOptGroup(LocalOptimizer):
)
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
self
.
track_map
=
defaultdict
(
lambda
:
[])
if
self
.
profile
:
self
.
time_opts
=
{}
self
.
process_count
=
{}
self
.
applied_true
=
{}
self
.
node_created
=
{}
self
.
tracker
=
LocalOptTracker
()
for
o
in
self
.
opts
:
self
.
tracker
.
add_tracker
(
o
)
if
self
.
profile
:
self
.
time_opts
.
setdefault
(
o
,
0
)
self
.
process_count
.
setdefault
(
o
,
0
)
self
.
applied_true
.
setdefault
(
o
,
0
)
self
.
node_created
.
setdefault
(
o
,
0
)
tracks
=
o
.
tracks
()
if
tracks
is
None
:
self
.
track_map
[
None
]
.
append
(
o
)
else
:
for
c
in
tracks
:
self
.
track_map
[
c
]
.
append
(
o
)
def
__str__
(
self
):
return
getattr
(
...
...
@@ -1346,13 +1405,12 @@ class LocalOptGroup(LocalOptimizer):
def
transform
(
self
,
fgraph
,
node
):
if
len
(
self
.
opts
)
==
0
:
return
repl
=
None
while
True
:
opts
=
(
self
.
track_map
[
type
(
node
.
op
)]
+
self
.
track_map
[
node
.
op
]
+
self
.
track_map
[
None
]
)
opts
=
self
.
tracker
.
get_trackers
(
node
.
op
)
new_repl
=
None
for
opt
in
opts
:
opt_start
=
time
.
time
()
...
...
@@ -2333,38 +2391,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super
()
.
__init__
(
None
,
ignore_newtrees
=
ignore_newtrees
,
failure_callback
=
failure_callback
)
self
.
local_optimizers_map
=
OrderedDict
()
self
.
local_optimizers_all
=
[]
self
.
global_optimizers
=
[]
self
.
final_optimizers
=
[]
self
.
cleanup_optimizers
=
[]
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
local_tracker
=
LocalOptTracker
()
for
opt
in
optimizers
:
if
isinstance
(
opt
,
LocalOptimizer
):
if
opt
.
tracks
()
is
None
:
self
.
local_optimizers_all
.
append
(
opt
)
else
:
for
c
in
opt
.
tracks
():
self
.
local_optimizers_map
.
setdefault
(
c
,
[])
.
append
(
opt
)
self
.
local_tracker
.
add_tracker
(
opt
)
else
:
self
.
global_optimizers
.
append
(
opt
)
if
final_optimizers
:
self
.
final_optimizers
=
final_optimizers
if
cleanup_optimizers
:
self
.
cleanup_optimizers
=
cleanup_optimizers
self
.
max_use_ratio
=
max_use_ratio
assert
self
.
max_use_ratio
is
not
None
,
"max_use_ratio has to be a number"
def
get_local_optimizers
(
self
):
for
opt
in
self
.
local_optimizers_all
:
yield
opt
# if repeat is not a problem we can drop the set
s
=
set
()
for
lopt
in
self
.
local_optimizers_map
.
values
():
for
opt
in
lopt
:
if
opt
not
in
s
:
yield
opt
s
.
add
(
opt
)
yield
from
self
.
local_tracker
.
get_rewriters
()
def
add_requirements
(
self
,
fgraph
):
super
()
.
add_requirements
(
fgraph
)
...
...
@@ -2496,11 +2543,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if
node
not
in
fgraph
.
apply_nodes
:
continue
current_node
=
node
for
lopt
in
(
self
.
local_optimizers_all
+
self
.
local_optimizers_map
.
get
(
type
(
node
.
op
),
[])
+
self
.
local_optimizers_map
.
get
(
node
.
op
,
[])
):
for
lopt
in
self
.
local_tracker
.
get_trackers
(
node
.
op
):
nb
=
change_tracker
.
nb_imported
t_opt
=
time
.
time
()
lopt_change
=
self
.
process_node
(
fgraph
,
node
,
lopt
)
...
...
tests/graph/test_opt.py
浏览文件 @
f8d06511
...
...
@@ -8,6 +8,7 @@ from aesara.graph.op import Op
from
aesara.graph.opt
import
(
EquilibriumOptimizer
,
LocalOptGroup
,
LocalOptTracker
,
MergeOptimizer
,
OpKeyOptimizer
,
OpSub
,
...
...
@@ -755,3 +756,59 @@ def test_local_optimizer():
# This is not allowed by `tracks`
local_opt_1
.
transform
(
fgraph
,
fgraph
.
outputs
[
2
]
.
owner
)
assert
hits
[
0
]
==
2
def
test_TrackingLocalOptimizer
():
@local_optimizer
(
None
)
def
local_opt_1
(
fgraph
,
node
):
pass
@local_optimizer
([
op1
])
def
local_opt_2
(
fgraph
,
node
):
pass
@local_optimizer
([
Op
])
def
local_opt_3
(
fgraph
,
node
):
pass
@local_optimizer
([
MyOp
])
def
local_opt_4
(
fgraph
,
node
):
pass
@local_optimizer
([
MyOp
])
def
local_opt_5
(
fgraph
,
node
):
pass
tracker
=
LocalOptTracker
()
tracker
.
add_tracker
(
local_opt_1
)
tracker
.
add_tracker
(
local_opt_2
)
tracker
.
add_tracker
(
local_opt_3
)
tracker
.
add_tracker
(
local_opt_4
)
tracker
.
add_tracker
(
local_opt_5
)
assert
tracker
.
tracked_instances
==
{
op1
:
[
local_opt_2
]}
assert
tracker
.
tracked_types
==
{
Op
:
[
local_opt_3
],
MyOp
:
[
local_opt_4
,
local_opt_5
],
}
assert
tracker
.
untracked_opts
==
[
local_opt_1
]
res
=
tracker
.
get_trackers
(
op1
)
assert
res
==
[
local_opt_4
,
local_opt_5
,
local_opt_3
,
local_opt_2
,
local_opt_1
]
class
MyNewOp
(
Op
):
def
perform
(
self
,
*
args
):
pass
new_op
=
MyNewOp
()
res
=
tracker
.
get_trackers
(
new_op
)
assert
res
==
[
local_opt_3
,
local_opt_1
]
assert
list
(
tracker
.
get_rewriters
())
==
[
local_opt_3
,
local_opt_4
,
local_opt_5
,
local_opt_2
,
local_opt_1
,
]
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论