Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ac213377
提交
ac213377
authored
7月 15, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename EquilibriumOptimizer to EquilibriumGraphRewriter
上级
e6635af8
显示空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
266 行增加
和
249 行删除
+266
-249
mode.py
aesara/compile/mode.py
+3
-3
configdefaults.py
aesara/configdefaults.py
+1
-1
opt.py
aesara/graph/opt.py
+165
-160
optdb.py
aesara/graph/optdb.py
+50
-35
opt.py
aesara/scan/opt.py
+3
-8
basic_opt.py
aesara/tensor/basic_opt.py
+5
-5
blas.py
aesara/tensor/blas.py
+2
-2
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+17
-17
test_kanren.py
tests/graph/test_kanren.py
+2
-2
test_opt.py
tests/graph/test_opt.py
+6
-6
test_optdb.py
tests/graph/test_optdb.py
+2
-2
test_opt.py
tests/tensor/random/test_opt.py
+10
-8
没有找到文件。
aesara/compile/mode.py
浏览文件 @
ac213377
...
...
@@ -212,10 +212,10 @@ optdb.register(
"canonicalize_db"
,
position
=
1
,
)
# Register in the canonizer Equilibrium as a clean
up opt the merge opt
.
# Register in the canonizer Equilibrium as a clean
-up rewrite the merge rewrite
.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global
optimiz
er with
# final_
opt
=True.
# won't merge all nodes if it is set as a global
rewrit
er with
# final_
rewriter
=True.
# We need a new instance of MergeOptimizer to don't have its name
# changed by other usage of it.
...
...
aesara/configdefaults.py
浏览文件 @
ac213377
...
...
@@ -1107,7 +1107,7 @@ def add_optimizer_configvars():
config
.
add
(
"optdb__max_use_ratio"
,
"A ratio that prevent infinite loop in Equilibrium
Optimiz
er."
,
"A ratio that prevent infinite loop in Equilibrium
GraphRewrit
er."
,
FloatParam
(
8
),
in_c_key
=
False
,
)
...
...
aesara/graph/opt.py
浏览文件 @
ac213377
...
...
@@ -2227,26 +2227,26 @@ def merge_dict(d1, d2):
return
d
class
Equilibrium
Optimiz
er
(
NodeProcessingGraphRewriter
):
"""A
n `Rewriter` that applies an optimization
until a fixed-point/equilibrium is reached."""
class
Equilibrium
GraphRewrit
er
(
NodeProcessingGraphRewriter
):
"""A
`Rewriter` that applies its rewrites
until a fixed-point/equilibrium is reached."""
def
__init__
(
self
,
optimiz
ers
:
Sequence
[
Rewriter
],
rewrit
ers
:
Sequence
[
Rewriter
],
failure_callback
:
Optional
[
FailureCallbackType
]
=
None
,
ignore_newtrees
:
bool
=
True
,
tracks_on_change_inputs
:
bool
=
False
,
max_use_ratio
:
Optional
[
float
]
=
None
,
final_
optimiz
ers
:
Optional
[
Sequence
[
GraphRewriter
]]
=
None
,
cleanup_
optimiz
ers
:
Optional
[
Sequence
[
GraphRewriter
]]
=
None
,
final_
rewrit
ers
:
Optional
[
Sequence
[
GraphRewriter
]]
=
None
,
cleanup_
rewrit
ers
:
Optional
[
Sequence
[
GraphRewriter
]]
=
None
,
):
"""
Parameters
----------
optimiz
ers
rewrit
ers
Node or graph rewriters to apply until equilibrium.
The global
optimiz
er will be run at the start of each iteration before
The global
rewrit
er will be run at the start of each iteration before
the node rewriter.
failure_callback
See :attr:`NodeProcessingGraphRewriter.failure_callback`.
...
...
@@ -2257,9 +2257,9 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times.
final_
optimiz
ers
final_
rewrit
ers
Rewriters that will be run after each iteration.
cleanup_
optimiz
ers
cleanup_
rewrit
ers
Rewriters applied after all graph rewriters, then when one
`NodeRewriter` is applied, then after all final rewriters.
They should not traverse the entire graph, since they are called
...
...
@@ -2270,27 +2270,27 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
super
()
.
__init__
(
None
,
ignore_newtrees
=
ignore_newtrees
,
failure_callback
=
failure_callback
)
self
.
global_
optimiz
ers
:
List
[
GraphRewriter
]
=
[]
self
.
global_
rewrit
ers
:
List
[
GraphRewriter
]
=
[]
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
node_tracker
=
OpToRewriterTracker
()
for
opt
in
optimiz
ers
:
if
isinstance
(
opt
,
NodeRewriter
):
self
.
node_tracker
.
add_tracker
(
opt
)
for
rewriter
in
rewrit
ers
:
if
isinstance
(
rewriter
,
NodeRewriter
):
self
.
node_tracker
.
add_tracker
(
rewriter
)
else
:
assert
isinstance
(
opt
,
GraphRewriter
)
self
.
global_
optimizers
.
append
(
opt
)
assert
isinstance
(
rewriter
,
GraphRewriter
)
self
.
global_
rewriters
.
append
(
rewriter
)
if
final_
optimiz
ers
:
self
.
final_
optimizers
=
list
(
final_optimiz
ers
)
if
final_
rewrit
ers
:
self
.
final_
rewriters
=
list
(
final_rewrit
ers
)
else
:
self
.
final_
optimiz
ers
=
[]
self
.
final_
rewrit
ers
=
[]
if
cleanup_
optimiz
ers
:
self
.
cleanup_
optimizers
=
list
(
cleanup_optimiz
ers
)
if
cleanup_
rewrit
ers
:
self
.
cleanup_
rewriters
=
list
(
cleanup_rewrit
ers
)
else
:
self
.
cleanup_
optimiz
ers
=
[]
self
.
cleanup_
rewrit
ers
=
[]
self
.
max_use_ratio
=
max_use_ratio
...
...
@@ -2307,14 +2307,14 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
def
add_requirements
(
self
,
fgraph
):
super
()
.
add_requirements
(
fgraph
)
for
opt
in
self
.
get_node_rewriters
():
opt
.
add_requirements
(
fgraph
)
for
opt
in
self
.
global_optimiz
ers
:
opt
.
add_requirements
(
fgraph
)
for
opt
in
self
.
final_optimiz
ers
:
opt
.
add_requirements
(
fgraph
)
for
opt
in
self
.
cleanup_optimiz
ers
:
opt
.
add_requirements
(
fgraph
)
for
rewriter
in
self
.
get_node_rewriters
():
rewriter
.
add_requirements
(
fgraph
)
for
rewriter
in
self
.
global_rewrit
ers
:
rewriter
.
add_requirements
(
fgraph
)
for
rewriter
in
self
.
final_rewrit
ers
:
rewriter
.
add_requirements
(
fgraph
)
for
rewriter
in
self
.
cleanup_rewrit
ers
:
rewriter
.
add_requirements
(
fgraph
)
def
apply
(
self
,
fgraph
,
start_from
=
None
):
change_tracker
=
ChangeTracker
()
...
...
@@ -2327,7 +2327,7 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
changed
=
True
max_use_abort
=
False
opt
_name
=
None
rewriter
_name
=
None
global_process_count
=
{}
start_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
...
...
@@ -2335,39 +2335,39 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
loop_timing
=
[]
loop_process_count
=
[]
global_
opt
_timing
=
[]
time_
opt
s
=
{}
global_
rewriter
_timing
=
[]
time_
rewriter
s
=
{}
io_toposort_timing
=
[]
nb_nodes
=
[]
node_created
=
{}
global_sub_profs
=
[]
final_sub_profs
=
[]
cleanup_sub_profs
=
[]
for
opt
in
(
self
.
global_
optimiz
ers
for
rewriter
in
(
self
.
global_
rewrit
ers
+
list
(
self
.
get_node_rewriters
())
+
self
.
final_
optimiz
ers
+
self
.
cleanup_
optimiz
ers
+
self
.
final_
rewrit
ers
+
self
.
cleanup_
rewrit
ers
):
global_process_count
.
setdefault
(
opt
,
0
)
time_
opts
.
setdefault
(
opt
,
0
)
node_created
.
setdefault
(
opt
,
0
)
global_process_count
.
setdefault
(
rewriter
,
0
)
time_
rewriters
.
setdefault
(
rewriter
,
0
)
node_created
.
setdefault
(
rewriter
,
0
)
def
apply_cleanup
(
profs_dict
):
changed
=
False
for
c
opt
in
self
.
cleanup_optimiz
ers
:
for
c
rewriter
in
self
.
cleanup_rewrit
ers
:
change_tracker
.
reset
()
nb
=
change_tracker
.
nb_imported
t_
opt
=
time
.
time
()
sub_prof
=
c
opt
.
apply
(
fgraph
)
time_
opts
[
copt
]
+=
time
.
time
()
-
t_opt
profs_dict
[
c
opt
]
.
append
(
sub_prof
)
t_
rewrite
=
time
.
time
()
sub_prof
=
c
rewriter
.
apply
(
fgraph
)
time_
rewriters
[
crewriter
]
+=
time
.
time
()
-
t_rewrite
profs_dict
[
c
rewriter
]
.
append
(
sub_prof
)
if
change_tracker
.
changed
:
process_count
.
setdefault
(
c
opt
,
0
)
process_count
[
c
opt
]
+=
1
global_process_count
[
c
opt
]
+=
1
process_count
.
setdefault
(
c
rewriter
,
0
)
process_count
[
c
rewriter
]
+=
1
global_process_count
[
c
rewriter
]
+=
1
changed
=
True
node_created
[
c
opt
]
+=
change_tracker
.
nb_imported
-
nb
node_created
[
c
rewriter
]
+=
change_tracker
.
nb_imported
-
nb
return
changed
while
changed
and
not
max_use_abort
:
...
...
@@ -2375,32 +2375,32 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
t0
=
time
.
time
()
changed
=
False
iter_cleanup_sub_profs
=
{}
for
c
opt
in
self
.
cleanup_optimiz
ers
:
iter_cleanup_sub_profs
[
c
opt
]
=
[]
for
c
rewrite
in
self
.
cleanup_rewrit
ers
:
iter_cleanup_sub_profs
[
c
rewrite
]
=
[]
#
apply global optimiz
ers
#
Apply global rewrit
ers
sub_profs
=
[]
for
g
opt
in
self
.
global_optimiz
ers
:
for
g
rewrite
in
self
.
global_rewrit
ers
:
change_tracker
.
reset
()
nb
=
change_tracker
.
nb_imported
t_
opt
=
time
.
time
()
sub_prof
=
g
opt
.
apply
(
fgraph
)
time_
opts
[
gopt
]
+=
time
.
time
()
-
t_opt
t_
rewrite
=
time
.
time
()
sub_prof
=
g
rewrite
.
apply
(
fgraph
)
time_
rewriters
[
grewrite
]
+=
time
.
time
()
-
t_rewrite
sub_profs
.
append
(
sub_prof
)
if
change_tracker
.
changed
:
process_count
.
setdefault
(
g
opt
,
0
)
process_count
[
g
opt
]
+=
1
global_process_count
[
g
opt
]
+=
1
process_count
.
setdefault
(
g
rewrite
,
0
)
process_count
[
g
rewrite
]
+=
1
global_process_count
[
g
rewrite
]
+=
1
changed
=
True
node_created
[
g
opt
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
g
opt
]
>
max_use
:
node_created
[
g
rewrite
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
g
rewrite
]
>
max_use
:
max_use_abort
=
True
opt_name
=
getattr
(
gopt
,
"name"
,
None
)
or
getattr
(
g
opt
,
"__name__"
,
""
rewriter_name
=
getattr
(
grewrite
,
"name"
,
None
)
or
getattr
(
g
rewrite
,
"__name__"
,
""
)
global_sub_profs
.
append
(
sub_profs
)
global_
opt
_timing
.
append
(
float
(
time
.
time
()
-
t0
))
global_
rewriter
_timing
.
append
(
float
(
time
.
time
()
-
t0
))
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
...
...
@@ -2434,11 +2434,11 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
current_node
=
node
for
node_rewriter
in
self
.
node_tracker
.
get_trackers
(
node
.
op
):
nb
=
change_tracker
.
nb_imported
t_
opt
=
time
.
time
()
t_
rewrite
=
time
.
time
()
node_rewriter_change
=
self
.
process_node
(
fgraph
,
node
,
node_rewriter
)
time_
opts
[
node_rewriter
]
+=
time
.
time
()
-
t_opt
time_
rewriters
[
node_rewriter
]
+=
time
.
time
()
-
t_rewrite
if
not
node_rewriter_change
:
continue
process_count
.
setdefault
(
node_rewriter
,
0
)
...
...
@@ -2449,48 +2449,48 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
if
global_process_count
[
node_rewriter
]
>
max_use
:
max_use_abort
=
True
opt_name
=
getattr
(
node_rewriter
,
"name"
,
None
)
or
getattr
(
node_rewriter
,
"
__name__"
,
""
)
rewriter_name
=
getattr
(
node_rewriter
,
"
name"
,
None
)
or
getattr
(
node_rewriter
,
"__name__"
,
""
)
if
node
not
in
fgraph
.
apply_nodes
:
# go to next node
break
finally
:
self
.
detach_updater
(
fgraph
,
u
)
# Apply final
optimiz
ers
# Apply final
rewrit
ers
sub_profs
=
[]
t_before_final_
opt
=
time
.
time
()
for
g
opt
in
self
.
final_optimiz
ers
:
t_before_final_
rewrites
=
time
.
time
()
for
g
rewrite
in
self
.
final_rewrit
ers
:
change_tracker
.
reset
()
nb
=
change_tracker
.
nb_imported
t_
opt
=
time
.
time
()
sub_prof
=
g
opt
.
apply
(
fgraph
)
time_
opts
[
gopt
]
+=
time
.
time
()
-
t_opt
t_
rewrite
=
time
.
time
()
sub_prof
=
g
rewrite
.
apply
(
fgraph
)
time_
rewriters
[
grewrite
]
+=
time
.
time
()
-
t_rewrite
sub_profs
.
append
(
sub_prof
)
if
change_tracker
.
changed
:
process_count
.
setdefault
(
g
opt
,
0
)
process_count
[
g
opt
]
+=
1
global_process_count
[
g
opt
]
+=
1
process_count
.
setdefault
(
g
rewrite
,
0
)
process_count
[
g
rewrite
]
+=
1
global_process_count
[
g
rewrite
]
+=
1
changed
=
True
node_created
[
g
opt
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
g
opt
]
>
max_use
:
node_created
[
g
rewrite
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
g
rewrite
]
>
max_use
:
max_use_abort
=
True
opt_name
=
getattr
(
gopt
,
"name"
,
None
)
or
getattr
(
g
opt
,
"__name__"
,
""
rewriter_name
=
getattr
(
grewrite
,
"name"
,
None
)
or
getattr
(
g
rewrite
,
"__name__"
,
""
)
final_sub_profs
.
append
(
sub_profs
)
global_opt_timing
[
-
1
]
+=
time
.
time
()
-
t_before_final_opt
# apply clean up as final opt can have done changes that
# request that
global_rewriter_timing
[
-
1
]
+=
time
.
time
()
-
t_before_final_rewrites
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
# merge clean up profiles during that iteration.
# Merge clean up profiles during that iteration
c_sub_profs
=
[]
for
c
opt
,
sub_profs
in
iter_cleanup_sub_profs
.
items
():
for
c
rewrite
,
sub_profs
in
iter_cleanup_sub_profs
.
items
():
sub_prof
=
sub_profs
[
0
]
for
s_p
in
sub_profs
[
1
:]:
sub_prof
=
c
opt
.
merge_profile
(
sub_prof
,
s_p
)
sub_prof
=
c
rewrite
.
merge_profile
(
sub_prof
,
s_p
)
c_sub_profs
.
append
(
sub_prof
)
cleanup_sub_profs
.
append
(
c_sub_profs
)
...
...
@@ -2501,9 +2501,9 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
if
max_use_abort
:
msg
=
(
f
"
EquilibriumOptimizer max'ed out by '{opt_name}'
"
+
".
You can safely raise the current threshold of "
+
"{config.optdb__max_use_ratio:f} with the aesara flag 'optdb__max_use_ratio'
."
f
"
{type(self).__name__} max'ed out by {rewriter_name}.
"
"
You can safely raise the current threshold of "
f
"{config.optdb__max_use_ratio} with the option `optdb__max_use_ratio`
."
)
if
config
.
on_opt_error
==
"raise"
:
raise
AssertionError
(
msg
)
...
...
@@ -2511,7 +2511,7 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
_logger
.
error
(
msg
)
fgraph
.
remove_feature
(
change_tracker
)
assert
len
(
loop_process_count
)
==
len
(
loop_timing
)
assert
len
(
loop_process_count
)
==
len
(
global_
opt
_timing
)
assert
len
(
loop_process_count
)
==
len
(
global_
rewriter
_timing
)
assert
len
(
loop_process_count
)
==
len
(
nb_nodes
)
assert
len
(
loop_process_count
)
==
len
(
io_toposort_timing
)
assert
len
(
loop_process_count
)
==
len
(
global_sub_profs
)
...
...
@@ -2522,9 +2522,9 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
loop_timing
,
loop_process_count
,
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
global_
opt
_timing
,
global_
rewriter
_timing
,
nb_nodes
,
time_
opt
s
,
time_
rewriter
s
,
io_toposort_timing
,
node_created
,
global_sub_profs
,
...
...
@@ -2543,16 +2543,16 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
)
)
@
static
method
def
print_profile
(
stream
,
prof
,
level
=
0
):
@
class
method
def
print_profile
(
cls
,
stream
,
prof
,
level
=
0
):
(
opt
,
rewrite
,
loop_timing
,
loop_process_count
,
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
global_
opt
_timing
,
global_
rewrite
_timing
,
nb_nodes
,
time_
opt
s
,
time_
rewrite
s
,
io_toposort_timing
,
node_created
,
global_sub_profs
,
...
...
@@ -2561,8 +2561,12 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
)
=
prof
blanc
=
" "
*
level
print
(
blanc
,
"EquilibriumOptimizer"
,
end
=
" "
,
file
=
stream
)
print
(
blanc
,
getattr
(
opt
,
"name"
,
getattr
(
opt
,
"__name__"
,
""
)),
file
=
stream
)
print
(
blanc
,
cls
.
__name__
,
end
=
" "
,
file
=
stream
)
print
(
blanc
,
getattr
(
rewrite
,
"name"
,
getattr
(
rewrite
,
"__name__"
,
""
)),
file
=
stream
,
)
print
(
blanc
,
f
" time {sum(loop_timing):.3f}s for {len(loop_timing)} passes"
,
...
...
@@ -2574,13 +2578,13 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
file
=
stream
,
)
print
(
blanc
,
f
" time io_toposort {sum(io_toposort_timing):.3f}s"
,
file
=
stream
)
s
=
sum
(
time_
opts
[
o
]
for
o
in
opt
.
get_node_rewriters
())
s
=
sum
(
time_
rewrites
[
o
]
for
o
in
rewrite
.
get_node_rewriters
())
print
(
blanc
,
f
" time in node rewriters {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_
opts
[
o
]
for
o
in
opt
.
global_optimiz
ers
)
s
=
sum
(
time_
rewrites
[
o
]
for
o
in
rewrite
.
global_rewrit
ers
)
print
(
blanc
,
f
" time in graph rewriters {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_
opts
[
o
]
for
o
in
opt
.
final_optimiz
ers
)
s
=
sum
(
time_
rewrites
[
o
]
for
o
in
rewrite
.
final_rewrit
ers
)
print
(
blanc
,
f
" time in final rewriters {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_
opts
[
o
]
for
o
in
opt
.
cleanup_optimiz
ers
)
s
=
sum
(
time_
rewrites
[
o
]
for
o
in
rewrite
.
cleanup_rewrit
ers
)
print
(
blanc
,
f
" time in cleanup rewriters {s:.3f}s"
,
file
=
stream
)
for
i
in
range
(
len
(
loop_timing
)):
loop_times
=
""
...
...
@@ -2594,21 +2598,21 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
print
(
blanc
,
(
f
" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_
opt
_timing[i]:.3f}s in graph rewriters, "
f
" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_
rewrite
_timing[i]:.3f}s in graph rewriters, "
f
"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {loop_times}"
),
file
=
stream
,
)
count_
opt
=
[]
count_
rewrite
=
[]
not_used
=
[]
not_used_time
=
0
process_count
=
{}
for
o
in
(
opt
.
global_optimiz
ers
+
list
(
opt
.
get_node_rewriters
())
+
list
(
opt
.
final_optimiz
ers
)
+
list
(
opt
.
cleanup_optimiz
ers
)
rewrite
.
global_rewrit
ers
+
list
(
rewrite
.
get_node_rewriters
())
+
list
(
rewrite
.
final_rewrit
ers
)
+
list
(
rewrite
.
cleanup_rewrit
ers
)
):
process_count
.
setdefault
(
o
,
0
)
for
count
in
loop_process_count
:
...
...
@@ -2616,17 +2620,17 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
process_count
[
o
]
+=
v
for
o
,
count
in
process_count
.
items
():
if
count
>
0
:
count_
opt
.
append
((
time_opt
s
[
o
],
count
,
node_created
[
o
],
o
))
count_
rewrite
.
append
((
time_rewrite
s
[
o
],
count
,
node_created
[
o
],
o
))
else
:
not_used
.
append
((
time_
opt
s
[
o
],
o
))
not_used_time
+=
time_
opt
s
[
o
]
not_used
.
append
((
time_
rewrite
s
[
o
],
o
))
not_used_time
+=
time_
rewrite
s
[
o
]
if
count_
opt
:
if
count_
rewrite
:
print
(
blanc
,
" times - times applied - nb node created - name:"
,
file
=
stream
)
count_
opt
.
sort
()
for
(
t
,
count
,
n_created
,
o
)
in
count_
opt
[::
-
1
]:
count_
rewrite
.
sort
()
for
(
t
,
count
,
n_created
,
o
)
in
count_
rewrite
[::
-
1
]:
print
(
blanc
,
f
" {t:.3f}s - {int(count)} - {int(n_created)} - {o}"
,
...
...
@@ -2634,40 +2638,40 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
)
print
(
blanc
,
f
" {not_used_time:.3f}s - in {len(not_used)}
optimization that were not used (display only those with a runtime > 0
)"
,
f
" {not_used_time:.3f}s - in {len(not_used)}
rewrites that were not used (i.e. those with a run-time of zero
)"
,
file
=
stream
,
)
not_used
.
sort
(
key
=
lambda
nu
:
(
nu
[
0
],
str
(
nu
[
1
])))
for
(
t
,
o
)
in
not_used
[::
-
1
]:
if
t
>
0
:
# Skip
opt that have 0 times, they probably was
n't even tried.
# Skip
rewrites that have no run-times; they probably were
n't even tried.
print
(
blanc
+
" "
,
f
" {t:.3f}s - {o}"
,
file
=
stream
)
print
(
file
=
stream
)
gf_
opt
s
=
[
gf_
rewrite
s
=
[
o
for
o
in
(
opt
.
global_optimizer
s
+
list
(
opt
.
final_optimiz
ers
)
+
list
(
opt
.
cleanup_optimiz
ers
)
rewrite
.
global_rewrite
s
+
list
(
rewrite
.
final_rewrit
ers
)
+
list
(
rewrite
.
cleanup_rewrit
ers
)
)
if
o
.
print_profile
.
__code__
is
not
GraphRewriter
.
print_profile
.
__code__
]
if
not
gf_
opt
s
:
if
not
gf_
rewrite
s
:
return
print
(
blanc
,
"Global, final
and clean up optimiz
ers"
,
file
=
stream
)
print
(
blanc
,
"Global, final
, and clean up rewrit
ers"
,
file
=
stream
)
for
i
in
range
(
len
(
loop_timing
)):
print
(
blanc
,
f
"Iter {int(i)}"
,
file
=
stream
)
for
o
,
prof
in
zip
(
opt
.
global_optimiz
ers
,
global_sub_profs
[
i
]):
for
o
,
prof
in
zip
(
rewrite
.
global_rewrit
ers
,
global_sub_profs
[
i
]):
try
:
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
except
NotImplementedError
:
print
(
blanc
,
"merge not implemented for "
,
o
)
for
o
,
prof
in
zip
(
opt
.
final_optimiz
ers
,
final_sub_profs
[
i
]):
for
o
,
prof
in
zip
(
rewrite
.
final_rewrit
ers
,
final_sub_profs
[
i
]):
try
:
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
except
NotImplementedError
:
print
(
blanc
,
"merge not implemented for "
,
o
)
for
o
,
prof
in
zip
(
opt
.
cleanup_optimiz
ers
,
cleanup_sub_profs
[
i
]):
for
o
,
prof
in
zip
(
rewrite
.
cleanup_rewrit
ers
,
cleanup_sub_profs
[
i
]):
try
:
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
except
NotImplementedError
:
...
...
@@ -2675,25 +2679,23 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
@staticmethod
def
merge_profile
(
prof1
,
prof2
):
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
node_rewriters
=
OrderedSet
(
prof1
[
0
]
.
get_node_rewriters
())
.
union
(
prof2
[
0
]
.
get_node_rewriters
()
)
global_
optimizers
=
OrderedSet
(
prof1
[
0
]
.
global_optimiz
ers
)
.
union
(
prof2
[
0
]
.
global_
optimiz
ers
global_
rewriters
=
OrderedSet
(
prof1
[
0
]
.
global_rewrit
ers
)
.
union
(
prof2
[
0
]
.
global_
rewrit
ers
)
final_
optimiz
ers
=
list
(
OrderedSet
(
prof1
[
0
]
.
final_
optimizers
)
.
union
(
prof2
[
0
]
.
final_optimiz
ers
)
final_
rewrit
ers
=
list
(
OrderedSet
(
prof1
[
0
]
.
final_
rewriters
)
.
union
(
prof2
[
0
]
.
final_rewrit
ers
)
)
cleanup_
optimiz
ers
=
list
(
OrderedSet
(
prof1
[
0
]
.
cleanup_
optimizers
)
.
union
(
prof2
[
0
]
.
cleanup_optimiz
ers
)
cleanup_
rewrit
ers
=
list
(
OrderedSet
(
prof1
[
0
]
.
cleanup_
rewriters
)
.
union
(
prof2
[
0
]
.
cleanup_rewrit
ers
)
)
new_
opt
=
EquilibriumOptimiz
er
(
node_rewriters
.
union
(
global_
optimiz
ers
),
new_
rewriter
=
EquilibriumGraphRewrit
er
(
node_rewriters
.
union
(
global_
rewrit
ers
),
max_use_ratio
=
1
,
final_
optimizers
=
final_optimiz
ers
,
cleanup_
optimizers
=
cleanup_optimiz
ers
,
final_
rewriters
=
final_rewrit
ers
,
cleanup_
rewriters
=
cleanup_rewrit
ers
,
)
def
add_append_list
(
l1
,
l2
):
...
...
@@ -2720,29 +2722,27 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
else
:
process_count
[
process
]
=
count
def
merge
(
opt
s
,
attr
,
idx
):
def
merge
(
rewriter
s
,
attr
,
idx
):
tmp
=
[]
for
opt
in
opt
s
:
for
rewriter
in
rewriter
s
:
o1
=
getattr
(
prof1
[
0
],
attr
)
o2
=
getattr
(
prof2
[
0
],
attr
)
if
opt
in
o1
and
opt
in
o2
:
p1
=
prof1
[
idx
][
i
][
o1
.
index
(
opt
)]
p2
=
prof2
[
idx
][
i
][
o2
.
index
(
opt
)]
if
rewriter
in
o1
and
rewriter
in
o2
:
p1
=
prof1
[
idx
][
i
][
o1
.
index
(
rewriter
)]
p2
=
prof2
[
idx
][
i
][
o2
.
index
(
rewriter
)]
m
=
None
if
hasattr
(
opt
,
"merge_profile"
):
m
=
opt
.
merge_profile
(
p1
,
p2
)
elif
opt
in
o1
:
m
=
prof1
[
idx
][
i
][
o1
.
index
(
opt
)]
if
hasattr
(
rewriter
,
"merge_profile"
):
m
=
rewriter
.
merge_profile
(
p1
,
p2
)
elif
rewriter
in
o1
:
m
=
prof1
[
idx
][
i
][
o1
.
index
(
rewriter
)]
else
:
m
=
prof2
[
idx
][
i
][
o2
.
index
(
opt
)]
m
=
prof2
[
idx
][
i
][
o2
.
index
(
rewriter
)]
tmp
.
append
(
m
)
return
tmp
global_sub_profs
.
append
(
merge
(
global_optimizers
,
"global_optimizers"
,
9
))
final_sub_profs
.
append
(
merge
(
final_optimizers
,
"final_optimizers"
,
10
))
cleanup_sub_profs
.
append
(
merge
(
cleanup_optimizers
,
"cleanup_optimizers"
,
11
)
)
global_sub_profs
.
append
(
merge
(
global_rewriters
,
"global_rewriters"
,
9
))
final_sub_profs
.
append
(
merge
(
final_rewriters
,
"final_rewriters"
,
10
))
cleanup_sub_profs
.
append
(
merge
(
cleanup_rewriters
,
"cleanup_rewriters"
,
11
))
# Add the iteration done by only one of the profile.
loop_process_count
.
extend
(
prof1
[
2
][
len
(
loop_process_count
)
:])
...
...
@@ -2756,15 +2756,15 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
max_nb_nodes
=
max
(
prof1
[
3
],
prof2
[
3
])
global_
opt
_timing
=
add_append_list
(
prof1
[
4
],
prof2
[
4
])
global_
rewrite
_timing
=
add_append_list
(
prof1
[
4
],
prof2
[
4
])
nb_nodes
=
add_append_list
(
prof1
[
5
],
prof2
[
5
])
time_
opt
s
=
merge_dict
(
prof1
[
6
],
prof2
[
6
])
time_
rewrite
s
=
merge_dict
(
prof1
[
6
],
prof2
[
6
])
io_toposort_timing
=
add_append_list
(
prof1
[
7
],
prof2
[
7
])
assert
(
len
(
loop_timing
)
==
len
(
global_
opt
_timing
)
==
len
(
global_
rewrite
_timing
)
==
len
(
global_sub_profs
)
==
len
(
io_toposort_timing
)
==
len
(
nb_nodes
)
...
...
@@ -2773,13 +2773,13 @@ class EquilibriumOptimizer(NodeProcessingGraphRewriter):
node_created
=
merge_dict
(
prof1
[
8
],
prof2
[
8
])
return
(
new_
opt
,
new_
rewriter
,
loop_timing
,
loop_process_count
,
max_nb_nodes
,
global_
opt
_timing
,
global_
rewrite
_timing
,
nb_nodes
,
time_
opt
s
,
time_
rewrite
s
,
io_toposort_timing
,
node_created
,
global_sub_profs
,
...
...
@@ -3235,6 +3235,11 @@ DEPRECATED_NAMES = [
"`OpKeyOptimizer` is deprecated: use `OpKeyGraphRewriter` instead."
,
OpKeyGraphRewriter
,
),
(
"EquilibriumOptimizer"
,
"`EquilibriumOptimizer` is deprecated: use `EquilibriumGraphRewriter` instead."
,
EquilibriumGraphRewriter
,
),
]
...
...
aesara/graph/optdb.py
浏览文件 @
ac213377
...
...
@@ -31,19 +31,18 @@ class OptimizationDatabase:
def
register
(
self
,
name
:
str
,
optimiz
er
:
Union
[
"OptimizationDatabase"
,
OptimizersType
],
rewrit
er
:
Union
[
"OptimizationDatabase"
,
OptimizersType
],
*
tags
:
str
,
use_db_name_as_tag
=
True
,
**
kwargs
,
):
"""Register a new
optimiz
er to the database.
"""Register a new
rewrit
er to the database.
Parameters
----------
name:
Name of the
optimiz
er.
opt
:
The
optimiz
er to register.
Name of the
rewrit
er.
rewriter
:
The
rewrit
er to register.
tags:
Tag name that allow to select the optimizer.
use_db_name_as_tag:
...
...
@@ -58,14 +57,14 @@ class OptimizationDatabase:
"""
if
not
isinstance
(
optimiz
er
,
rewrit
er
,
(
OptimizationDatabase
,
aesara_opt
.
GraphRewriter
,
aesara_opt
.
NodeRewriter
,
),
):
raise
TypeError
(
f
"{
optimiz
er} is not a valid optimizer type."
)
raise
TypeError
(
f
"{
rewrit
er} is not a valid optimizer type."
)
if
name
in
self
.
__db__
:
raise
ValueError
(
f
"The tag '{name}' is already present in the database."
)
...
...
@@ -74,18 +73,18 @@ class OptimizationDatabase:
if
self
.
name
is
not
None
:
tags
=
tags
+
(
self
.
name
,)
optimiz
er
.
name
=
name
rewrit
er
.
name
=
name
# This restriction is there because in many place we suppose that
# something in the OptimizationDatabase is there only once.
if
optimiz
er
.
name
in
self
.
__db__
:
if
rewrit
er
.
name
in
self
.
__db__
:
raise
ValueError
(
f
"Tried to register {
optimiz
er.name} again under the new name {name}. "
f
"Tried to register {
rewrit
er.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead."
)
self
.
__db__
[
name
]
=
OrderedSet
([
optimiz
er
])
self
.
__db__
[
name
]
=
OrderedSet
([
rewrit
er
])
self
.
_names
.
add
(
name
)
self
.
__db__
[
optimizer
.
__class__
.
__name__
]
.
add
(
optimiz
er
)
self
.
__db__
[
rewriter
.
__class__
.
__name__
]
.
add
(
rewrit
er
)
self
.
add_tags
(
name
,
*
tags
)
def
add_tags
(
self
,
name
,
*
tags
):
...
...
@@ -292,11 +291,11 @@ class OptimizationQuery:
class
EquilibriumDB
(
OptimizationDatabase
):
"""A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium
optimization
s.
Canonicalize, Stabilize, and Specialize are all equilibrium
rewriter
s.
Notes
-----
We can use `NodeRewriter` and `GraphRewriter` since `Equilibrium
Optimiz
er`
We can use `NodeRewriter` and `GraphRewriter` since `Equilibrium
GraphRewrit
er`
supports both.
It is probably not a good idea to have both ``ignore_newtrees == False``
...
...
@@ -322,33 +321,47 @@ class EquilibriumDB(OptimizationDatabase):
super
()
.
__init__
()
self
.
ignore_newtrees
=
ignore_newtrees
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
__final__
:
Dict
[
str
,
aesara_opt
.
Rewriter
]
=
{}
self
.
__cleanup__
:
Dict
[
str
,
aesara_opt
.
Rewriter
]
=
{}
self
.
__final__
:
Dict
[
str
,
bool
]
=
{}
self
.
__cleanup__
:
Dict
[
str
,
bool
]
=
{}
def
register
(
self
,
name
,
obj
,
*
tags
,
final_opt
=
False
,
cleanup
=
False
,
**
kwargs
):
if
final_opt
and
cleanup
:
raise
ValueError
(
"`final_opt` and `cleanup` cannot both be true."
)
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwargs
)
self
.
__final__
[
name
]
=
final_opt
def
register
(
self
,
name
:
str
,
rewriter
:
Union
[
"OptimizationDatabase"
,
OptimizersType
],
*
tags
:
str
,
final_rewriter
:
bool
=
False
,
cleanup
:
bool
=
False
,
**
kwargs
,
):
if
final_rewriter
and
cleanup
:
raise
ValueError
(
"`final_rewriter` and `cleanup` cannot both be true."
)
super
()
.
register
(
name
,
rewriter
,
*
tags
,
**
kwargs
)
self
.
__final__
[
name
]
=
final_rewriter
self
.
__cleanup__
[
name
]
=
cleanup
def
query
(
self
,
*
tags
,
**
kwtags
):
_opts
=
super
()
.
query
(
*
tags
,
**
kwtags
)
final_opts
=
[
o
for
o
in
_opts
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
cleanup_opts
=
[
o
for
o
in
_opts
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)]
opts
=
[
o
for
o
in
_opts
if
o
not
in
final_opts
and
o
not
in
cleanup_opts
]
if
len
(
final_opts
)
==
0
:
final_opts
=
None
if
len
(
cleanup_opts
)
==
0
:
cleanup_opts
=
None
return
aesara_opt
.
EquilibriumOptimizer
(
opts
,
_rewriters
=
super
()
.
query
(
*
tags
,
**
kwtags
)
final_rewriters
=
[
o
for
o
in
_rewriters
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
cleanup_rewriters
=
[
o
for
o
in
_rewriters
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)
]
rewriters
=
[
o
for
o
in
_rewriters
if
o
not
in
final_rewriters
and
o
not
in
cleanup_rewriters
]
if
len
(
final_rewriters
)
==
0
:
final_rewriters
=
None
if
len
(
cleanup_rewriters
)
==
0
:
cleanup_rewriters
=
None
return
aesara_opt
.
EquilibriumGraphRewriter
(
rewriters
,
max_use_ratio
=
config
.
optdb__max_use_ratio
,
ignore_newtrees
=
self
.
ignore_newtrees
,
tracks_on_change_inputs
=
self
.
tracks_on_change_inputs
,
failure_callback
=
aesara_opt
.
NodeProcessingGraphRewriter
.
warn_inplace
,
final_
optimizers
=
final_opt
s
,
cleanup_
optimizers
=
cleanup_opt
s
,
final_
rewriters
=
final_rewriter
s
,
cleanup_
rewriters
=
cleanup_rewriter
s
,
)
...
...
@@ -372,8 +385,10 @@ class SequenceDB(OptimizationDatabase):
self
.
failure_callback
=
failure_callback
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwargs
)
position
=
kwargs
.
pop
(
"position"
,
"last"
)
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwargs
)
if
position
==
"last"
:
if
len
(
self
.
__position__
)
==
0
:
self
.
__position__
[
name
]
=
0
...
...
aesara/scan/opt.py
浏览文件 @
ac213377
...
...
@@ -2373,7 +2373,7 @@ optdb.register(
position
=
75
,
)
scan_eqopt1
.
register
(
"all_pushout_opt"
,
scan_seqopt1
,
"fast_run"
,
"scan"
,
position
=
1
)
scan_eqopt1
.
register
(
"all_pushout_opt"
,
scan_seqopt1
,
"fast_run"
,
"scan"
)
scan_seqopt1
.
register
(
...
...
@@ -2419,7 +2419,7 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
"scan_pushout_add"
,
# TODO: Perhaps this should be an `Equilibrium
Optimiz
er`?
# TODO: Perhaps this should be an `Equilibrium
GraphRewrit
er`?
in2out
(
push_out_add_scan
,
ignore_newtrees
=
False
),
"fast_run"
,
"more_mem"
,
...
...
@@ -2434,7 +2434,6 @@ scan_eqopt2.register(
in2out
(
basic_opt
.
constant_folding
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
position
=
1
,
)
...
...
@@ -2444,14 +2443,13 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"scan"
,
position
=
2
,
)
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_eqopt2
.
register
(
"scan_merge"
,
ScanMerge
(),
"fast_run"
,
"scan"
,
position
=
4
)
scan_eqopt2
.
register
(
"scan_merge"
,
ScanMerge
(),
"fast_run"
,
"scan"
)
# After Merge optimization
scan_eqopt2
.
register
(
...
...
@@ -2460,7 +2458,6 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"scan"
,
position
=
5
,
)
scan_eqopt2
.
register
(
...
...
@@ -2468,7 +2465,6 @@ scan_eqopt2.register(
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
position
=
6
,
)
# After everything else
...
...
@@ -2478,5 +2474,4 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"scan"
,
position
=
8
,
)
aesara/tensor/basic_opt.py
浏览文件 @
ac213377
...
...
@@ -2802,10 +2802,10 @@ def constant_folding(fgraph, node):
topo_constant_folding
=
in2out
(
constant_folding
,
ignore_newtrees
=
True
,
name
=
"topo_constant_folding"
)
register_canonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_uncanonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_stabilize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_specialize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_canonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
register_uncanonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
register_stabilize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
register_specialize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
def
local_elemwise_fusion_op
(
op_class
,
max_input_fct
=
lambda
node
:
32
,
maker
=
None
):
...
...
@@ -3096,7 +3096,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class
FusionOptimizer
(
GraphRewriter
):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `Equilibrium
Optimiz
er`; we should just use that.
TODO: This is basically an `Equilibrium
GraphRewrit
er`; we should just use that.
"""
...
...
aesara/tensor/blas.py
浏览文件 @
ac213377
...
...
@@ -146,7 +146,7 @@ from aesara.graph.basic import Apply, view_roots
from
aesara.graph.features
import
ReplacementDidNotRemoveError
,
ReplaceValidate
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
(
Equilibrium
Optimiz
er
,
Equilibrium
GraphRewrit
er
,
GraphRewriter
,
copy_stack_trace
,
in2out
,
...
...
@@ -1906,7 +1906,7 @@ blas_optdb.register(
blas_optdb
.
register
(
"gemm_optimizer"
,
GemmOptimizer
(),
"fast_run"
,
position
=
10
)
blas_optdb
.
register
(
"local_gemm_to_gemv"
,
Equilibrium
Optimiz
er
(
Equilibrium
GraphRewrit
er
(
[
local_gemm_to_gemv
,
local_gemm_to_ger
,
...
...
doc/extending/graph_rewriting.rst
浏览文件 @
ac213377
...
...
@@ -444,7 +444,7 @@ The following is an example that distributes dot products across additions.
import aesara
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import Equilibrium
Optimiz
er
from aesara.graph.opt import Equilibrium
GraphRewrit
er
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.math import _dot
from etuples import etuple
...
...
@@ -484,7 +484,7 @@ The following is an example that distributes dot products across additions.
)
dot_distribute_opt = Equilibrium
Optimiz
er([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
dot_distribute_opt = Equilibrium
GraphRewrit
er([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
Below, we apply `dot_distribute_opt` to a few example graphs. First we create simple test graph:
...
...
@@ -531,7 +531,7 @@ relational properties.
To do that, we will create another :class:`Rewriter` that simply reverses the arguments
to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``:
>>> dot_gather_opt = Equilibrium
Optimiz
er([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> dot_gather_opt = Equilibrium
GraphRewrit
er([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_opt, clone=False)
>>> print(aesara.pprint(rev_res))
(A @ (x + (y + (B @ (z + w)))))
...
...
@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`Equilibrium
Optimiz
er` containing :class:`NodeRewriter`\s A, B
maybe optimization Y is an :class:`Equilibrium
GraphRewrit
er` containing :class:`NodeRewriter`\s A, B
and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to
...
...
@@ -599,7 +599,7 @@ optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`Equilibrium
Optimiz
er`, which is returned. If the
inserted into an :class:`Equilibrium
GraphRewrit
er`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places
...
...
@@ -859,8 +859,8 @@ This will output something like this:
0.028s for fgraph.validate()
0.131s for callback
time - (name, class, index) - validate time
0.751816s - ('canonicalize', 'Equilibrium
Optimiz
er', 4) - 0.004s
Equilibrium
Optimiz
er canonicalize
0.751816s - ('canonicalize', 'Equilibrium
GraphRewrit
er', 4) - 0.004s
Equilibrium
GraphRewrit
er canonicalize
time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s
...
...
@@ -974,8 +974,8 @@ This will output something like this:
init io_toposort 0.00171804428101
loop time 0.000502109527588
callback_time 0.0
0.002257s - ('local_gemm_to_gemv', 'Equilibrium
Optimiz
er', 3) - 0.000s
Equilibrium
Optimiz
er local_gemm_to_gemv
0.002257s - ('local_gemm_to_gemv', 'Equilibrium
GraphRewrit
er', 3) - 0.000s
Equilibrium
GraphRewrit
er local_gemm_to_gemv
time 0.002s for 1 passes
nb nodes (start, end, max) 80 80 80
time io_toposort 0.001s
...
...
@@ -994,8 +994,8 @@ This will output something like this:
init io_toposort 0.00138401985168
loop time 0.000202178955078
callback_time 0.0
0.031740s - ('specialize', 'Equilibrium
Optimiz
er', 9) - 0.000s
Equilibrium
Optimiz
er specialize
0.031740s - ('specialize', 'Equilibrium
GraphRewrit
er', 9) - 0.000s
Equilibrium
GraphRewrit
er specialize
time 0.031s for 2 passes
nb nodes (start, end, max) 80 78 80
time io_toposort 0.003s
...
...
@@ -1080,8 +1080,8 @@ To understand this profile here is some explanation of how optimizations work:
.. code-block:: none
0.751816s - ('canonicalize', 'Equilibrium
Optimiz
er', 4) - 0.004s
Equilibrium
Optimiz
er canonicalize
0.751816s - ('canonicalize', 'Equilibrium
GraphRewrit
er', 4) - 0.004s
Equilibrium
GraphRewrit
er canonicalize
time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s
...
...
@@ -1146,15 +1146,15 @@ To understand this profile here is some explanation of how optimizations work:
0.000s - local_subtensor_of_dot
0.000s - local_subtensor_merge
* ``0.751816s - ('canonicalize', 'Equilibrium
Optimiz
er', 4) - 0.004s``
* ``0.751816s - ('canonicalize', 'Equilibrium
GraphRewrit
er', 4) - 0.004s``
This line is from :class:`SequentialGraphRewriter`, and indicates information related
to a sub-optimizer. It means that this sub-optimizer took
a total of .7s. Its name is ``'canonicalize'``. It is an
:class:`Equilibrium
Optimiz
er`. It was executed at index 4 by the
:class:`Equilibrium
GraphRewrit
er`. It was executed at index 4 by the
:class:`SequentialGraphRewriter`. It spent 0.004s in the *validate* phase.
* All other lines are from the profiler of the :class:`Equilibrium
Optimiz
er`.
* All other lines are from the profiler of the :class:`Equilibrium
GraphRewrit
er`.
* An :class:`Equilibrium
Optimiz
er` does multiple passes on the Apply nodes from
* An :class:`Equilibrium
GraphRewrit
er` does multiple passes on the Apply nodes from
the graph, trying to apply local and global optimizations.
Conceptually, it tries to execute all global optimizations,
and to apply all local optimizations on all
...
...
tests/graph/test_kanren.py
浏览文件 @
ac213377
...
...
@@ -13,7 +13,7 @@ from aesara.graph.basic import Apply
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.kanren
import
KanrenRelationSub
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
Equilibrium
Optimiz
er
from
aesara.graph.opt
import
Equilibrium
GraphRewrit
er
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.unify
import
eval_if_etuple
from
aesara.tensor.math
import
Dot
,
_dot
...
...
@@ -151,7 +151,7 @@ def test_KanrenRelationSub_dot():
),
)
distribute_opt
=
Equilibrium
Optimiz
er
(
distribute_opt
=
Equilibrium
GraphRewrit
er
(
[
KanrenRelationSub
(
distributes
)],
max_use_ratio
=
10
)
...
...
tests/graph/test_opt.py
浏览文件 @
ac213377
...
...
@@ -6,7 +6,7 @@ from aesara.graph.features import Feature
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
(
Equilibrium
Optimiz
er
,
Equilibrium
GraphRewrit
er
,
MergeOptimizer
,
OpKeyGraphRewriter
,
OpToRewriterTracker
,
...
...
@@ -446,7 +446,7 @@ class TestEquilibrium:
e
=
op3
(
op4
(
x
,
y
))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
# print g
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
((
op1
,
"x"
,
"y"
),
(
op2
,
"x"
,
"y"
)),
PatternNodeRewriter
((
op4
,
"x"
,
"y"
),
(
op1
,
"x"
,
"y"
)),
...
...
@@ -463,7 +463,7 @@ class TestEquilibrium:
e
=
op1
(
op1
(
op3
(
x
,
y
)))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
# print g
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
((
op1
,
(
op2
,
"x"
,
"y"
)),
(
op4
,
"x"
,
"y"
)),
PatternNodeRewriter
((
op3
,
"x"
,
"y"
),
(
op4
,
"x"
,
"y"
)),
...
...
@@ -488,7 +488,7 @@ class TestEquilibrium:
oldlevel
=
_logger
.
level
_logger
.
setLevel
(
logging
.
CRITICAL
)
try
:
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
((
op1
,
"x"
,
"y"
),
(
op2
,
"x"
,
"y"
)),
PatternNodeRewriter
((
op4
,
"x"
,
"y"
),
(
op1
,
"x"
,
"y"
)),
...
...
@@ -600,7 +600,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
e
=
op1
(
x
)
fg
=
FunctionGraph
([
x
],
[
e
],
clone
=
False
)
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
(
(
op1
,
"x"
),
...
...
@@ -633,7 +633,7 @@ def test_patternsub_invalid_dtype(out_pattern):
e
=
op_cast_type2
(
x
)
fg
=
FunctionGraph
([
x
],
[
e
])
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
(
(
op_cast_type2
,
"x"
),
...
...
tests/graph/test_optdb.py
浏览文件 @
ac213377
...
...
@@ -45,8 +45,8 @@ class TestDB:
def
test_EquilibriumDB
(
self
):
eq_db
=
EquilibriumDB
()
with
pytest
.
raises
(
ValueError
,
match
=
r"`final_
opt
` and.*"
):
eq_db
.
register
(
"d"
,
TestOpt
(),
final_
opt
=
True
,
cleanup
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
r"`final_
rewriter
` and.*"
):
eq_db
.
register
(
"d"
,
TestOpt
(),
final_
rewriter
=
True
,
cleanup
=
True
)
def
test_SequenceDB
(
self
):
seq_db
=
SequenceDB
(
failure_callback
=
None
)
...
...
tests/tensor/random/test_opt.py
浏览文件 @
ac213377
...
...
@@ -7,7 +7,7 @@ from aesara.compile.function import function
from
aesara.compile.mode
import
Mode
from
aesara.graph.basic
import
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
Equilibrium
Optimiz
er
from
aesara.graph.opt
import
Equilibrium
GraphRewrit
er
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.random.basic
import
(
...
...
@@ -50,7 +50,7 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None
p
for
p
in
dist_params_at
+
size_at
if
not
isinstance
(
p
,
(
slice
,
Constant
))
]
mode
=
Mode
(
"py"
,
Equilibrium
Optimiz
er
([
opt
],
max_use_ratio
=
100
))
mode
=
Mode
(
"py"
,
Equilibrium
GraphRewrit
er
([
opt
],
max_use_ratio
=
100
))
f_opt
=
function
(
f_inputs
,
...
...
@@ -519,7 +519,7 @@ def test_Subtensor_lift_restrictions():
z
=
x
-
y
fg
=
FunctionGraph
([
rng
],
[
z
],
clone
=
False
)
_
=
Equilibrium
Optimiz
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
_
=
Equilibrium
GraphRewrit
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
subtensor_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
subtensor_node
==
y
.
owner
...
...
@@ -531,7 +531,7 @@ def test_Subtensor_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg
=
FunctionGraph
([
rng
],
[
z
,
x
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
assert
fg
.
outputs
[
0
]
==
z
assert
fg
.
outputs
[
1
]
==
x
...
...
@@ -539,7 +539,7 @@ def test_Subtensor_lift_restrictions():
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
fg
=
FunctionGraph
([
rng
],
[
z
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
rv_node
.
op
==
normal
...
...
@@ -557,7 +557,9 @@ def test_Dimshuffle_lift_restrictions():
z
=
x
-
y
fg
=
FunctionGraph
([
rng
],
[
z
,
y
],
clone
=
False
)
_
=
EquilibriumOptimizer
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
_
=
EquilibriumGraphRewriter
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
dimshuffle_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
dimshuffle_node
==
y
.
owner
...
...
@@ -569,7 +571,7 @@ def test_Dimshuffle_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg
=
FunctionGraph
([
rng
],
[
z
,
x
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
assert
fg
.
outputs
[
0
]
==
z
assert
fg
.
outputs
[
1
]
==
x
...
...
@@ -577,7 +579,7 @@ def test_Dimshuffle_lift_restrictions():
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
fg
=
FunctionGraph
([
rng
],
[
z
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
rv_node
.
op
==
normal
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论