Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
92a3b2a6
提交
92a3b2a6
authored
7月 18, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename OptimizationQuery to RewriteDatabaseQuery
上级
5dbfd046
显示空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
107 行增加
和
102 行删除
+107
-102
mode.py
aesara/compile/mode.py
+13
-13
__init__.py
aesara/graph/__init__.py
+1
-1
opt_utils.py
aesara/graph/opt_utils.py
+3
-3
optdb.py
aesara/graph/optdb.py
+30
-25
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+23
-23
test_mode.py
tests/compile/test_mode.py
+3
-3
test_jax.py
tests/link/test_jax.py
+3
-3
test_numba.py
tests/link/test_numba.py
+2
-2
test_numba_performance.py
tests/link/test_numba_performance.py
+2
-2
test_basic.py
tests/tensor/random/test_basic.py
+2
-2
test_opt.py
tests/tensor/random/test_opt.py
+2
-2
test_utils.py
tests/tensor/random/test_utils.py
+2
-2
test_basic_opt.py
tests/tensor/test_basic_opt.py
+7
-7
test_extra_ops.py
tests/tensor/test_extra_ops.py
+2
-2
test_math_opt.py
tests/tensor/test_math_opt.py
+10
-10
test_subtensor_opt.py
tests/tensor/test_subtensor_opt.py
+2
-2
没有找到文件。
aesara/compile/mode.py
浏览文件 @
92a3b2a6
...
@@ -19,8 +19,8 @@ from aesara.graph.opt import (
...
@@ -19,8 +19,8 @@ from aesara.graph.opt import (
from
aesara.graph.optdb
import
(
from
aesara.graph.optdb
import
(
EquilibriumDB
,
EquilibriumDB
,
LocalGroupDB
,
LocalGroupDB
,
OptimizationQuery
,
RewriteDatabase
,
RewriteDatabase
,
RewriteDatabaseQuery
,
SequenceDB
,
SequenceDB
,
TopoDB
,
TopoDB
,
)
)
...
@@ -64,15 +64,15 @@ def register_linker(name, linker):
...
@@ -64,15 +64,15 @@ def register_linker(name, linker):
exclude
=
[]
exclude
=
[]
if
not
config
.
cxx
:
if
not
config
.
cxx
:
exclude
=
[
"cxx_only"
]
exclude
=
[
"cxx_only"
]
OPT_NONE
=
Optimization
Query
(
include
=
[],
exclude
=
exclude
)
OPT_NONE
=
RewriteDatabase
Query
(
include
=
[],
exclude
=
exclude
)
# Even if multiple merge optimizer call will be there, this shouldn't
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
# impact performance.
OPT_MERGE
=
Optimization
Query
(
include
=
[
"merge"
],
exclude
=
exclude
)
OPT_MERGE
=
RewriteDatabase
Query
(
include
=
[
"merge"
],
exclude
=
exclude
)
OPT_FAST_RUN
=
Optimization
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_FAST_RUN
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
"stable"
)
OPT_FAST_RUN_STABLE
=
OPT_FAST_RUN
.
requiring
(
"stable"
)
OPT_FAST_COMPILE
=
Optimization
Query
(
include
=
[
"fast_compile"
],
exclude
=
exclude
)
OPT_FAST_COMPILE
=
RewriteDatabase
Query
(
include
=
[
"fast_compile"
],
exclude
=
exclude
)
OPT_STABILIZE
=
Optimization
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_STABILIZE
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
],
exclude
=
exclude
)
OPT_STABILIZE
.
position_cutoff
=
1.5000001
OPT_STABILIZE
.
position_cutoff
=
1.5000001
OPT_NONE
.
name
=
"OPT_NONE"
OPT_NONE
.
name
=
"OPT_NONE"
OPT_MERGE
.
name
=
"OPT_MERGE"
OPT_MERGE
.
name
=
"OPT_MERGE"
...
@@ -302,7 +302,7 @@ class Mode:
...
@@ -302,7 +302,7 @@ class Mode:
def
__init__
(
def
__init__
(
self
,
self
,
linker
:
Optional
[
Union
[
str
,
Linker
]]
=
None
,
linker
:
Optional
[
Union
[
str
,
Linker
]]
=
None
,
optimizer
:
Union
[
str
,
Optimization
Query
]
=
"default"
,
optimizer
:
Union
[
str
,
RewriteDatabase
Query
]
=
"default"
,
db
:
RewriteDatabase
=
None
,
db
:
RewriteDatabase
=
None
,
):
):
if
linker
is
None
:
if
linker
is
None
:
...
@@ -320,7 +320,7 @@ class Mode:
...
@@ -320,7 +320,7 @@ class Mode:
# self.provided_optimizer - typically the `optimizer` arg.
# self.provided_optimizer - typically the `optimizer` arg.
# But if the `optimizer` arg is keyword corresponding to a predefined
# But if the `optimizer` arg is keyword corresponding to a predefined
#
Optimization
Query, then this stores the query
#
RewriteDatabase
Query, then this stores the query
# self._optimizer - typically same as provided_optimizer??
# self._optimizer - typically same as provided_optimizer??
# self.__get_optimizer - returns self._optimizer (possibly querying
# self.__get_optimizer - returns self._optimizer (possibly querying
...
@@ -342,7 +342,7 @@ class Mode:
...
@@ -342,7 +342,7 @@ class Mode:
self
.
linker
=
linker
self
.
linker
=
linker
if
isinstance
(
optimizer
,
str
)
or
optimizer
is
None
:
if
isinstance
(
optimizer
,
str
)
or
optimizer
is
None
:
optimizer
=
predefined_optimizers
[
optimizer
]
optimizer
=
predefined_optimizers
[
optimizer
]
if
isinstance
(
optimizer
,
Optimization
Query
):
if
isinstance
(
optimizer
,
RewriteDatabase
Query
):
self
.
provided_optimizer
=
optimizer
self
.
provided_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
call_time
=
0
self
.
call_time
=
0
...
@@ -357,7 +357,7 @@ class Mode:
...
@@ -357,7 +357,7 @@ class Mode:
)
)
def
__get_optimizer
(
self
):
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
Optimization
Query
):
if
isinstance
(
self
.
_optimizer
,
RewriteDatabase
Query
):
return
self
.
optdb
.
query
(
self
.
_optimizer
)
return
self
.
optdb
.
query
(
self
.
_optimizer
)
else
:
else
:
return
self
.
_optimizer
return
self
.
_optimizer
...
@@ -375,7 +375,7 @@ class Mode:
...
@@ -375,7 +375,7 @@ class Mode:
link
,
opt
=
self
.
get_linker_optimizer
(
link
,
opt
=
self
.
get_linker_optimizer
(
self
.
provided_linker
,
self
.
provided_optimizer
self
.
provided_linker
,
self
.
provided_optimizer
)
)
# N.B. opt might be a
Optimization
Query instance, not sure what else it might be...
# N.B. opt might be a
RewriteDatabase
Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
# string? Optimizer? OptDB? who knows???
return
self
.
clone
(
optimizer
=
opt
.
including
(
*
tags
),
linker
=
link
)
return
self
.
clone
(
optimizer
=
opt
.
including
(
*
tags
),
linker
=
link
)
...
@@ -448,11 +448,11 @@ else:
...
@@ -448,11 +448,11 @@ else:
JAX
=
Mode
(
JAX
=
Mode
(
JAXLinker
(),
JAXLinker
(),
Optimization
Query
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
RewriteDatabase
Query
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
)
)
NUMBA
=
Mode
(
NUMBA
=
Mode
(
NumbaLinker
(),
NumbaLinker
(),
Optimization
Query
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
RewriteDatabase
Query
(
include
=
[
"fast_run"
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
]),
)
)
...
...
aesara/graph/__init__.py
浏览文件 @
92a3b2a6
...
@@ -15,6 +15,6 @@ from aesara.graph.type import Type
...
@@ -15,6 +15,6 @@ from aesara.graph.type import Type
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
node_rewriter
,
graph_rewriter
from
aesara.graph.opt
import
node_rewriter
,
graph_rewriter
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
# isort: on
# isort: on
aesara/graph/opt_utils.py
浏览文件 @
92a3b2a6
...
@@ -10,7 +10,7 @@ from aesara.graph.basic import (
...
@@ -10,7 +10,7 @@ from aesara.graph.basic import (
vars_between
,
vars_between
,
)
)
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
def
optimize_graph
(
def
optimize_graph
(
...
@@ -34,7 +34,7 @@ def optimize_graph(
...
@@ -34,7 +34,7 @@ def optimize_graph(
clone:
clone:
Whether or not to clone the input graph before optimizing.
Whether or not to clone the input graph before optimizing.
**kwargs:
**kwargs:
Keyword arguments passed to the ``aesara.graph.optdb.
Optimization
Query`` object.
Keyword arguments passed to the ``aesara.graph.optdb.
RewriteDatabase
Query`` object.
"""
"""
from
aesara.compile
import
optdb
from
aesara.compile
import
optdb
...
@@ -43,7 +43,7 @@ def optimize_graph(
...
@@ -43,7 +43,7 @@ def optimize_graph(
fgraph
=
FunctionGraph
(
outputs
=
[
fgraph
],
clone
=
clone
)
fgraph
=
FunctionGraph
(
outputs
=
[
fgraph
],
clone
=
clone
)
return_only_out
=
True
return_only_out
=
True
canonicalize_opt
=
optdb
.
query
(
Optimization
Query
(
include
=
include
,
**
kwargs
))
canonicalize_opt
=
optdb
.
query
(
RewriteDatabase
Query
(
include
=
include
,
**
kwargs
))
_
=
canonicalize_opt
.
optimize
(
fgraph
)
_
=
canonicalize_opt
.
optimize
(
fgraph
)
if
custom_opt
:
if
custom_opt
:
...
...
aesara/graph/optdb.py
浏览文件 @
92a3b2a6
...
@@ -137,10 +137,10 @@ class RewriteDatabase:
...
@@ -137,10 +137,10 @@ class RewriteDatabase:
return
variables
return
variables
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Optimization
Query
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
RewriteDatabase
Query
):
if
len
(
tags
)
>
1
or
kwtags
:
if
len
(
tags
)
>
1
or
kwtags
:
raise
TypeError
(
raise
TypeError
(
"If the first argument to query is an `
Optimization
Query`,"
"If the first argument to query is an `
RewriteDatabase
Query`,"
" there should be no other arguments."
" there should be no other arguments."
)
)
return
self
.
__query__
(
tags
[
0
])
return
self
.
__query__
(
tags
[
0
])
...
@@ -153,7 +153,7 @@ class RewriteDatabase:
...
@@ -153,7 +153,7 @@ class RewriteDatabase:
" characters: '+', '&' or '-'"
" characters: '+', '&' or '-'"
)
)
return
self
.
__query__
(
return
self
.
__query__
(
Optimization
Query
(
RewriteDatabase
Query
(
include
=
include
,
require
=
require
,
exclude
=
exclude
,
subquery
=
kwtags
include
=
include
,
require
=
require
,
exclude
=
exclude
,
subquery
=
kwtags
)
)
)
)
...
@@ -176,7 +176,7 @@ class RewriteDatabase:
...
@@ -176,7 +176,7 @@ class RewriteDatabase:
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
class
Optimization
Query
:
class
RewriteDatabase
Query
:
"""An object that specifies a set of optimizations by tag/name."""
"""An object that specifies a set of optimizations by tag/name."""
def
__init__
(
def
__init__
(
...
@@ -184,11 +184,11 @@ class OptimizationQuery:
...
@@ -184,11 +184,11 @@ class OptimizationQuery:
include
:
Iterable
[
str
],
include
:
Iterable
[
str
],
require
:
Optional
[
Union
[
OrderedSet
,
Sequence
[
str
]]]
=
None
,
require
:
Optional
[
Union
[
OrderedSet
,
Sequence
[
str
]]]
=
None
,
exclude
:
Optional
[
Union
[
OrderedSet
,
Sequence
[
str
]]]
=
None
,
exclude
:
Optional
[
Union
[
OrderedSet
,
Sequence
[
str
]]]
=
None
,
subquery
:
Optional
[
Dict
[
str
,
"
Optimization
Query"
]]
=
None
,
subquery
:
Optional
[
Dict
[
str
,
"
RewriteDatabase
Query"
]]
=
None
,
position_cutoff
:
float
=
math
.
inf
,
position_cutoff
:
float
=
math
.
inf
,
extra_optimizations
:
Optional
[
extra_optimizations
:
Optional
[
Sequence
[
Sequence
[
Tuple
[
Union
[
"
Optimization
Query"
,
OptimizersType
],
Union
[
int
,
float
]]
Tuple
[
Union
[
"
RewriteDatabase
Query"
,
OptimizersType
],
Union
[
int
,
float
]]
]
]
]
=
None
,
]
=
None
,
):
):
...
@@ -198,19 +198,19 @@ class OptimizationQuery:
...
@@ -198,19 +198,19 @@ class OptimizationQuery:
==========
==========
include:
include:
A set of tags such that every optimization obtained through this
A set of tags such that every optimization obtained through this
`
Optimization
Query` must have **one** of the tags listed. This
`
RewriteDatabase
Query` must have **one** of the tags listed. This
field is required and basically acts as a starting point for the
field is required and basically acts as a starting point for the
search.
search.
require:
require:
A set of tags such that every optimization obtained through this
A set of tags such that every optimization obtained through this
`
Optimization
Query` must have **all** of these tags.
`
RewriteDatabase
Query` must have **all** of these tags.
exclude:
exclude:
A set of tags such that every optimization obtained through this
A set of tags such that every optimization obtained through this
``
Optimization
Query` must have **none** of these tags.
``
RewriteDatabase
Query` must have **none** of these tags.
subquery:
subquery:
A dictionary mapping the name of a sub-database to a special
A dictionary mapping the name of a sub-database to a special
`
Optimization
Query`. If no subquery is given for a sub-database,
`
RewriteDatabase
Query`. If no subquery is given for a sub-database,
the original `
Optimization
Query` will be used again.
the original `
RewriteDatabase
Query` will be used again.
position_cutoff:
position_cutoff:
Only optimizations with position less than the cutoff are returned.
Only optimizations with position less than the cutoff are returned.
extra_optimizations:
extra_optimizations:
...
@@ -229,7 +229,7 @@ class OptimizationQuery:
...
@@ -229,7 +229,7 @@ class OptimizationQuery:
def
__str__
(
self
):
def
__str__
(
self
):
return
(
return
(
"
Optimization
Query("
"
RewriteDatabase
Query("
+
f
"inc={self.include},ex={self.exclude},"
+
f
"inc={self.include},ex={self.exclude},"
+
f
"require={self.require},subquery={self.subquery},"
+
f
"require={self.require},subquery={self.subquery},"
+
f
"position_cutoff={self.position_cutoff},"
+
f
"position_cutoff={self.position_cutoff},"
...
@@ -241,9 +241,9 @@ class OptimizationQuery:
...
@@ -241,9 +241,9 @@ class OptimizationQuery:
if
not
hasattr
(
self
,
"extra_optimizations"
):
if
not
hasattr
(
self
,
"extra_optimizations"
):
self
.
extra_optimizations
=
[]
self
.
extra_optimizations
=
[]
def
including
(
self
,
*
tags
:
str
)
->
"
Optimization
Query"
:
def
including
(
self
,
*
tags
:
str
)
->
"
RewriteDatabase
Query"
:
"""Add rewrites with the given tags."""
"""Add rewrites with the given tags."""
return
Optimization
Query
(
return
RewriteDatabase
Query
(
self
.
include
.
union
(
tags
),
self
.
include
.
union
(
tags
),
self
.
require
,
self
.
require
,
self
.
exclude
,
self
.
exclude
,
...
@@ -252,9 +252,9 @@ class OptimizationQuery:
...
@@ -252,9 +252,9 @@ class OptimizationQuery:
self
.
extra_optimizations
,
self
.
extra_optimizations
,
)
)
def
excluding
(
self
,
*
tags
:
str
)
->
"
Optimization
Query"
:
def
excluding
(
self
,
*
tags
:
str
)
->
"
RewriteDatabase
Query"
:
"""Remove rewrites with the given tags."""
"""Remove rewrites with the given tags."""
return
Optimization
Query
(
return
RewriteDatabase
Query
(
self
.
include
,
self
.
include
,
self
.
require
,
self
.
require
,
self
.
exclude
.
union
(
tags
),
self
.
exclude
.
union
(
tags
),
...
@@ -263,9 +263,9 @@ class OptimizationQuery:
...
@@ -263,9 +263,9 @@ class OptimizationQuery:
self
.
extra_optimizations
,
self
.
extra_optimizations
,
)
)
def
requiring
(
self
,
*
tags
:
str
)
->
"
Optimization
Query"
:
def
requiring
(
self
,
*
tags
:
str
)
->
"
RewriteDatabase
Query"
:
"""Filter for rewrites with the given tags."""
"""Filter for rewrites with the given tags."""
return
Optimization
Query
(
return
RewriteDatabase
Query
(
self
.
include
,
self
.
include
,
self
.
require
.
union
(
tags
),
self
.
require
.
union
(
tags
),
self
.
exclude
,
self
.
exclude
,
...
@@ -275,10 +275,10 @@ class OptimizationQuery:
...
@@ -275,10 +275,10 @@ class OptimizationQuery:
)
)
def
register
(
def
register
(
self
,
*
optimizations
:
Tuple
[
"
Optimization
Query"
,
Union
[
int
,
float
]]
self
,
*
optimizations
:
Tuple
[
"
RewriteDatabase
Query"
,
Union
[
int
,
float
]]
)
->
"
Optimization
Query"
:
)
->
"
RewriteDatabase
Query"
:
"""Include the given optimizations."""
"""Include the given optimizations."""
return
Optimization
Query
(
return
RewriteDatabase
Query
(
self
.
include
,
self
.
include
,
self
.
require
,
self
.
require
,
self
.
exclude
,
self
.
exclude
,
...
@@ -417,13 +417,13 @@ class SequenceDB(RewriteDatabase):
...
@@ -417,13 +417,13 @@ class SequenceDB(RewriteDatabase):
position_dict
=
self
.
__position__
position_dict
=
self
.
__position__
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Optimization
Query
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
RewriteDatabase
Query
):
# the call to super should have raise an error with a good message
# the call to super should have raise an error with a good message
assert
len
(
tags
)
==
1
assert
len
(
tags
)
==
1
if
getattr
(
tags
[
0
],
"position_cutoff"
,
None
):
if
getattr
(
tags
[
0
],
"position_cutoff"
,
None
):
position_cutoff
=
tags
[
0
]
.
position_cutoff
position_cutoff
=
tags
[
0
]
.
position_cutoff
# The
Optimization
Query instance might contain extra optimizations which need
# The
RewriteDatabase
Query instance might contain extra optimizations which need
# to be added the the sequence of optimizations (don't alter the
# to be added the the sequence of optimizations (don't alter the
# original dictionary)
# original dictionary)
if
len
(
tags
[
0
]
.
extra_optimizations
)
>
0
:
if
len
(
tags
[
0
]
.
extra_optimizations
)
>
0
:
...
@@ -544,14 +544,19 @@ DEPRECATED_NAMES = [
...
@@ -544,14 +544,19 @@ DEPRECATED_NAMES = [
),
),
(
(
"Query"
,
"Query"
,
"`Query` is deprecated; use `
Optimization
Query` instead."
,
"`Query` is deprecated; use `
RewriteDatabase
Query` instead."
,
Optimization
Query
,
RewriteDatabase
Query
,
),
),
(
(
"OptimizationDatabase"
,
"OptimizationDatabase"
,
"`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead."
,
"`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead."
,
RewriteDatabase
,
RewriteDatabase
,
),
),
(
"OptimizationQuery"
,
"`OptimizationQuery` is deprecated; use `RewriteDatabaseQuery` instead."
,
RewriteDatabaseQuery
,
),
]
]
...
...
doc/extending/graph_rewriting.rst
浏览文件 @
92a3b2a6
...
@@ -585,23 +585,23 @@ Definition of :obj:`optdb`
...
@@ -585,23 +585,23 @@ Definition of :obj:`optdb`
:class:`SequenceDB <optdb.SequenceDB>`,
:class:`SequenceDB <optdb.SequenceDB>`,
itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`.
itself a subclass of :class:`RewriteDatabase <optdb.RewriteDatabase>`.
There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`.
When given an appropriate :class:`
Optimization
Query`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
When given an appropriate :class:`
RewriteDatabase
Query`, :class:`RewriteDatabase` objects build an :class:`Optimizer` matching
the query.
the query.
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them
A :class:`SequenceDB` contains :class:`Optimizer` or :class:`RewriteDatabase` objects. Each of them
has a name, an arbitrary number of tags and an integer representing their order
has a name, an arbitrary number of tags and an integer representing their order
in the sequence. When a :class:`
Optimization
Query` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
in the sequence. When a :class:`
RewriteDatabase
Query` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose
tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which
tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which
is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase`
is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase`
instances, the :class:`
Optimization
Query` will be passed to them as well and the
instances, the :class:`
RewriteDatabase
Query` will be passed to them as well and the
optimizers they return will be put in their places.
optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them
An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`
Optimization
Query` is applied to
has a name and an arbitrary number of tags. When a :class:`
RewriteDatabase
Query` is applied to
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the
:class:`SequenceDB` contains :class:`RewriteDatabase` instances, the
:class:`SequenceDB` contains :class:`RewriteDatabase` instances, the
:class:`
Optimization
Query` will be passed to them as well and the
:class:`
RewriteDatabase
Query` will be passed to them as well and the
:class:`NodeRewriter`\s they return will be put in their places
:class:`NodeRewriter`\s they return will be put in their places
(note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this
(note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this
is a moot point).
is a moot point).
...
@@ -613,68 +613,68 @@ optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence
...
@@ -613,68 +613,68 @@ optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence
of global optimizations to the computation graphs.
of global optimizations to the computation graphs.
:class:`
Optimization
Query`
:class:`
RewriteDatabase
Query`
--------------------------
--------------------------
---
A :class:`
Optimization
Query` is built by the following call:
A :class:`
RewriteDatabase
Query` is built by the following call:
.. code-block:: python
.. code-block:: python
aesara.graph.optdb.
Optimization
Query(include, require=None, exclude=None, subquery=None)
aesara.graph.optdb.
RewriteDatabase
Query(include, require=None, exclude=None, subquery=None)
.. class::
Optimization
Query
.. class::
RewriteDatabase
Query
.. attribute:: include
.. attribute:: include
A set of tags (a tag being a string) such that every
A set of tags (a tag being a string) such that every
optimization obtained through this :class:`
Optimization
Query` must have **one** of the tags
optimization obtained through this :class:`
RewriteDatabase
Query` must have **one** of the tags
listed. This field is required and basically acts as a starting point
listed. This field is required and basically acts as a starting point
for the search.
for the search.
.. attribute:: require
.. attribute:: require
A set of tags such that every optimization obtained
A set of tags such that every optimization obtained
through this :class:`
Optimization
Query` must have **all** of these tags.
through this :class:`
RewriteDatabase
Query` must have **all** of these tags.
.. attribute:: exclude
.. attribute:: exclude
A set of tags such that every optimization obtained
A set of tags such that every optimization obtained
through this :class:`
Optimization
Query` must have **none** of these tags.
through this :class:`
RewriteDatabase
Query` must have **none** of these tags.
.. attribute:: subquery
.. attribute:: subquery
:obj:`optdb` can contain sub-databases; subquery is a
:obj:`optdb` can contain sub-databases; subquery is a
dictionary mapping the name of a sub-database to a special :class:`
Optimization
Query`.
dictionary mapping the name of a sub-database to a special :class:`
RewriteDatabase
Query`.
If no subquery is given for a sub-database, the original :class:`
Optimization
Query` will be
If no subquery is given for a sub-database, the original :class:`
RewriteDatabase
Query` will be
used again.
used again.
Furthermore, a :class:`
Optimization
Query` object includes three methods, :meth:`including`,
Furthermore, a :class:`
RewriteDatabase
Query` object includes three methods, :meth:`including`,
:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`
Optimization
Query` object
:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`
RewriteDatabase
Query` object
with the include, require, and exclude sets refined to contain the new entries.
with the include, require, and exclude sets refined to contain the new entries.
Examples
Examples
--------
--------
Here are a few examples of how to use a :class:`
Optimization
Query` on :obj:`optdb` to produce an
Here are a few examples of how to use a :class:`
RewriteDatabase
Query` on :obj:`optdb` to produce an
:class:`Optimizer`:
:class:`Optimizer`:
.. testcode::
.. testcode::
from aesara.graph.optdb import
Optimization
Query
from aesara.graph.optdb import
RewriteDatabase
Query
from aesara.compile import optdb
from aesara.compile import optdb
# This is how the optimizer for the fast_run mode is defined
# This is how the optimizer for the fast_run mode is defined
fast_run = optdb.query(
Optimization
Query(include=['fast_run']))
fast_run = optdb.query(
RewriteDatabase
Query(include=['fast_run']))
# This is how the optimizer for the fast_compile mode is defined
# This is how the optimizer for the fast_compile mode is defined
fast_compile = optdb.query(
Optimization
Query(include=['fast_compile']))
fast_compile = optdb.query(
RewriteDatabase
Query(include=['fast_compile']))
# This is the same as fast_run but no optimizations will replace
# This is the same as fast_run but no optimizations will replace
# any operation by an inplace version. This assumes, of course,
# any operation by an inplace version. This assumes, of course,
# that all inplace operations are tagged as 'inplace' (as they
# that all inplace operations are tagged as 'inplace' (as they
# should!)
# should!)
fast_run_no_inplace = optdb.query(
Optimization
Query(include=['fast_run'],
fast_run_no_inplace = optdb.query(
RewriteDatabase
Query(include=['fast_run'],
exclude=['inplace']))
exclude=['inplace']))
...
@@ -733,7 +733,7 @@ optimizations:
...
@@ -733,7 +733,7 @@ optimizations:
For each group, all optimizations of the group that are selected by
For each group, all optimizations of the group that are selected by
the :class:`
Optimization
Query` will be applied on the graph over and over again until none
the :class:`
RewriteDatabase
Query` will be applied on the graph over and over again until none
of them is applicable, so keep that in mind when designing it: check
of them is applicable, so keep that in mind when designing it: check
carefully that your optimization leads to a fixpoint (a point where it
carefully that your optimization leads to a fixpoint (a point where it
cannot apply anymore) at which point it returns ``False`` to indicate its
cannot apply anymore) at which point it returns ``False`` to indicate its
...
...
tests/compile/test_mode.py
浏览文件 @
92a3b2a6
from
aesara.compile.function
import
function
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
AddFeatureOptimizer
,
Mode
from
aesara.compile.mode
import
AddFeatureOptimizer
,
Mode
from
aesara.graph.features
import
NoOutputFromInplace
from
aesara.graph.features
import
NoOutputFromInplace
from
aesara.graph.optdb
import
Optimization
Query
,
SequenceDB
from
aesara.graph.optdb
import
RewriteDatabase
Query
,
SequenceDB
from
aesara.tensor.math
import
dot
,
tanh
from
aesara.tensor.math
import
dot
,
tanh
from
aesara.tensor.type
import
matrix
from
aesara.tensor.type
import
matrix
def
test_Mode_basic
():
def
test_Mode_basic
():
db
=
SequenceDB
()
db
=
SequenceDB
()
mode
=
Mode
(
linker
=
"py"
,
optimizer
=
Optimization
Query
(
include
=
None
),
db
=
db
)
mode
=
Mode
(
linker
=
"py"
,
optimizer
=
RewriteDatabase
Query
(
include
=
None
),
db
=
db
)
assert
mode
.
optdb
is
db
assert
mode
.
optdb
is
db
assert
str
(
mode
)
.
startswith
(
"Mode(linker=py, optimizer=
Optimization
Query"
)
assert
str
(
mode
)
.
startswith
(
"Mode(linker=py, optimizer=
RewriteDatabase
Query"
)
def
test_NoOutputFromInplace
():
def
test_NoOutputFromInplace
():
...
...
tests/link/test_jax.py
浏览文件 @
92a3b2a6
...
@@ -15,7 +15,7 @@ from aesara.configdefaults import config
...
@@ -15,7 +15,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.ifelse
import
ifelse
from
aesara.ifelse
import
ifelse
from
aesara.link.jax
import
JAXLinker
from
aesara.link.jax
import
JAXLinker
from
aesara.raise_op
import
assert_op
from
aesara.raise_op
import
assert_op
...
@@ -56,7 +56,7 @@ from aesara.tensor.type import (
...
@@ -56,7 +56,7 @@ from aesara.tensor.type import (
jax
=
pytest
.
importorskip
(
"jax"
)
jax
=
pytest
.
importorskip
(
"jax"
)
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabase
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
@@ -1142,7 +1142,7 @@ def test_jax_BatchedDot():
...
@@ -1142,7 +1142,7 @@ def test_jax_BatchedDot():
# A dimension mismatch should raise a TypeError for compatibility
# A dimension mismatch should raise a TypeError for compatibility
inputs
=
[
get_test_value
(
a
)[:
-
1
],
get_test_value
(
b
)]
inputs
=
[
get_test_value
(
a
)[:
-
1
],
get_test_value
(
b
)]
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabase
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
aesara_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
aesara_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
...
...
tests/link/test_numba.py
浏览文件 @
92a3b2a6
...
@@ -24,7 +24,7 @@ from aesara.compile.sharedvalue import SharedVariable
...
@@ -24,7 +24,7 @@ from aesara.compile.sharedvalue import SharedVariable
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.op
import
Op
,
get_test_value
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.ifelse
import
ifelse
from
aesara.ifelse
import
ifelse
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
...
@@ -92,7 +92,7 @@ my_multi_out.ufunc = MyMultiOut.impl
...
@@ -92,7 +92,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out
.
ufunc
.
nin
=
2
my_multi_out
.
ufunc
.
nin
=
2
my_multi_out
.
ufunc
.
nout
=
2
my_multi_out
.
ufunc
.
nout
=
2
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabase
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
...
tests/link/test_numba_performance.py
浏览文件 @
92a3b2a6
...
@@ -7,12 +7,12 @@ import aesara.tensor as aet
...
@@ -7,12 +7,12 @@ import aesara.tensor as aet
from
aesara
import
config
from
aesara
import
config
from
aesara.compile.function
import
function
from
aesara.compile.function
import
function
from
aesara.compile.mode
import
Mode
from
aesara.compile.mode
import
Mode
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.link.numba.linker
import
NumbaLinker
from
aesara.link.numba.linker
import
NumbaLinker
from
aesara.tensor.math
import
Max
from
aesara.tensor.math
import
Max
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabase
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
numba_mode
=
Mode
(
NumbaLinker
(),
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
...
tests/tensor/random/test_basic.py
浏览文件 @
92a3b2a6
...
@@ -14,7 +14,7 @@ from aesara.configdefaults import config
...
@@ -14,7 +14,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Constant
,
Variable
,
graph_inputs
from
aesara.graph.basic
import
Constant
,
Variable
,
graph_inputs
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
get_test_value
from
aesara.graph.op
import
get_test_value
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.tensor.basic_opt
import
ShapeFeature
from
aesara.tensor.basic_opt
import
ShapeFeature
from
aesara.tensor.random.basic
import
(
from
aesara.tensor.random.basic
import
(
bernoulli
,
bernoulli
,
...
@@ -60,7 +60,7 @@ from aesara.tensor.type import iscalar, scalar, tensor
...
@@ -60,7 +60,7 @@ from aesara.tensor.type import iscalar, scalar, tensor
from
tests.unittest_tools
import
create_aesara_param
from
tests.unittest_tools
import
create_aesara_param
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabase
Query
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
...
...
tests/tensor/random/test_opt.py
浏览文件 @
92a3b2a6
...
@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
...
@@ -8,7 +8,7 @@ from aesara.compile.mode import Mode
from
aesara.graph.basic
import
Constant
from
aesara.graph.basic
import
Constant
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
EquilibriumGraphRewriter
from
aesara.graph.opt
import
EquilibriumGraphRewriter
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.random.basic
import
(
from
aesara.tensor.random.basic
import
(
dirichlet
,
dirichlet
,
...
@@ -28,7 +28,7 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
...
@@ -28,7 +28,7 @@ from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subte
from
aesara.tensor.type
import
iscalar
,
vector
from
aesara.tensor.type
import
iscalar
,
vector
no_mode
=
Mode
(
"py"
,
Optimization
Query
(
include
=
[],
exclude
=
[]))
no_mode
=
Mode
(
"py"
,
RewriteDatabase
Query
(
include
=
[],
exclude
=
[]))
def
apply_local_opt_to_rv
(
opt
,
op_fn
,
dist_op
,
dist_params
,
size
,
rng
,
name
=
None
):
def
apply_local_opt_to_rv
(
opt
,
op_fn
,
dist_op
,
dist_params
,
size
,
rng
,
name
=
None
):
...
...
tests/tensor/random/test_utils.py
浏览文件 @
92a3b2a6
...
@@ -3,7 +3,7 @@ import pytest
...
@@ -3,7 +3,7 @@ import pytest
from
aesara
import
config
,
function
from
aesara
import
config
,
function
from
aesara.compile.mode
import
Mode
from
aesara.compile.mode
import
Mode
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.tensor.random.utils
import
RandomStream
,
broadcast_params
from
aesara.tensor.random.utils
import
RandomStream
,
broadcast_params
from
aesara.tensor.type
import
matrix
,
tensor
from
aesara.tensor.type
import
matrix
,
tensor
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
...
@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
...
@@ -11,7 +11,7 @@ from tests import unittest_tools as utt
@pytest.fixture
(
scope
=
"module"
,
autouse
=
True
)
@pytest.fixture
(
scope
=
"module"
,
autouse
=
True
)
def
set_aesara_flags
():
def
set_aesara_flags
():
opts
=
Optimization
Query
(
include
=
[
None
],
exclude
=
[])
opts
=
RewriteDatabase
Query
(
include
=
[
None
],
exclude
=
[])
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
with
config
.
change_flags
(
mode
=
py_mode
,
compute_test_value
=
"warn"
):
with
config
.
change_flags
(
mode
=
py_mode
,
compute_test_value
=
"warn"
):
yield
yield
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
92a3b2a6
...
@@ -18,7 +18,7 @@ from aesara.graph.fg import FunctionGraph
...
@@ -18,7 +18,7 @@ from aesara.graph.fg import FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
check_stack_trace
,
node_rewriter
,
out2in
from
aesara.graph.opt
import
check_stack_trace
,
node_rewriter
,
out2in
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.printing
import
pprint
from
aesara.printing
import
pprint
...
@@ -141,15 +141,15 @@ mode_opt = get_mode(mode_opt)
...
@@ -141,15 +141,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
_optimizer_stabilize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_specialize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_fast_run
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
...
@@ -352,7 +352,7 @@ def test_local_useless_dimshuffle_in_reshape():
...
@@ -352,7 +352,7 @@ def test_local_useless_dimshuffle_in_reshape():
class
TestFusion
:
class
TestFusion
:
opts
=
Optimization
Query
(
opts
=
RewriteDatabase
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
@@ -1099,7 +1099,7 @@ class TestFusion:
...
@@ -1099,7 +1099,7 @@ class TestFusion:
def
test_add_mul_fusion_inplace
(
self
):
def
test_add_mul_fusion_inplace
(
self
):
opts
=
Optimization
Query
(
opts
=
RewriteDatabase
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
@@ -1165,7 +1165,7 @@ class TestFusion:
...
@@ -1165,7 +1165,7 @@ class TestFusion:
"""
"""
opts
=
Optimization
Query
(
opts
=
RewriteDatabase
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
92a3b2a6
...
@@ -9,7 +9,7 @@ from aesara import tensor as at
...
@@ -9,7 +9,7 @@ from aesara import tensor as at
from
aesara.compile.mode
import
Mode
from
aesara.compile.mode
import
Mode
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Constant
,
applys_between
from
aesara.graph.basic
import
Constant
,
applys_between
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.raise_op
import
Assert
from
aesara.raise_op
import
Assert
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.extra_ops
import
(
from
aesara.tensor.extra_ops
import
(
...
@@ -1285,7 +1285,7 @@ class TestBroadcastTo(utt.InferShapeTester):
...
@@ -1285,7 +1285,7 @@ class TestBroadcastTo(utt.InferShapeTester):
q
=
b
[
np
.
r_
[
0
,
1
,
3
]]
q
=
b
[
np
.
r_
[
0
,
1
,
3
]]
e
=
at
.
set_subtensor
(
q
,
np
.
r_
[
0
,
0
,
0
])
e
=
at
.
set_subtensor
(
q
,
np
.
r_
[
0
,
0
,
0
])
opts
=
Optimization
Query
(
include
=
[
"inplace"
])
opts
=
RewriteDatabase
Query
(
include
=
[
"inplace"
])
py_mode
=
Mode
(
"py"
,
opts
)
py_mode
=
Mode
(
"py"
,
opts
)
e_fn
=
function
([
d
],
e
,
mode
=
py_mode
)
e_fn
=
function
([
d
],
e
,
mode
=
py_mode
)
...
...
tests/tensor/test_math_opt.py
浏览文件 @
92a3b2a6
...
@@ -26,7 +26,7 @@ from aesara.graph.opt import (
...
@@ -26,7 +26,7 @@ from aesara.graph.opt import (
out2in
,
out2in
,
)
)
from
aesara.graph.opt_utils
import
is_same_graph
,
optimize_graph
from
aesara.graph.opt_utils
import
is_same_graph
,
optimize_graph
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.tensor
import
inplace
from
aesara.tensor
import
inplace
from
aesara.tensor.basic
import
Alloc
,
join
,
switch
from
aesara.tensor.basic
import
Alloc
,
join
,
switch
...
@@ -132,15 +132,15 @@ mode_opt = get_mode(mode_opt)
...
@@ -132,15 +132,15 @@ mode_opt = get_mode(mode_opt)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
dimshuffle_lift
=
out2in
(
local_dimshuffle_lift
)
_optimizer_stabilize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
])
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
.
position_cutoff
=
1.51
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_stabilize
=
optdb
.
query
(
_optimizer_stabilize
)
_optimizer_specialize
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
])
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
.
position_cutoff
=
2.01
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_specialize
=
optdb
.
query
(
_optimizer_specialize
)
_optimizer_fast_run
=
Optimization
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
RewriteDatabase
Query
(
include
=
[
"fast_run"
])
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
_optimizer_fast_run
=
optdb
.
query
(
_optimizer_fast_run
)
...
@@ -366,7 +366,7 @@ class TestAlgebraicCanonizer:
...
@@ -366,7 +366,7 @@ class TestAlgebraicCanonizer:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
opt
=
Optimization
Query
([
"canonicalize"
])
opt
=
RewriteDatabase
Query
([
"canonicalize"
])
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
...
@@ -500,7 +500,7 @@ class TestAlgebraicCanonizer:
...
@@ -500,7 +500,7 @@ class TestAlgebraicCanonizer:
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# We must be sure that the AlgebraicCanonizer is working, but that we don't have other
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
mode
.
_optimizer
=
Optimization
Query
([
"canonicalize"
])
mode
.
_optimizer
=
RewriteDatabase
Query
([
"canonicalize"
])
mode
.
_optimizer
=
mode
.
_optimizer
.
excluding
(
"local_elemwise_fusion"
)
mode
.
_optimizer
=
mode
.
_optimizer
.
excluding
(
"local_elemwise_fusion"
)
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
for
id
,
[
g
,
sym_inputs
,
val_inputs
,
nb_elemwise
,
out_dtype
]
in
enumerate
(
cases
):
f
=
function
(
f
=
function
(
...
@@ -547,7 +547,7 @@ class TestAlgebraicCanonizer:
...
@@ -547,7 +547,7 @@ class TestAlgebraicCanonizer:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
opt
=
Optimization
Query
([
"canonicalize"
])
opt
=
RewriteDatabase
Query
([
"canonicalize"
])
opt
=
opt
.
including
(
"ShapeOpt"
,
"local_fill_to_alloc"
)
opt
=
opt
.
including
(
"ShapeOpt"
,
"local_fill_to_alloc"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
...
@@ -907,7 +907,7 @@ class TestAlgebraicCanonizer:
...
@@ -907,7 +907,7 @@ class TestAlgebraicCanonizer:
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
# optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion
mode
=
get_default_mode
()
mode
=
get_default_mode
()
opt
=
Optimization
Query
([
"canonicalize"
])
opt
=
RewriteDatabase
Query
([
"canonicalize"
])
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
opt
=
opt
.
excluding
(
"local_elemwise_fusion"
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
mode
=
mode
.
__class__
(
linker
=
mode
.
linker
,
optimizer
=
opt
)
# test fail!
# test fail!
...
@@ -1074,7 +1074,7 @@ def test_cast_in_mul_canonizer():
...
@@ -1074,7 +1074,7 @@ def test_cast_in_mul_canonizer():
class
TestFusion
:
class
TestFusion
:
opts
=
Optimization
Query
(
opts
=
RewriteDatabase
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
@@ -1782,7 +1782,7 @@ class TestFusion:
...
@@ -1782,7 +1782,7 @@ class TestFusion:
def
test_add_mul_fusion_inplace
(
self
):
def
test_add_mul_fusion_inplace
(
self
):
opts
=
Optimization
Query
(
opts
=
RewriteDatabase
Query
(
include
=
[
include
=
[
"local_elemwise_fusion"
,
"local_elemwise_fusion"
,
"composite_elemwise_fusion"
,
"composite_elemwise_fusion"
,
...
...
tests/tensor/test_subtensor_opt.py
浏览文件 @
92a3b2a6
...
@@ -12,7 +12,7 @@ from aesara.configdefaults import config
...
@@ -12,7 +12,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Constant
,
Variable
,
ancestors
from
aesara.graph.basic
import
Constant
,
Variable
,
ancestors
from
aesara.graph.opt
import
check_stack_trace
from
aesara.graph.opt
import
check_stack_trace
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
Optimization
Query
from
aesara.graph.optdb
import
RewriteDatabase
Query
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.raise_op
import
Assert
from
aesara.raise_op
import
Assert
from
aesara.tensor
import
inplace
from
aesara.tensor
import
inplace
...
@@ -1994,7 +1994,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
...
@@ -1994,7 +1994,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y
=
specify_shape
(
x
,
s
)[
idx
]
y
=
specify_shape
(
x
,
s
)[
idx
]
assert
isinstance
(
y
.
owner
.
inputs
[
0
]
.
owner
.
op
,
SpecifyShape
)
assert
isinstance
(
y
.
owner
.
inputs
[
0
]
.
owner
.
op
,
SpecifyShape
)
opts
=
Optimization
Query
(
include
=
[
None
])
opts
=
RewriteDatabase
Query
(
include
=
[
None
])
no_opt_mode
=
Mode
(
optimizer
=
opts
)
no_opt_mode
=
Mode
(
optimizer
=
opts
)
y_val_fn
=
function
([
x
]
+
list
(
s
),
y
,
on_unused_input
=
"ignore"
,
mode
=
no_opt_mode
)
y_val_fn
=
function
([
x
]
+
list
(
s
),
y
,
on_unused_input
=
"ignore"
,
mode
=
no_opt_mode
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论