Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ac213377
提交
ac213377
authored
7月 15, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename EquilibriumOptimizer to EquilibriumGraphRewriter
上级
e6635af8
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
101 行增加
和
89 行删除
+101
-89
mode.py
aesara/compile/mode.py
+3
-3
configdefaults.py
aesara/configdefaults.py
+1
-1
opt.py
aesara/graph/opt.py
+0
-0
optdb.py
aesara/graph/optdb.py
+50
-35
opt.py
aesara/scan/opt.py
+3
-8
basic_opt.py
aesara/tensor/basic_opt.py
+5
-5
blas.py
aesara/tensor/blas.py
+2
-2
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+17
-17
test_kanren.py
tests/graph/test_kanren.py
+2
-2
test_opt.py
tests/graph/test_opt.py
+6
-6
test_optdb.py
tests/graph/test_optdb.py
+2
-2
test_opt.py
tests/tensor/random/test_opt.py
+10
-8
没有找到文件。
aesara/compile/mode.py
浏览文件 @
ac213377
...
...
@@ -212,10 +212,10 @@ optdb.register(
"canonicalize_db"
,
position
=
1
,
)
# Register in the canonizer Equilibrium as a clean
up opt the merge opt
.
# Register in the canonizer Equilibrium as a clean
-up rewrite the merge rewrite
.
# Without this, as the equilibrium have ignore_newtrees=False, we
# won't merge all nodes if it is set as a global
optimiz
er with
# final_
opt
=True.
# won't merge all nodes if it is set as a global
rewrit
er with
# final_
rewriter
=True.
# We need a new instance of MergeOptimizer to don't have its name
# changed by other usage of it.
...
...
aesara/configdefaults.py
浏览文件 @
ac213377
...
...
@@ -1107,7 +1107,7 @@ def add_optimizer_configvars():
config
.
add
(
"optdb__max_use_ratio"
,
"A ratio that prevent infinite loop in Equilibrium
Optimiz
er."
,
"A ratio that prevent infinite loop in Equilibrium
GraphRewrit
er."
,
FloatParam
(
8
),
in_c_key
=
False
,
)
...
...
aesara/graph/opt.py
浏览文件 @
ac213377
差异被折叠。
点击展开。
aesara/graph/optdb.py
浏览文件 @
ac213377
...
...
@@ -31,19 +31,18 @@ class OptimizationDatabase:
def
register
(
self
,
name
:
str
,
optimiz
er
:
Union
[
"OptimizationDatabase"
,
OptimizersType
],
rewrit
er
:
Union
[
"OptimizationDatabase"
,
OptimizersType
],
*
tags
:
str
,
use_db_name_as_tag
=
True
,
**
kwargs
,
):
"""Register a new
optimiz
er to the database.
"""Register a new
rewrit
er to the database.
Parameters
----------
name:
Name of the
optimiz
er.
opt
:
The
optimiz
er to register.
Name of the
rewrit
er.
rewriter
:
The
rewrit
er to register.
tags:
Tag name that allow to select the optimizer.
use_db_name_as_tag:
...
...
@@ -58,14 +57,14 @@ class OptimizationDatabase:
"""
if
not
isinstance
(
optimiz
er
,
rewrit
er
,
(
OptimizationDatabase
,
aesara_opt
.
GraphRewriter
,
aesara_opt
.
NodeRewriter
,
),
):
raise
TypeError
(
f
"{
optimiz
er} is not a valid optimizer type."
)
raise
TypeError
(
f
"{
rewrit
er} is not a valid optimizer type."
)
if
name
in
self
.
__db__
:
raise
ValueError
(
f
"The tag '{name}' is already present in the database."
)
...
...
@@ -74,18 +73,18 @@ class OptimizationDatabase:
if
self
.
name
is
not
None
:
tags
=
tags
+
(
self
.
name
,)
optimiz
er
.
name
=
name
rewrit
er
.
name
=
name
# This restriction is there because in many place we suppose that
# something in the OptimizationDatabase is there only once.
if
optimiz
er
.
name
in
self
.
__db__
:
if
rewrit
er
.
name
in
self
.
__db__
:
raise
ValueError
(
f
"Tried to register {
optimiz
er.name} again under the new name {name}. "
f
"Tried to register {
rewrit
er.name} again under the new name {name}. "
"The same optimization cannot be registered multiple times in"
" an ``OptimizationDatabase``; use ProxyDB instead."
)
self
.
__db__
[
name
]
=
OrderedSet
([
optimiz
er
])
self
.
__db__
[
name
]
=
OrderedSet
([
rewrit
er
])
self
.
_names
.
add
(
name
)
self
.
__db__
[
optimizer
.
__class__
.
__name__
]
.
add
(
optimiz
er
)
self
.
__db__
[
rewriter
.
__class__
.
__name__
]
.
add
(
rewrit
er
)
self
.
add_tags
(
name
,
*
tags
)
def
add_tags
(
self
,
name
,
*
tags
):
...
...
@@ -292,11 +291,11 @@ class OptimizationQuery:
class
EquilibriumDB
(
OptimizationDatabase
):
"""A database of rewrites that should be applied until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium
optimization
s.
Canonicalize, Stabilize, and Specialize are all equilibrium
rewriter
s.
Notes
-----
We can use `NodeRewriter` and `GraphRewriter` since `Equilibrium
Optimiz
er`
We can use `NodeRewriter` and `GraphRewriter` since `Equilibrium
GraphRewrit
er`
supports both.
It is probably not a good idea to have both ``ignore_newtrees == False``
...
...
@@ -322,33 +321,47 @@ class EquilibriumDB(OptimizationDatabase):
super
()
.
__init__
()
self
.
ignore_newtrees
=
ignore_newtrees
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
__final__
:
Dict
[
str
,
aesara_opt
.
Rewriter
]
=
{}
self
.
__cleanup__
:
Dict
[
str
,
aesara_opt
.
Rewriter
]
=
{}
self
.
__final__
:
Dict
[
str
,
bool
]
=
{}
self
.
__cleanup__
:
Dict
[
str
,
bool
]
=
{}
def
register
(
self
,
name
,
obj
,
*
tags
,
final_opt
=
False
,
cleanup
=
False
,
**
kwargs
):
if
final_opt
and
cleanup
:
raise
ValueError
(
"`final_opt` and `cleanup` cannot both be true."
)
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwargs
)
self
.
__final__
[
name
]
=
final_opt
def
register
(
self
,
name
:
str
,
rewriter
:
Union
[
"OptimizationDatabase"
,
OptimizersType
],
*
tags
:
str
,
final_rewriter
:
bool
=
False
,
cleanup
:
bool
=
False
,
**
kwargs
,
):
if
final_rewriter
and
cleanup
:
raise
ValueError
(
"`final_rewriter` and `cleanup` cannot both be true."
)
super
()
.
register
(
name
,
rewriter
,
*
tags
,
**
kwargs
)
self
.
__final__
[
name
]
=
final_rewriter
self
.
__cleanup__
[
name
]
=
cleanup
def
query
(
self
,
*
tags
,
**
kwtags
):
_opts
=
super
()
.
query
(
*
tags
,
**
kwtags
)
final_opts
=
[
o
for
o
in
_opts
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
cleanup_opts
=
[
o
for
o
in
_opts
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)]
opts
=
[
o
for
o
in
_opts
if
o
not
in
final_opts
and
o
not
in
cleanup_opts
]
if
len
(
final_opts
)
==
0
:
final_opts
=
None
if
len
(
cleanup_opts
)
==
0
:
cleanup_opts
=
None
return
aesara_opt
.
EquilibriumOptimizer
(
opts
,
_rewriters
=
super
()
.
query
(
*
tags
,
**
kwtags
)
final_rewriters
=
[
o
for
o
in
_rewriters
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
cleanup_rewriters
=
[
o
for
o
in
_rewriters
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)
]
rewriters
=
[
o
for
o
in
_rewriters
if
o
not
in
final_rewriters
and
o
not
in
cleanup_rewriters
]
if
len
(
final_rewriters
)
==
0
:
final_rewriters
=
None
if
len
(
cleanup_rewriters
)
==
0
:
cleanup_rewriters
=
None
return
aesara_opt
.
EquilibriumGraphRewriter
(
rewriters
,
max_use_ratio
=
config
.
optdb__max_use_ratio
,
ignore_newtrees
=
self
.
ignore_newtrees
,
tracks_on_change_inputs
=
self
.
tracks_on_change_inputs
,
failure_callback
=
aesara_opt
.
NodeProcessingGraphRewriter
.
warn_inplace
,
final_
optimizers
=
final_opt
s
,
cleanup_
optimizers
=
cleanup_opt
s
,
final_
rewriters
=
final_rewriter
s
,
cleanup_
rewriters
=
cleanup_rewriter
s
,
)
...
...
@@ -372,8 +385,10 @@ class SequenceDB(OptimizationDatabase):
self
.
failure_callback
=
failure_callback
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwargs
)
position
=
kwargs
.
pop
(
"position"
,
"last"
)
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwargs
)
if
position
==
"last"
:
if
len
(
self
.
__position__
)
==
0
:
self
.
__position__
[
name
]
=
0
...
...
aesara/scan/opt.py
浏览文件 @
ac213377
...
...
@@ -2373,7 +2373,7 @@ optdb.register(
position
=
75
,
)
scan_eqopt1
.
register
(
"all_pushout_opt"
,
scan_seqopt1
,
"fast_run"
,
"scan"
,
position
=
1
)
scan_eqopt1
.
register
(
"all_pushout_opt"
,
scan_seqopt1
,
"fast_run"
,
"scan"
)
scan_seqopt1
.
register
(
...
...
@@ -2419,7 +2419,7 @@ scan_seqopt1.register(
scan_seqopt1
.
register
(
"scan_pushout_add"
,
# TODO: Perhaps this should be an `Equilibrium
Optimiz
er`?
# TODO: Perhaps this should be an `Equilibrium
GraphRewrit
er`?
in2out
(
push_out_add_scan
,
ignore_newtrees
=
False
),
"fast_run"
,
"more_mem"
,
...
...
@@ -2434,7 +2434,6 @@ scan_eqopt2.register(
in2out
(
basic_opt
.
constant_folding
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
position
=
1
,
)
...
...
@@ -2444,14 +2443,13 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"scan"
,
position
=
2
,
)
# after const merge but before stabilize so that we can have identity
# for equivalent nodes but we still have the chance to hoist stuff out
# of the scan later.
scan_eqopt2
.
register
(
"scan_merge"
,
ScanMerge
(),
"fast_run"
,
"scan"
,
position
=
4
)
scan_eqopt2
.
register
(
"scan_merge"
,
ScanMerge
(),
"fast_run"
,
"scan"
)
# After Merge optimization
scan_eqopt2
.
register
(
...
...
@@ -2460,7 +2458,6 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"scan"
,
position
=
5
,
)
scan_eqopt2
.
register
(
...
...
@@ -2468,7 +2465,6 @@ scan_eqopt2.register(
in2out
(
scan_merge_inouts
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
position
=
6
,
)
# After everything else
...
...
@@ -2478,5 +2474,4 @@ scan_eqopt2.register(
"remove_constants_and_unused_inputs_scan"
,
"fast_run"
,
"scan"
,
position
=
8
,
)
aesara/tensor/basic_opt.py
浏览文件 @
ac213377
...
...
@@ -2802,10 +2802,10 @@ def constant_folding(fgraph, node):
topo_constant_folding
=
in2out
(
constant_folding
,
ignore_newtrees
=
True
,
name
=
"topo_constant_folding"
)
register_canonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_uncanonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_stabilize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_specialize
(
topo_constant_folding
,
"fast_compile"
,
final_
opt
=
True
)
register_canonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
register_uncanonicalize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
register_stabilize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
register_specialize
(
topo_constant_folding
,
"fast_compile"
,
final_
rewriter
=
True
)
def
local_elemwise_fusion_op
(
op_class
,
max_input_fct
=
lambda
node
:
32
,
maker
=
None
):
...
...
@@ -3096,7 +3096,7 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class
FusionOptimizer
(
GraphRewriter
):
"""Graph rewriter that simply runs node fusion operations.
TODO: This is basically an `Equilibrium
Optimiz
er`; we should just use that.
TODO: This is basically an `Equilibrium
GraphRewrit
er`; we should just use that.
"""
...
...
aesara/tensor/blas.py
浏览文件 @
ac213377
...
...
@@ -146,7 +146,7 @@ from aesara.graph.basic import Apply, view_roots
from
aesara.graph.features
import
ReplacementDidNotRemoveError
,
ReplaceValidate
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
(
Equilibrium
Optimiz
er
,
Equilibrium
GraphRewrit
er
,
GraphRewriter
,
copy_stack_trace
,
in2out
,
...
...
@@ -1906,7 +1906,7 @@ blas_optdb.register(
blas_optdb
.
register
(
"gemm_optimizer"
,
GemmOptimizer
(),
"fast_run"
,
position
=
10
)
blas_optdb
.
register
(
"local_gemm_to_gemv"
,
Equilibrium
Optimiz
er
(
Equilibrium
GraphRewrit
er
(
[
local_gemm_to_gemv
,
local_gemm_to_ger
,
...
...
doc/extending/graph_rewriting.rst
浏览文件 @
ac213377
...
...
@@ -444,7 +444,7 @@ The following is an example that distributes dot products across additions.
import aesara
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
from aesara.graph.opt import Equilibrium
Optimiz
er
from aesara.graph.opt import Equilibrium
GraphRewrit
er
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.math import _dot
from etuples import etuple
...
...
@@ -484,7 +484,7 @@ The following is an example that distributes dot products across additions.
)
dot_distribute_opt = Equilibrium
Optimiz
er([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
dot_distribute_opt = Equilibrium
GraphRewrit
er([KanrenRelationSub(dot_distributeo)], max_use_ratio=10)
Below, we apply `dot_distribute_opt` to a few example graphs. First we create simple test graph:
...
...
@@ -531,7 +531,7 @@ relational properties.
To do that, we will create another :class:`Rewriter` that simply reverses the arguments
to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``:
>>> dot_gather_opt = Equilibrium
Optimiz
er([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> dot_gather_opt = Equilibrium
GraphRewrit
er([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10)
>>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_opt, clone=False)
>>> print(aesara.pprint(rev_res))
(A @ (x + (y + (B @ (z + w)))))
...
...
@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`Equilibrium
Optimiz
er` containing :class:`NodeRewriter`\s A, B
maybe optimization Y is an :class:`Equilibrium
GraphRewrit
er` containing :class:`NodeRewriter`\s A, B
and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to
...
...
@@ -599,7 +599,7 @@ optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`Equilibrium
Optimiz
er`, which is returned. If the
inserted into an :class:`Equilibrium
GraphRewrit
er`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places
...
...
@@ -859,8 +859,8 @@ This will output something like this:
0.028s for fgraph.validate()
0.131s for callback
time - (name, class, index) - validate time
0.751816s - ('canonicalize', 'Equilibrium
Optimiz
er', 4) - 0.004s
Equilibrium
Optimiz
er canonicalize
0.751816s - ('canonicalize', 'Equilibrium
GraphRewrit
er', 4) - 0.004s
Equilibrium
GraphRewrit
er canonicalize
time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s
...
...
@@ -974,8 +974,8 @@ This will output something like this:
init io_toposort 0.00171804428101
loop time 0.000502109527588
callback_time 0.0
0.002257s - ('local_gemm_to_gemv', 'Equilibrium
Optimiz
er', 3) - 0.000s
Equilibrium
Optimiz
er local_gemm_to_gemv
0.002257s - ('local_gemm_to_gemv', 'Equilibrium
GraphRewrit
er', 3) - 0.000s
Equilibrium
GraphRewrit
er local_gemm_to_gemv
time 0.002s for 1 passes
nb nodes (start, end, max) 80 80 80
time io_toposort 0.001s
...
...
@@ -994,8 +994,8 @@ This will output something like this:
init io_toposort 0.00138401985168
loop time 0.000202178955078
callback_time 0.0
0.031740s - ('specialize', 'Equilibrium
Optimiz
er', 9) - 0.000s
Equilibrium
Optimiz
er specialize
0.031740s - ('specialize', 'Equilibrium
GraphRewrit
er', 9) - 0.000s
Equilibrium
GraphRewrit
er specialize
time 0.031s for 2 passes
nb nodes (start, end, max) 80 78 80
time io_toposort 0.003s
...
...
@@ -1080,8 +1080,8 @@ To understand this profile here is some explanation of how optimizations work:
.. code-block:: none
0.751816s - ('canonicalize', 'Equilibrium
Optimiz
er', 4) - 0.004s
Equilibrium
Optimiz
er canonicalize
0.751816s - ('canonicalize', 'Equilibrium
GraphRewrit
er', 4) - 0.004s
Equilibrium
GraphRewrit
er canonicalize
time 0.751s for 14 passes
nb nodes (start, end, max) 108 81 117
time io_toposort 0.029s
...
...
@@ -1146,15 +1146,15 @@ To understand this profile here is some explanation of how optimizations work:
0.000s - local_subtensor_of_dot
0.000s - local_subtensor_merge
* ``0.751816s - ('canonicalize', 'Equilibrium
Optimiz
er', 4) - 0.004s``
* ``0.751816s - ('canonicalize', 'Equilibrium
GraphRewrit
er', 4) - 0.004s``
This line is from :class:`SequentialGraphRewriter`, and indicates information related
to a sub-optimizer. It means that this sub-optimizer took
a total of .7s. Its name is ``'canonicalize'``. It is an
:class:`Equilibrium
Optimiz
er`. It was executed at index 4 by the
:class:`Equilibrium
GraphRewrit
er`. It was executed at index 4 by the
:class:`SequentialGraphRewriter`. It spent 0.004s in the *validate* phase.
* All other lines are from the profiler of the :class:`Equilibrium
Optimiz
er`.
* All other lines are from the profiler of the :class:`Equilibrium
GraphRewrit
er`.
* An :class:`Equilibrium
Optimiz
er` does multiple passes on the Apply nodes from
* An :class:`Equilibrium
GraphRewrit
er` does multiple passes on the Apply nodes from
the graph, trying to apply local and global optimizations.
Conceptually, it tries to execute all global optimizations,
and to apply all local optimizations on all
...
...
tests/graph/test_kanren.py
浏览文件 @
ac213377
...
...
@@ -13,7 +13,7 @@ from aesara.graph.basic import Apply
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.kanren
import
KanrenRelationSub
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
Equilibrium
Optimiz
er
from
aesara.graph.opt
import
Equilibrium
GraphRewrit
er
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.unify
import
eval_if_etuple
from
aesara.tensor.math
import
Dot
,
_dot
...
...
@@ -151,7 +151,7 @@ def test_KanrenRelationSub_dot():
),
)
distribute_opt
=
Equilibrium
Optimiz
er
(
distribute_opt
=
Equilibrium
GraphRewrit
er
(
[
KanrenRelationSub
(
distributes
)],
max_use_ratio
=
10
)
...
...
tests/graph/test_opt.py
浏览文件 @
ac213377
...
...
@@ -6,7 +6,7 @@ from aesara.graph.features import Feature
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
(
Equilibrium
Optimiz
er
,
Equilibrium
GraphRewrit
er
,
MergeOptimizer
,
OpKeyGraphRewriter
,
OpToRewriterTracker
,
...
...
@@ -446,7 +446,7 @@ class TestEquilibrium:
e
=
op3
(
op4
(
x
,
y
))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
# print g
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
((
op1
,
"x"
,
"y"
),
(
op2
,
"x"
,
"y"
)),
PatternNodeRewriter
((
op4
,
"x"
,
"y"
),
(
op1
,
"x"
,
"y"
)),
...
...
@@ -463,7 +463,7 @@ class TestEquilibrium:
e
=
op1
(
op1
(
op3
(
x
,
y
)))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
# print g
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
((
op1
,
(
op2
,
"x"
,
"y"
)),
(
op4
,
"x"
,
"y"
)),
PatternNodeRewriter
((
op3
,
"x"
,
"y"
),
(
op4
,
"x"
,
"y"
)),
...
...
@@ -488,7 +488,7 @@ class TestEquilibrium:
oldlevel
=
_logger
.
level
_logger
.
setLevel
(
logging
.
CRITICAL
)
try
:
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
((
op1
,
"x"
,
"y"
),
(
op2
,
"x"
,
"y"
)),
PatternNodeRewriter
((
op4
,
"x"
,
"y"
),
(
op1
,
"x"
,
"y"
)),
...
...
@@ -600,7 +600,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
e
=
op1
(
x
)
fg
=
FunctionGraph
([
x
],
[
e
],
clone
=
False
)
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
(
(
op1
,
"x"
),
...
...
@@ -633,7 +633,7 @@ def test_patternsub_invalid_dtype(out_pattern):
e
=
op_cast_type2
(
x
)
fg
=
FunctionGraph
([
x
],
[
e
])
opt
=
Equilibrium
Optimiz
er
(
opt
=
Equilibrium
GraphRewrit
er
(
[
PatternNodeRewriter
(
(
op_cast_type2
,
"x"
),
...
...
tests/graph/test_optdb.py
浏览文件 @
ac213377
...
...
@@ -45,8 +45,8 @@ class TestDB:
def
test_EquilibriumDB
(
self
):
eq_db
=
EquilibriumDB
()
with
pytest
.
raises
(
ValueError
,
match
=
r"`final_
opt
` and.*"
):
eq_db
.
register
(
"d"
,
TestOpt
(),
final_
opt
=
True
,
cleanup
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
r"`final_
rewriter
` and.*"
):
eq_db
.
register
(
"d"
,
TestOpt
(),
final_
rewriter
=
True
,
cleanup
=
True
)
def
test_SequenceDB
(
self
):
seq_db
=
SequenceDB
(
failure_callback
=
None
)
...
...
tests/tensor/random/test_opt.py
浏览文件 @
ac213377
...
...
@@ -7,7 +7,7 @@ from aesara.compile.function import function
from
aesara.compile.mode
import
Mode
from
aesara.graph.basic
import
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
Equilibrium
Optimiz
er
from
aesara.graph.opt
import
Equilibrium
GraphRewrit
er
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.random.basic
import
(
...
...
@@ -50,7 +50,7 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None
p
for
p
in
dist_params_at
+
size_at
if
not
isinstance
(
p
,
(
slice
,
Constant
))
]
mode
=
Mode
(
"py"
,
Equilibrium
Optimiz
er
([
opt
],
max_use_ratio
=
100
))
mode
=
Mode
(
"py"
,
Equilibrium
GraphRewrit
er
([
opt
],
max_use_ratio
=
100
))
f_opt
=
function
(
f_inputs
,
...
...
@@ -519,7 +519,7 @@ def test_Subtensor_lift_restrictions():
z
=
x
-
y
fg
=
FunctionGraph
([
rng
],
[
z
],
clone
=
False
)
_
=
Equilibrium
Optimiz
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
_
=
Equilibrium
GraphRewrit
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
subtensor_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
subtensor_node
==
y
.
owner
...
...
@@ -531,7 +531,7 @@ def test_Subtensor_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg
=
FunctionGraph
([
rng
],
[
z
,
x
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
assert
fg
.
outputs
[
0
]
==
z
assert
fg
.
outputs
[
1
]
==
x
...
...
@@ -539,7 +539,7 @@ def test_Subtensor_lift_restrictions():
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
fg
=
FunctionGraph
([
rng
],
[
z
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
rv_node
.
op
==
normal
...
...
@@ -557,7 +557,9 @@ def test_Dimshuffle_lift_restrictions():
z
=
x
-
y
fg
=
FunctionGraph
([
rng
],
[
z
,
y
],
clone
=
False
)
_
=
EquilibriumOptimizer
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
_
=
EquilibriumGraphRewriter
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
dimshuffle_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
dimshuffle_node
==
y
.
owner
...
...
@@ -569,7 +571,7 @@ def test_Dimshuffle_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg
=
FunctionGraph
([
rng
],
[
z
,
x
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
assert
fg
.
outputs
[
0
]
==
z
assert
fg
.
outputs
[
1
]
==
x
...
...
@@ -577,7 +579,7 @@ def test_Dimshuffle_lift_restrictions():
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
fg
=
FunctionGraph
([
rng
],
[
z
],
clone
=
False
)
Equilibrium
Optimiz
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
Equilibrium
GraphRewrit
er
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
rv_node
.
op
==
normal
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论