Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c2672deb
提交
c2672deb
authored
5月 23, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
5月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename Query to OptimizationQuery and DB to OptimizationDatabase,
上级
1c11dd44
隐藏空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
146 行增加
和
110 行删除
+146
-110
mode.py
aesara/compile/mode.py
+24
-12
optdb.py
aesara/gpuarray/optdb.py
+8
-3
opt_utils.py
aesara/graph/opt_utils.py
+3
-3
optdb.py
aesara/graph/optdb.py
+48
-30
optimization.txt
doc/extending/optimization.txt
+31
-32
test_optdb.py
tests/graph/test_optdb.py
+2
-2
test_jax.py
tests/link/test_jax.py
+3
-3
test_numba.py
tests/link/test_numba.py
+2
-2
test_opt.py
tests/tensor/random/test_opt.py
+5
-3
test_utils.py
tests/tensor/random/test_utils.py
+2
-2
test_basic_opt.py
tests/tensor/test_basic_opt.py
+6
-6
test_extra_ops.py
tests/tensor/test_extra_ops.py
+2
-2
test_math_opt.py
tests/tensor/test_math_opt.py
+10
-10
没有找到文件。
aesara/compile/mode.py
浏览文件 @
c2672deb
...
@@ -17,7 +17,13 @@ from aesara.graph.opt import (
...
@@ -17,7 +17,13 @@ from aesara.graph.opt import (
MergeOptimizer
,
MergeOptimizer
,
NavigatorOptimizer
,
NavigatorOptimizer
,
)
)
from
aesara.graph.optdb
import
EquilibriumDB
,
LocalGroupDB
,
Query
,
SequenceDB
,
TopoDB
from
aesara.graph.optdb
import
(
EquilibriumDB
,
LocalGroupDB
,
OptimizationQuery
,
SequenceDB
,
TopoDB
,
)
from
aesara.link.basic
import
PerformLinker
from
aesara.link.basic
import
PerformLinker
from
aesara.link.c.basic
import
CLinker
,
OpWiseCLinker
from
aesara.link.c.basic
import
CLinker
,
OpWiseCLinker
from
aesara.link.jax.linker
import
JAXLinker
from
aesara.link.jax.linker
import
JAXLinker
...
@@ -58,19 +64,21 @@ def register_linker(name, linker):
...
@@ -58,19 +64,21 @@ def register_linker(name, linker):
exclude
=
[]
exclude
=
[]
if
not
config
.
cxx
:
if
not
config
.
cxx
:
exclude
=
[
"cxx_only"
]
exclude
=
[
"cxx_only"
]
OPT_NONE
=
Query
(
include
=
[],
exclude
=
exclude
)
OPT_NONE
=
Optimization
Query
(
include
=
[],
exclude
=
exclude
)
# Even if multiple merge optimizer call will be there, this shouldn't
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
# impact performance.
OPT_MERGE
=
Query
(
include
=
[
"merge"
],
exclude
=
exclude
)
OPT_MERGE
=
Optimization
Query
(
include
=
[
"merge"
],
exclude
=
exclude
)
OPT_FAST_RUN
=
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_FAST_RUN
=
Optimization
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
"stable"
)
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
"stable"
)
# We need fast_compile_gpu here. As on the GPU, we don't have all
# We need fast_compile_gpu here. As on the GPU, we don't have all
# operation that exist in fast_compile, but have some that get
# operation that exist in fast_compile, but have some that get
# introduced in fast_run, we want those optimization to also run in
# introduced in fast_run, we want those optimization to also run in
# fast_compile+gpu. We can't tag them just as 'gpu', as this would
# fast_compile+gpu. We can't tag them just as 'gpu', as this would
# exclude them if we exclude 'gpu'.
# exclude them if we exclude 'gpu'.
OPT_FAST_COMPILE
=
Query
(
include
=
[
"fast_compile"
,
"fast_compile_gpu"
],
exclude
=
exclude
)
OPT_FAST_COMPILE
=
OptimizationQuery
(
OPT_STABILIZE
=
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
include
=
[
"fast_compile"
,
"fast_compile_gpu"
],
exclude
=
exclude
)
OPT_STABILIZE
=
OptimizationQuery
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_STABILIZE
.
position_cutoff
=
1.5000001
OPT_STABILIZE
.
position_cutoff
=
1.5000001
OPT_NONE
.
name
=
"OPT_NONE"
OPT_NONE
.
name
=
"OPT_NONE"
OPT_MERGE
.
name
=
"OPT_MERGE"
OPT_MERGE
.
name
=
"OPT_MERGE"
...
@@ -297,7 +305,7 @@ class Mode:
...
@@ -297,7 +305,7 @@ class Mode:
# self.provided_optimizer - typically the `optimizer` arg.
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
# But if the `optimizer` arg is keyword corresponding to a predefined
# Query, then this stores the query
#
Optimization
Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# self.__get_optimizer - returns self._optimizer (possibly querying
...
@@ -316,7 +324,7 @@ class Mode:
...
@@ -316,7 +324,7 @@ class Mode:
self
.
linker
=
linker
self
.
linker
=
linker
if
isinstance
(
optimizer
,
str
)
or
optimizer
is
None
:
if
isinstance
(
optimizer
,
str
)
or
optimizer
is
None
:
optimizer
=
predefined_optimizers
[
optimizer
]
optimizer
=
predefined_optimizers
[
optimizer
]
if
isinstance
(
optimizer
,
Query
):
if
isinstance
(
optimizer
,
Optimization
Query
):
self
.
provided_optimizer
=
optimizer
self
.
provided_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
call_time
=
0
self
.
call_time
=
0
...
@@ -330,7 +338,7 @@ class Mode:
...
@@ -330,7 +338,7 @@ class Mode:
)
)
def
__get_optimizer
(
self
):
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
Query
):
if
isinstance
(
self
.
_optimizer
,
Optimization
Query
):
return
optdb
.
query
(
self
.
_optimizer
)
return
optdb
.
query
(
self
.
_optimizer
)
else
:
else
:
return
self
.
_optimizer
return
self
.
_optimizer
...
@@ -348,7 +356,7 @@ class Mode:
...
@@ -348,7 +356,7 @@ class Mode:
link
,
opt
=
self
.
get_linker_optimizer
(
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
self
.
provided_linker
,
self
.
provided_optimizer
)
)
# N.B. opt might be a Query instance, not sure what else it might be...
# N.B. opt might be a
Optimization
Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
# string? Optimizer? OptDB? who knows???
return
self
.
clone
(
optimizer
=
opt
.
including
(
*
tags
),
linker
=
link
)
return
self
.
clone
(
optimizer
=
opt
.
including
(
*
tags
),
linker
=
link
)
...
@@ -421,9 +429,13 @@ if config.cxx:
...
@@ -421,9 +429,13 @@ if config.cxx:
else
:
else
:
FAST_RUN
=
Mode
(
"vm"
,
"fast_run"
)
FAST_RUN
=
Mode
(
"vm"
,
"fast_run"
)
JAX
=
Mode
(
JAXLinker
(),
Query
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]))
JAX
=
Mode
(
JAXLinker
(),
OptimizationQuery
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
)
NUMBA
=
Mode
(
NUMBA
=
Mode
(
NumbaLinker
(),
Query
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
NumbaLinker
(),
OptimizationQuery
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
)
)
...
...
aesara/gpuarray/optdb.py
浏览文件 @
c2672deb
from
aesara.compile
import
optdb
from
aesara.compile
import
optdb
from
aesara.graph.opt
import
GraphToGPULocalOptGroup
,
TopoOptimizer
,
local_optimizer
from
aesara.graph.opt
import
GraphToGPULocalOptGroup
,
TopoOptimizer
,
local_optimizer
from
aesara.graph.optdb
import
DB
,
EquilibriumDB
,
LocalGroupDB
,
SequenceDB
from
aesara.graph.optdb
import
(
EquilibriumDB
,
LocalGroupDB
,
OptimizationDatabase
,
SequenceDB
,
)
gpu_optimizer
=
EquilibriumDB
()
gpu_optimizer
=
EquilibriumDB
()
...
@@ -62,7 +67,7 @@ def register_opt2(tracks, *tags, **kwargs):
...
@@ -62,7 +67,7 @@ def register_opt2(tracks, *tags, **kwargs):
def
f
(
local_opt
):
def
f
(
local_opt
):
name
=
(
kwargs
and
kwargs
.
pop
(
"name"
))
or
local_opt
.
__name__
name
=
(
kwargs
and
kwargs
.
pop
(
"name"
))
or
local_opt
.
__name__
if
isinstance
(
local_opt
,
DB
):
if
isinstance
(
local_opt
,
OptimizationDatabase
):
opt
=
local_opt
opt
=
local_opt
else
:
else
:
opt
=
local_optimizer
(
tracks
)(
local_opt
)
opt
=
local_optimizer
(
tracks
)(
local_opt
)
...
@@ -97,7 +102,7 @@ abstractconv_groupopt.__name__ = "gpuarray_abstractconv_opts"
...
@@ -97,7 +102,7 @@ abstractconv_groupopt.__name__ = "gpuarray_abstractconv_opts"
register_opt
(
"fast_compile"
)(
abstractconv_groupopt
)
register_opt
(
"fast_compile"
)(
abstractconv_groupopt
)
class
GraphToGPUDB
(
DB
):
class
GraphToGPUDB
(
OptimizationDatabase
):
"""
"""
Retrieves the list local optimizers based on the optimizer flag's value
Retrieves the list local optimizers based on the optimizer flag's value
from EquilibriumOptimizer by calling the method query.
from EquilibriumOptimizer by calling the method query.
...
...
aesara/graph/opt_utils.py
浏览文件 @
c2672deb
...
@@ -4,7 +4,7 @@ from typing import Sequence, Union
...
@@ -4,7 +4,7 @@ from typing import Sequence, Union
import
aesara
import
aesara
from
aesara.graph.basic
import
Variable
,
equal_computations
,
graph_inputs
,
vars_between
from
aesara.graph.basic
import
Variable
,
equal_computations
,
graph_inputs
,
vars_between
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
def
optimize_graph
(
def
optimize_graph
(
...
@@ -28,7 +28,7 @@ def optimize_graph(
...
@@ -28,7 +28,7 @@ def optimize_graph(
clone:
clone:
Whether or not to clone the input graph before optimizing.
Whether or not to clone the input graph before optimizing.
**kwargs:
**kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.Query`` object.
Keyword arguments passed to the ``aesara.graph.optdb.
Optimization
Query`` object.
"""
"""
from
aesara.compile
import
optdb
from
aesara.compile
import
optdb
...
@@ -37,7 +37,7 @@ def optimize_graph(
...
@@ -37,7 +37,7 @@ def optimize_graph(
fgraph
=
FunctionGraph
(
outputs
=
[
fgraph
],
clone
=
clone
)
fgraph
=
FunctionGraph
(
outputs
=
[
fgraph
],
clone
=
clone
)
return_only_out
=
True
return_only_out
=
True
canonicalize_opt
=
optdb
.
query
(
Query
(
include
=
include
,
**
kwargs
))
canonicalize_opt
=
optdb
.
query
(
Optimization
Query
(
include
=
include
,
**
kwargs
))
_
=
canonicalize_opt
.
optimize
(
fgraph
)
_
=
canonicalize_opt
.
optimize
(
fgraph
)
if
custom_opt
:
if
custom_opt
:
...
...
aesara/graph/optdb.py
浏览文件 @
c2672deb
...
@@ -10,7 +10,13 @@ from aesara.misc.ordered_set import OrderedSet
...
@@ -10,7 +10,13 @@ from aesara.misc.ordered_set import OrderedSet
from
aesara.utils
import
DefaultOrderedDict
from
aesara.utils
import
DefaultOrderedDict
class
DB
:
class
OptimizationDatabase
:
"""A class that represents a collection/database of optimizations.
These databases can be used to logically organize sets of
(i.e. ``GlobalOptimizer``s and ``LocalOptimizer``)
"""
def
__hash__
(
self
):
def
__hash__
(
self
):
if
not
hasattr
(
self
,
"_optimizer_idx"
):
if
not
hasattr
(
self
,
"_optimizer_idx"
):
self
.
_optimizer_idx
=
opt
.
_optimizer_idx
[
0
]
self
.
_optimizer_idx
=
opt
.
_optimizer_idx
[
0
]
...
@@ -24,7 +30,7 @@ class DB:
...
@@ -24,7 +30,7 @@ class DB:
# (via obj.name by the thing doing the registering)
# (via obj.name by the thing doing the registering)
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
"""
"""
Register a new optimizer to the database.
Parameters
Parameters
----------
----------
...
@@ -35,19 +41,21 @@ class DB:
...
@@ -35,19 +41,21 @@ class DB:
tags
tags
Tag name that allow to select the optimizer.
Tag name that allow to select the optimizer.
kwargs
kwargs
If non empty, should contain only
use_db_name_as_tag=False.
If non empty, should contain only
``use_db_name_as_tag=False``. By
By default, all optimizations registered in EquilibriumDB
default, all optimizations registered in ``EquilibriumDB`` are
are selected when the EquilibriumDB name is used as a
selected when the ``EquilibriumDB`` name is used as a tag. We do
tag. We do
not want this behavior for some optimizer like
not want this behavior for some optimizer like
local_remove_all_assert. use_db_name_as_tag=False remove
``local_remove_all_assert``. ``use_db_name_as_tag=False`` removes
that behavior. This mean only the optimizer name and the
that behavior. This mean only the optimizer name and the
tags
tags
specified will enable that optimization.
specified will enable that optimization.
"""
"""
# N.B. obj is not an instance of class `GlobalOptimizer`.
# N.B. obj is not an instance of class `GlobalOptimizer`.
# It is an instance of a DB.In the tests for example,
# It is an instance of a DB.In the tests for example,
# this is not always the case.
# this is not always the case.
if
not
isinstance
(
obj
,
(
DB
,
opt
.
GlobalOptimizer
,
opt
.
LocalOptimizer
)):
if
not
isinstance
(
obj
,
(
OptimizationDatabase
,
opt
.
GlobalOptimizer
,
opt
.
LocalOptimizer
)
):
raise
TypeError
(
"Object cannot be registered in OptDB"
,
obj
)
raise
TypeError
(
"Object cannot be registered in OptDB"
,
obj
)
if
name
in
self
.
__db__
:
if
name
in
self
.
__db__
:
raise
ValueError
(
raise
ValueError
(
...
@@ -99,8 +107,8 @@ class DB:
...
@@ -99,8 +107,8 @@ class DB:
self
.
__db__
[
tag
]
.
remove
(
obj
)
self
.
__db__
[
tag
]
.
remove
(
obj
)
def
__query__
(
self
,
q
):
def
__query__
(
self
,
q
):
if
not
isinstance
(
q
,
Query
):
if
not
isinstance
(
q
,
Optimization
Query
):
raise
TypeError
(
"Expected a Query."
,
q
)
raise
TypeError
(
"Expected a
Optimization
Query."
,
q
)
# The ordered set is needed for deterministic optimization.
# The ordered set is needed for deterministic optimization.
variables
=
OrderedSet
()
variables
=
OrderedSet
()
for
tag
in
q
.
include
:
for
tag
in
q
.
include
:
...
@@ -112,7 +120,7 @@ class DB:
...
@@ -112,7 +120,7 @@ class DB:
remove
=
OrderedSet
()
remove
=
OrderedSet
()
add
=
OrderedSet
()
add
=
OrderedSet
()
for
obj
in
variables
:
for
obj
in
variables
:
if
isinstance
(
obj
,
DB
):
if
isinstance
(
obj
,
OptimizationDatabase
):
def_sub_query
=
q
def_sub_query
=
q
if
q
.
extra_optimizations
:
if
q
.
extra_optimizations
:
def_sub_query
=
copy
.
copy
(
q
)
def_sub_query
=
copy
.
copy
(
q
)
...
@@ -128,10 +136,10 @@ class DB:
...
@@ -128,10 +136,10 @@ class DB:
return
variables
return
variables
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Optimization
Query
):
if
len
(
tags
)
>
1
or
kwtags
:
if
len
(
tags
)
>
1
or
kwtags
:
raise
TypeError
(
raise
TypeError
(
"If the first argument to query is a Query,"
"If the first argument to query is a
Optimization
Query,"
" there should be no other arguments."
,
" there should be no other arguments."
,
tags
,
tags
,
kwtags
,
kwtags
,
...
@@ -147,7 +155,9 @@ class DB:
...
@@ -147,7 +155,9 @@ class DB:
tags
,
tags
,
)
)
return
self
.
__query__
(
return
self
.
__query__
(
Query
(
include
=
include
,
require
=
require
,
exclude
=
exclude
,
subquery
=
kwtags
)
OptimizationQuery
(
include
=
include
,
require
=
require
,
exclude
=
exclude
,
subquery
=
kwtags
)
)
)
def
__getitem__
(
self
,
name
):
def
__getitem__
(
self
,
name
):
...
@@ -168,7 +178,11 @@ class DB:
...
@@ -168,7 +178,11 @@ class DB:
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
class
Query
:
# This is deprecated and will be removed.
DB
=
OptimizationDatabase
class
OptimizationQuery
:
"""
"""
Parameters
Parameters
...
@@ -204,7 +218,7 @@ class Query:
...
@@ -204,7 +218,7 @@ class Query:
def
__str__
(
self
):
def
__str__
(
self
):
return
(
return
(
"Query{inc=
%
s,ex=
%
s,require=
%
s,subquery=
%
s,"
"
Optimization
Query{inc=
%
s,ex=
%
s,require=
%
s,subquery=
%
s,"
"position_cutoff=
%
f,extra_opts=
%
s}"
"position_cutoff=
%
f,extra_opts=
%
s}"
%
(
%
(
self
.
include
,
self
.
include
,
...
@@ -223,7 +237,7 @@ class Query:
...
@@ -223,7 +237,7 @@ class Query:
# add all opt with this tag
# add all opt with this tag
def
including
(
self
,
*
tags
):
def
including
(
self
,
*
tags
):
return
Query
(
return
Optimization
Query
(
self
.
include
.
union
(
tags
),
self
.
include
.
union
(
tags
),
self
.
require
,
self
.
require
,
self
.
exclude
,
self
.
exclude
,
...
@@ -234,7 +248,7 @@ class Query:
...
@@ -234,7 +248,7 @@ class Query:
# remove all opt with this tag
# remove all opt with this tag
def
excluding
(
self
,
*
tags
):
def
excluding
(
self
,
*
tags
):
return
Query
(
return
Optimization
Query
(
self
.
include
,
self
.
include
,
self
.
require
,
self
.
require
,
self
.
exclude
.
union
(
tags
),
self
.
exclude
.
union
(
tags
),
...
@@ -245,7 +259,7 @@ class Query:
...
@@ -245,7 +259,7 @@ class Query:
# keep only opt with this tag.
# keep only opt with this tag.
def
requiring
(
self
,
*
tags
):
def
requiring
(
self
,
*
tags
):
return
Query
(
return
Optimization
Query
(
self
.
include
,
self
.
include
,
self
.
require
.
union
(
tags
),
self
.
require
.
union
(
tags
),
self
.
exclude
,
self
.
exclude
,
...
@@ -255,7 +269,7 @@ class Query:
...
@@ -255,7 +269,7 @@ class Query:
)
)
def
register
(
self
,
*
optimizations
):
def
register
(
self
,
*
optimizations
):
return
Query
(
return
Optimization
Query
(
self
.
include
,
self
.
include
,
self
.
require
,
self
.
require
,
self
.
exclude
,
self
.
exclude
,
...
@@ -265,7 +279,11 @@ class Query:
...
@@ -265,7 +279,11 @@ class Query:
)
)
class
EquilibriumDB
(
DB
):
# This is deprecated and will be removed.
Query
=
OptimizationQuery
class
EquilibriumDB
(
OptimizationDatabase
):
"""
"""
A set of potential optimizations which should be applied in an arbitrary
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
order until equilibrium is reached.
...
@@ -331,7 +349,7 @@ class EquilibriumDB(DB):
...
@@ -331,7 +349,7 @@ class EquilibriumDB(DB):
)
)
class
SequenceDB
(
DB
):
class
SequenceDB
(
OptimizationDatabase
):
"""
"""
A sequence of potential optimizations.
A sequence of potential optimizations.
...
@@ -378,13 +396,13 @@ class SequenceDB(DB):
...
@@ -378,13 +396,13 @@ class SequenceDB(DB):
position_cutoff
=
kwtags
.
pop
(
"position_cutoff"
,
config
.
optdb__position_cutoff
)
position_cutoff
=
kwtags
.
pop
(
"position_cutoff"
,
config
.
optdb__position_cutoff
)
position_dict
=
self
.
__position__
position_dict
=
self
.
__position__
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Optimization
Query
):
# the call to super should have raise an error with a good message
# the call to super should have raise an error with a good message
assert
len
(
tags
)
==
1
assert
len
(
tags
)
==
1
if
getattr
(
tags
[
0
],
"position_cutoff"
,
None
):
if
getattr
(
tags
[
0
],
"position_cutoff"
,
None
):
position_cutoff
=
tags
[
0
]
.
position_cutoff
position_cutoff
=
tags
[
0
]
.
position_cutoff
# The Query instance might contain extra optimizations which need
# The
Optimization
Query instance might contain extra optimizations which need
# to be added the the sequence of optimizations (don't alter the
# to be added the the sequence of optimizations (don't alter the
# original dictionary)
# original dictionary)
if
len
(
tags
[
0
]
.
extra_optimizations
)
>
0
:
if
len
(
tags
[
0
]
.
extra_optimizations
)
>
0
:
...
@@ -430,7 +448,7 @@ class SequenceDB(DB):
...
@@ -430,7 +448,7 @@ class SequenceDB(DB):
return
sio
.
getvalue
()
return
sio
.
getvalue
()
class
LocalGroupDB
(
DB
):
class
LocalGroupDB
(
OptimizationDatabase
):
"""
"""
Generate a local optimizer of type LocalOptGroup instead
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
of a global optimizer.
...
@@ -476,7 +494,7 @@ class LocalGroupDB(DB):
...
@@ -476,7 +494,7 @@ class LocalGroupDB(DB):
return
ret
return
ret
class
TopoDB
(
DB
):
class
TopoDB
(
OptimizationDatabase
):
"""
"""
Generate a `GlobalOptimizer` of type TopoOptimizer.
Generate a `GlobalOptimizer` of type TopoOptimizer.
...
@@ -501,7 +519,7 @@ class TopoDB(DB):
...
@@ -501,7 +519,7 @@ class TopoDB(DB):
)
)
class
ProxyDB
(
DB
):
class
ProxyDB
(
OptimizationDatabase
):
"""
"""
Wrap an existing proxy.
Wrap an existing proxy.
...
@@ -511,7 +529,7 @@ class ProxyDB(DB):
...
@@ -511,7 +529,7 @@ class ProxyDB(DB):
"""
"""
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
assert
isinstance
(
db
,
DB
),
""
assert
isinstance
(
db
,
OptimizationDatabase
),
""
self
.
db
=
db
self
.
db
=
db
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
...
...
doc/extending/optimization.txt
浏览文件 @
c2672deb
...
@@ -395,97 +395,96 @@ Definition of optdb
...
@@ -395,97 +395,96 @@ Definition of optdb
optdb is an object which is an instance of
optdb is an object which is an instance of
:class:`SequenceDB <optdb.SequenceDB>`,
:class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`
DB <optdb.DB
>`.
itself a subclass of :class:`
OptimizationDatabase <optdb.OptimizationDatabase
>`.
There exist (for now) two types of
DB
, SequenceDB and EquilibriumDB.
There exist (for now) two types of
OptimizationDatabase
, SequenceDB and EquilibriumDB.
When given an appropriate
Query, DB
objects build an Optimizer matching
When given an appropriate
OptimizationQuery, OptimizationDatabase
objects build an Optimizer matching
the query.
the query.
A SequenceDB contains Optimizer or DB objects. Each of them has a
A SequenceDB contains Optimizer or OptimizationDatabase objects. Each of them
name, an arbitrary number of tags and an integer representing their
has a name, an arbitrary number of tags and an integer representing their order
order in the sequence. When a Query is applied to a SequenceDB, all
in the sequence. When a OptimizationQuery is applied to a SequenceDB, all Optimizers whose
Optimizers whose tags match the query are inserted in proper order in
tags match the query are inserted in proper order in a SequenceOptimizer, which
a SequenceOptimizer, which is returned. If the SequenceDB contains DB
is returned. If the SequenceDB contains OptimizationDatabase instances, the OptimizationQuery will be passed
instances, the Query will be passed to them as well and the optimizers
to them as well and the optimizers they return will be put in their places.
they return will be put in their places.
An EquilibriumDB contains LocalOptimizer or
DB
objects. Each of them
An EquilibriumDB contains LocalOptimizer or
OptimizationDatabase
objects. Each of them
has a name and an arbitrary number of tags. When a Query is applied to
has a name and an arbitrary number of tags. When a
Optimization
Query is applied to
an EquilibriumDB, all LocalOptimizers that match the query are
an EquilibriumDB, all LocalOptimizers that match the query are
inserted into an EquilibriumOptimizer, which is returned. If the
inserted into an EquilibriumOptimizer, which is returned. If the
SequenceDB contains
DB instances, the
Query will be passed to them as
SequenceDB contains
OptimizationDatabase instances, the Optimization
Query will be passed to them as
well and the LocalOptimizers they return will be put in their places
well and the LocalOptimizers they return will be put in their places
(note that as of yet no
DB
can produce LocalOptimizer objects, so this
(note that as of yet no
OptimizationDatabase
can produce LocalOptimizer objects, so this
is a moot point).
is a moot point).
Aesara contains one principal
DB
object, :class:`optdb`, which
Aesara contains one principal
OptimizationDatabase
object, :class:`optdb`, which
contains all of Aesara's optimizers with proper tags. It is
contains all of Aesara's optimizers with proper tags. It is
recommended to insert new Optimizers in it. As mentioned previously,
recommended to insert new Optimizers in it. As mentioned previously,
optdb is a SequenceDB, so, at the top level, Aesara applies a sequence
optdb is a SequenceDB, so, at the top level, Aesara applies a sequence
of global optimizations to the computation graphs.
of global optimizations to the computation graphs.
Query
Optimization
Query
-----
-----
A Query is built by the following call:
A
Optimization
Query is built by the following call:
.. code-block:: python
.. code-block:: python
aesara.graph.optdb.Query(include, require=None, exclude=None, subquery=None)
aesara.graph.optdb.
Optimization
Query(include, require=None, exclude=None, subquery=None)
.. class:: Query
.. class::
Optimization
Query
.. attribute:: include
.. attribute:: include
A set of tags (a tag being a string) such that every
A set of tags (a tag being a string) such that every
optimization obtained through this Query must have **one** of the tags
optimization obtained through this
Optimization
Query must have **one** of the tags
listed. This field is required and basically acts as a starting point
listed. This field is required and basically acts as a starting point
for the search.
for the search.
.. attribute:: require
.. attribute:: require
A set of tags such that every optimization obtained
A set of tags such that every optimization obtained
through this Query must have **all** of these tags.
through this
Optimization
Query must have **all** of these tags.
.. attribute:: exclude
.. attribute:: exclude
A set of tags such that every optimization obtained
A set of tags such that every optimization obtained
through this Query must have **none** of these tags.
through this
Optimization
Query must have **none** of these tags.
.. attribute:: subquery
.. attribute:: subquery
optdb can contain sub-databases; subquery is a
optdb can contain sub-databases; subquery is a
dictionary mapping the name of a sub-database to a special Query.
dictionary mapping the name of a sub-database to a special
Optimization
Query.
If no subquery is given for a sub-database, the original Query will be
If no subquery is given for a sub-database, the original
Optimization
Query will be
used again.
used again.
Furthermore, a Query object includes three methods, ``including``,
Furthermore, a
Optimization
Query object includes three methods, ``including``,
``requiring`` and ``excluding`` which each produce a new Query object
``requiring`` and ``excluding`` which each produce a new
Optimization
Query object
with include, require and exclude sets refined to contain the new [WRITEME]
with include, require and exclude sets refined to contain the new [WRITEME]
Examples
Examples
--------
--------
Here are a few examples of how to use a Query on optdb to produce an
Here are a few examples of how to use a
Optimization
Query on optdb to produce an
Optimizer:
Optimizer:
.. testcode::
.. testcode::
from aesara.graph.optdb import Query
from aesara.graph.optdb import
Optimization
Query
from aesara.compile import optdb
from aesara.compile import optdb
# This is how the optimizer for the fast_run mode is defined
# This is how the optimizer for the fast_run mode is defined
fast_run = optdb.query(Query(include=['fast_run']))
fast_run = optdb.query(
Optimization
Query(include=['fast_run']))
# This is how the optimizer for the fast_compile mode is defined
# This is how the optimizer for the fast_compile mode is defined
fast_compile = optdb.query(Query(include=['fast_compile']))
fast_compile = optdb.query(
Optimization
Query(include=['fast_compile']))
# This is the same as fast_run but no optimizations will replace
# This is the same as fast_run but no optimizations will replace
# any operation by an inplace version. This assumes, of course,
# any operation by an inplace version. This assumes, of course,
# that all inplace operations are tagged as 'inplace' (as they
# that all inplace operations are tagged as 'inplace' (as they
# should!)
# should!)
fast_run_no_inplace = optdb.query(Query(include=['fast_run'],
fast_run_no_inplace = optdb.query(
Optimization
Query(include=['fast_run'],
exclude=['inplace']))
exclude=['inplace']))
...
@@ -544,7 +543,7 @@ optimizations:
...
@@ -544,7 +543,7 @@ optimizations:
For each group, all optimizations of the group that are selected by
For each group, all optimizations of the group that are selected by
the Query will be applied on the graph over and over again until none
the
Optimization
Query will be applied on the graph over and over again until none
of them is applicable, so keep that in mind when designing it: check
of them is applicable, so keep that in mind when designing it: check
carefully that your optimization leads to a fixpoint (a point where it
carefully that your optimization leads to a fixpoint (a point where it
cannot apply anymore) at which point it returns ``False`` to indicate its
cannot apply anymore) at which point it returns ``False`` to indicate its
...
...
tests/graph/test_optdb.py
浏览文件 @
c2672deb
import
pytest
import
pytest
from
aesara.graph.optdb
import
DB
,
opt
from
aesara.graph.optdb
import
OptimizationDatabase
,
opt
class
TestDB
:
class
TestDB
:
...
@@ -11,7 +11,7 @@ class TestDB:
...
@@ -11,7 +11,7 @@ class TestDB:
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
pass
pass
db
=
DB
()
db
=
OptimizationDatabase
()
db
.
register
(
"a"
,
Opt
())
db
.
register
(
"a"
,
Opt
())
db
.
register
(
"b"
,
Opt
())
db
.
register
(
"b"
,
Opt
())
...
...
tests/link/test_jax.py
浏览文件 @
c2672deb
...
@@ -13,7 +13,7 @@ from aesara.configdefaults import config
...
@@ -13,7 +13,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.ifelse
import
ifelse
from
aesara.ifelse
import
ifelse
from
aesara.link.jax
import
JAXLinker
from
aesara.link.jax
import
JAXLinker
from
aesara.scalar.basic
import
Composite
from
aesara.scalar.basic
import
Composite
...
@@ -52,7 +52,7 @@ from aesara.tensor.type import (
...
@@ -52,7 +52,7 @@ from aesara.tensor.type import (
jax
=
pytest
.
importorskip
(
"jax"
)
jax
=
pytest
.
importorskip
(
"jax"
)
opts
=
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
@@ -1057,7 +1057,7 @@ def test_jax_BatchedDot():
...
@@ -1057,7 +1057,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility
# A dimension mismatch should raise a TypeError for compatibility
inputs
=
[
get_test_value
(
a
)[:
-
1
],
get_test_value
(
b
)]
inputs
=
[
get_test_value
(
a
)[:
-
1
],
get_test_value
(
b
)]
opts
=
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
aesara_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
aesara_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
...
...
tests/link/test_numba.py
浏览文件 @
c2672deb
...
@@ -21,7 +21,7 @@ from aesara.compile.sharedvalue import SharedVariable
...
@@ -21,7 +21,7 @@ from aesara.compile.sharedvalue import SharedVariable
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.link.numba.dispatch
import
create_numba_signature
,
get_numba_type
from
aesara.link.numba.dispatch
import
create_numba_signature
,
get_numba_type
from
aesara.link.numba.linker
import
NumbaLinker
from
aesara.link.numba.linker
import
NumbaLinker
...
@@ -70,7 +70,7 @@ class MyMultiOut(Op):
...
@@ -70,7 +70,7 @@ class MyMultiOut(Op):
outputs
[
1
][
0
]
=
res2
outputs
[
1
][
0
]
=
res2
opts
=
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
...
tests/tensor/random/test_opt.py
浏览文件 @
c2672deb
...
@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
...
@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
from
aesara.graph.basic
import
Constant
from
aesara.graph.basic
import
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
EquilibriumOptimizer
from
aesara.graph.opt
import
EquilibriumOptimizer
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.random.basic
import
(
from
aesara.tensor.random.basic
import
(
dirichlet
,
dirichlet
,
...
@@ -27,8 +27,10 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
...
@@ -27,8 +27,10 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from
aesara.tensor.type
import
iscalar
,
vector
from
aesara.tensor.type
import
iscalar
,
vector
inplace_mode
=
Mode
(
"py"
,
Query
(
include
=
[
"random_make_inplace"
],
exclude
=
[]))
inplace_mode
=
Mode
(
no_mode
=
Mode
(
"py"
,
Query
(
include
=
[],
exclude
=
[]))
"py"
,
OptimizationQuery
(
include
=
[
"random_make_inplace"
],
exclude
=
[])
)
no_mode
=
Mode
(
"py"
,
OptimizationQuery
(
include
=
[],
exclude
=
[]))
def
test_inplace_optimization
():
def
test_inplace_optimization
():
...
...
tests/tensor/random/test_utils.py
浏览文件 @
c2672deb
...
@@ -3,7 +3,7 @@ import pytest
...
@@ -3,7 +3,7 @@ import pytest
from
aesara
import
config
,
function
from
aesara
import
config
,
function
from
aesara.compile.mode
import
Mode
from
aesara.compile.mode
import
Mode
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.tensor.random.utils
import
RandomStream
,
broadcast_params
from
aesara.tensor.random.utils
import
RandomStream
,
broadcast_params
from
aesara.tensor.type
import
matrix
,
tensor
from
aesara.tensor.type
import
matrix
,
tensor
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
...
@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
...
@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
@pytest.fixture
(
scope
=
"module"
,
autouse
=
True
)
@pytest.fixture
(
scope
=
"module"
,
autouse
=
True
)
def
set_aesara_flags
():
def
set_aesara_flags
():
opts
=
Query
(
include
=
[
None
],
exclude
=
[])
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[])
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
with
config
.
change_flags
(
mode
=
py_mode
,
compute_test_value
=
"warn"
):
with
config
.
change_flags
(
mode
=
py_mode
,
compute_test_value
=
"warn"
):
yield
yield
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
c2672deb
...
@@ -19,7 +19,7 @@ from aesara.graph.basic import Apply, Constant
...
@@ -19,7 +19,7 @@ from aesara.graph.basic import Apply, Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
check_stack_trace
,
local_optimizer
,
out2in
from
aesara.graph.opt
import
check_stack_trace
,
local_optimizer
,
out2in
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.tensor
import
inplace
from
aesara.tensor
import
inplace
from
aesara.tensor.basic
import
(
from
aesara.tensor.basic
import
(
...
@@ -140,15 +140,15 @@ mode_opt = get_mode(mode_opt)
...
@@ -140,15 +140,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
_optimizer_stabilize
=
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_specialize
=
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_fast_run
=
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
...
@@ -351,7 +351,7 @@ def test_local_useless_dimshuffle_in_reshape():
...
@@ -351,7 +351,7 @@ def test_local_useless_dimshuffle_in_reshape():
class
TestFusion
:
class
TestFusion
:
opts
=
Query
(
opts
=
Optimization
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
@@ -1125,7 +1125,7 @@ class TestFusion:
...
@@ -1125,7 +1125,7 @@ class TestFusion:
def
test_add_mul_fusion_inplace
(
self
):
def
test_add_mul_fusion_inplace
(
self
):
opts
=
Query
(
opts
=
Optimization
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
c2672deb
...
@@ -9,7 +9,7 @@ from aesara.compile.mode import Mode
...
@@ -9,7 +9,7 @@ from aesara.compile.mode import Mode
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.gradient
import
grad
from
aesara.gradient
import
grad
from
aesara.graph.basic
import
applys_between
from
aesara.graph.basic
import
applys_between
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.extra_ops
import
(
from
aesara.tensor.extra_ops
import
(
Bartlett
,
Bartlett
,
...
@@ -1169,7 +1169,7 @@ class TestBroadcastTo(utt.InferShapeTester):
...
@@ -1169,7 +1169,7 @@ class TestBroadcastTo(utt.InferShapeTester):
q
=
b
[
np
.
r_
[
0
,
1
,
3
]]
q
=
b
[
np
.
r_
[
0
,
1
,
3
]]
e
=
aet
.
set_subtensor
(
q
,
np
.
r_
[
0
,
0
,
0
])
e
=
aet
.
set_subtensor
(
q
,
np
.
r_
[
0
,
0
,
0
])
opts
=
Query
(
include
=
[
"inplace"
])
opts
=
Optimization
Query
(
include
=
[
"inplace"
])
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
e_fn
=
function
([
d
],
e
,
mode
=
py_mode
)
e_fn
=
function
([
d
],
e
,
mode
=
py_mode
)
...
...
tests/tensor/test_math_opt.py
浏览文件 @
c2672deb
...
@@ -20,7 +20,7 @@ from aesara.graph.basic import Constant
...
@@ -20,7 +20,7 @@ from aesara.graph.basic import Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
LocalOptGroup
,
TopoOptimizer
,
check_stack_trace
,
out2in
from
aesara.graph.opt
import
LocalOptGroup
,
TopoOptimizer
,
check_stack_trace
,
out2in
from
aesara.graph.opt_utils
import
is_same_graph
from
aesara.graph.opt_utils
import
is_same_graph
from
aesara.graph.optdb
import
Query
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.tensor
import
inplace
from
aesara.tensor
import
inplace
from
aesara.tensor.basic
import
Alloc
,
join
,
switch
from
aesara.tensor.basic
import
Alloc
,
join
,
switch
...
@@ -124,15 +124,15 @@ mode_opt = get_mode(mode_opt)
...
@@ -124,15 +124,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
_optimizer_stabilize
=
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_specialize
=
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_fast_run
=
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
...
@@ -351,7 +351,7 @@ class TestAlgebraicCanonize:
...
@@ -351,7 +351,7 @@ class TestAlgebraicCanonize:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
opt
=
Query
([
"canonicalize"
])
opt
=
Optimization
Query
([
"canonicalize"
])
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
...
@@ -486,7 +486,7 @@ class TestAlgebraicCanonize:
...
@@ -486,7 +486,7 @@ class TestAlgebraicCanonize:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
mode
.
_optimizer
=
Query
([
"canonicalize"
])
mode
.
_optimizer
=
Optimization
Query
([
"canonicalize"
])
mode
.
_optimizer
=
mode
.
_optimizer
.
excluding
(
"local_elemwise_fusion"
)
mode
.
_optimizer
=
mode
.
_optimizer
.
excluding
(
"local_elemwise_fusion"
)
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
f
=
function
(
f
=
function
(
...
@@ -534,7 +534,7 @@ class TestAlgebraicCanonize:
...
@@ -534,7 +534,7 @@ class TestAlgebraicCanonize:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
opt
=
Query
([
"canonicalize"
])
opt
=
Optimization
Query
([
"canonicalize"
])
opt
=
opt
.
including
(
"ShapeOpt"
,
"local_fill_to_alloc"
)
opt
=
opt
.
including
(
"ShapeOpt"
,
"local_fill_to_alloc"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
...
@@ -897,7 +897,7 @@ class TestAlgebraicCanonize:
...
@@ -897,7 +897,7 @@ class TestAlgebraicCanonize:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
opt
=
Query
([
"canonicalize"
])
opt
=
Optimization
Query
([
"canonicalize"
])
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
# test fail!
# test fail!
...
@@ -1051,7 +1051,7 @@ def test_cast_in_mul_canonizer():
...
@@ -1051,7 +1051,7 @@ def test_cast_in_mul_canonizer():
class
TestFusion
:
class
TestFusion
:
opts
=
Query
(
opts
=
Optimization
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
@@ -1762,7 +1762,7 @@ class TestFusion:
...
@@ -1762,7 +1762,7 @@ class TestFusion:
def
test_add_mul_fusion_inplace
(
self
):
def
test_add_mul_fusion_inplace
(
self
):
opts
=
Query
(
opts
=
Optimization
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论