Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d5013456
提交
d5013456
authored
7月 15, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename NavigatorOptimizer to NodeProcessingGraphRewriter
上级
6302cef1
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
42 行增加
和
31 行删除
+42
-31
mode.py
aesara/compile/mode.py
+2
-2
opt.py
aesara/graph/opt.py
+29
-18
optdb.py
aesara/graph/optdb.py
+1
-1
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+7
-7
test_destroyhandler.py
tests/graph/test_destroyhandler.py
+2
-2
test_opt.py
tests/graph/test_opt.py
+1
-1
没有找到文件。
aesara/compile/mode.py
浏览文件 @
d5013456
...
@@ -14,7 +14,7 @@ from aesara.graph.opt import (
...
@@ -14,7 +14,7 @@ from aesara.graph.opt import (
CheckStackTraceOptimization
,
CheckStackTraceOptimization
,
GraphRewriter
,
GraphRewriter
,
MergeOptimizer
,
MergeOptimizer
,
N
avigatorOptimiz
er
,
N
odeProcessingGraphRewrit
er
,
)
)
from
aesara.graph.optdb
import
(
from
aesara.graph.optdb
import
(
EquilibriumDB
,
EquilibriumDB
,
...
@@ -193,7 +193,7 @@ optdb.register(
...
@@ -193,7 +193,7 @@ optdb.register(
local_useless
=
LocalGroupDB
(
apply_all_opts
=
True
,
profile
=
True
)
local_useless
=
LocalGroupDB
(
apply_all_opts
=
True
,
profile
=
True
)
optdb
.
register
(
optdb
.
register
(
"useless"
,
"useless"
,
TopoDB
(
local_useless
,
failure_callback
=
N
avigatorOptimiz
er
.
warn_inplace
),
TopoDB
(
local_useless
,
failure_callback
=
N
odeProcessingGraphRewrit
er
.
warn_inplace
),
"fast_run"
,
"fast_run"
,
"fast_compile"
,
"fast_compile"
,
position
=
0.6
,
position
=
0.6
,
...
...
aesara/graph/opt.py
浏览文件 @
d5013456
...
@@ -48,7 +48,7 @@ _logger = logging.getLogger("aesara.graph.opt")
...
@@ -48,7 +48,7 @@ _logger = logging.getLogger("aesara.graph.opt")
FailureCallbackType
=
Callable
[
FailureCallbackType
=
Callable
[
[
[
Exception
,
Exception
,
"N
avigatorOptimiz
er"
,
"N
odeProcessingGraphRewrit
er"
,
List
[
Tuple
[
Variable
,
None
]],
List
[
Tuple
[
Variable
,
None
]],
"NodeRewriter"
,
"NodeRewriter"
,
Apply
,
Apply
,
...
@@ -1210,7 +1210,7 @@ class SequentialNodeRewriter(NodeRewriter):
...
@@ -1210,7 +1210,7 @@ class SequentialNodeRewriter(NodeRewriter):
Attributes
Attributes
----------
----------
reentrant : bool
reentrant : bool
Some global optimizers, like `N
avigatorOptimiz
er`, use this value to
Some global optimizers, like `N
odeProcessingGraphRewrit
er`, use this value to
determine if they should ignore new nodes.
determine if they should ignore new nodes.
retains_inputs : bool
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
States whether or not the inputs of a transformed node are transferred
...
@@ -1724,13 +1724,17 @@ class Updater(Feature):
...
@@ -1724,13 +1724,17 @@ class Updater(Feature):
self
.
chin
=
None
self
.
chin
=
None
class
N
avigatorOptimiz
er
(
GraphRewriter
):
class
N
odeProcessingGraphRewrit
er
(
GraphRewriter
):
r"""A
n optimiz
er that applies a `NodeRewriter` with considerations for the new nodes it creates.
r"""A
rewrit
er that applies a `NodeRewriter` with considerations for the new nodes it creates.
The results of successful rewrites are considered for rewriting based on
the values of `NodeProcessingGraphRewriter.ignore_newtrees` and/or
`NodeRewriter.reentrant`.
This optimizer also allows the `NodeRewriter` to use a special ``"remove"`` value
This rewriter accepts ``dict`` values from `NodeRewriter.transform`.
in the ``dict``\s returned by :meth:`NodeRewriter`. `Variable`\s mapped to this
Entries in these ``dict``\s can be `Variable`\s and their new values.
value are removed from the `FunctionGraph`.
It also accepts a special ``"remove"`` key. A sequence of `Variable`\s
mapped to the key ``"remove"`` are removed from the `FunctionGraph`.
"""
"""
...
@@ -1759,7 +1763,9 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1759,7 +1763,9 @@ class NavigatorOptimizer(GraphRewriter):
"""
"""
if
isinstance
(
exc
,
InconsistencyError
):
if
isinstance
(
exc
,
InconsistencyError
):
return
return
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
)
return
NodeProcessingGraphRewriter
.
warn
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
)
@staticmethod
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
):
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
):
...
@@ -1778,10 +1784,10 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1778,10 +1784,10 @@ class NavigatorOptimizer(GraphRewriter):
node_rewriter
node_rewriter
A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``).
A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees
ignore_newtrees
- ``True``: new subgraphs returned by an
optimization
are not a
- ``True``: new subgraphs returned by an
`NodeRewriter`
are not a
candidate for
optimization
.
candidate for
rewriting
.
- ``False``: new subgraphs returned by an
optimization
is a
- ``False``: new subgraphs returned by an
`NodeRewriter`
is a
candidate for
optimization
.
candidate for
rewriting
.
- ``'auto'``: let the `node_rewriter` set this parameter via its
- ``'auto'``: let the `node_rewriter` set this parameter via its
:attr:`reentrant` attribute.
:attr:`reentrant` attribute.
failure_callback
failure_callback
...
@@ -1970,7 +1976,7 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1970,7 +1976,7 @@ class NavigatorOptimizer(GraphRewriter):
)
)
class
TopoOptimizer
(
N
avigatorOptimiz
er
):
class
TopoOptimizer
(
N
odeProcessingGraphRewrit
er
):
"""An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
"""An optimizer that applies a single `NodeRewriter` to each node in topological order (or reverse)."""
def
__init__
(
def
__init__
(
...
@@ -2116,7 +2122,7 @@ in2out = partial(topogroup_optimizer, "in_to_out")
...
@@ -2116,7 +2122,7 @@ in2out = partial(topogroup_optimizer, "in_to_out")
out2in
=
partial
(
topogroup_optimizer
,
"out_to_in"
)
out2in
=
partial
(
topogroup_optimizer
,
"out_to_in"
)
class
OpKeyOptimizer
(
N
avigatorOptimiz
er
):
class
OpKeyOptimizer
(
N
odeProcessingGraphRewrit
er
):
r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s.
r"""An optimizer that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
...
@@ -2200,7 +2206,7 @@ def merge_dict(d1, d2):
...
@@ -2200,7 +2206,7 @@ def merge_dict(d1, d2):
return
d
return
d
class
EquilibriumOptimizer
(
N
avigatorOptimiz
er
):
class
EquilibriumOptimizer
(
N
odeProcessingGraphRewrit
er
):
"""An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
"""An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
def
__init__
(
def
__init__
(
...
@@ -2222,11 +2228,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2222,11 +2228,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
The global optimizer will be run at the start of each iteration before
The global optimizer will be run at the start of each iteration before
the node rewriter.
the node rewriter.
failure_callback
failure_callback
See :attr:`N
avigatorOptimiz
er.failure_callback`.
See :attr:`N
odeProcessingGraphRewrit
er.failure_callback`.
ignore_newtrees
ignore_newtrees
See :attr:`N
avigatorOptimiz
er.ignore_newtrees`.
See :attr:`N
odeProcessingGraphRewrit
er.ignore_newtrees`.
tracks_on_change_inputs
tracks_on_change_inputs
See :attr:`N
avigatorOptimiz
er.tracks_on_change_inputs`.
See :attr:`N
odeProcessingGraphRewrit
er.tracks_on_change_inputs`.
max_use_ratio
max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times.
times.
...
@@ -3188,6 +3194,11 @@ DEPRECATED_NAMES = [
...
@@ -3188,6 +3194,11 @@ DEPRECATED_NAMES = [
"`PatternSub` is deprecated: use `PatternNodeRewriter` instead."
,
"`PatternSub` is deprecated: use `PatternNodeRewriter` instead."
,
PatternNodeRewriter
,
PatternNodeRewriter
,
),
),
(
"NavigatorOptimizer"
,
"`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead."
,
NodeProcessingGraphRewriter
,
),
]
]
...
...
aesara/graph/optdb.py
浏览文件 @
d5013456
...
@@ -346,7 +346,7 @@ class EquilibriumDB(OptimizationDatabase):
...
@@ -346,7 +346,7 @@ class EquilibriumDB(OptimizationDatabase):
max_use_ratio
=
config
.
optdb__max_use_ratio
,
max_use_ratio
=
config
.
optdb__max_use_ratio
,
ignore_newtrees
=
self
.
ignore_newtrees
,
ignore_newtrees
=
self
.
ignore_newtrees
,
tracks_on_change_inputs
=
self
.
tracks_on_change_inputs
,
tracks_on_change_inputs
=
self
.
tracks_on_change_inputs
,
failure_callback
=
aesara_opt
.
N
avigatorOptimiz
er
.
warn_inplace
,
failure_callback
=
aesara_opt
.
N
odeProcessingGraphRewrit
er
.
warn_inplace
,
final_optimizers
=
final_opts
,
final_optimizers
=
final_opts
,
cleanup_optimizers
=
cleanup_opts
,
cleanup_optimizers
=
cleanup_opts
,
)
)
...
...
doc/extending/graph_rewriting.rst
浏览文件 @
d5013456
...
@@ -74,7 +74,7 @@ A local optimization is an object which defines the following methods:
...
@@ -74,7 +74,7 @@ A local optimization is an object which defines the following methods:
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a
returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`NodeRewriter` is applied by a :class:`N
avigatorOptimiz
er`, the outputs
list. When the :class:`NodeRewriter` is applied by a :class:`N
odeProcessingGraphRewrit
er`, the outputs
of the node passed as argument to the :class:`NodeRewriter` will be replaced by
of the node passed as argument to the :class:`NodeRewriter` will be replaced by
the list returned.
the list returned.
...
@@ -89,7 +89,7 @@ For starters, let's define the following simplification:
...
@@ -89,7 +89,7 @@ For starters, let's define the following simplification:
\frac{xy}{y} = x
\frac{xy}{y} = x
We will implement it in three ways: using a global optimization, a
We will implement it in three ways: using a global optimization, a
local optimization with a :class:`N
avigatorOptimiz
er` and then using the :class:`PatternNodeRewriter`
local optimization with a :class:`N
odeProcessingGraphRewrit
er` and then using the :class:`PatternNodeRewriter`
facility.
facility.
Global optimization
Global optimization
...
@@ -253,7 +253,7 @@ outputs are returned. This list must have the same length as
...
@@ -253,7 +253,7 @@ outputs are returned. This list must have the same length as
you can put ``None`` in the returned list to remove it.
you can put ``None`` in the returned list to remove it.
In order to apply the local optimizer we can use it in conjunction
In order to apply the local optimizer we can use it in conjunction
with a :class:`N
avigatorOptimizer`. Basically, a :class:`NavigatorOptimiz
er` is
with a :class:`N
odeProcessingGraphRewriter`. Basically, a :class:`NodeProcessingGraphRewrit
er` is
a global optimizer that loops through all nodes in the graph (or a well-defined
a global optimizer that loops through all nodes in the graph (or a well-defined
subset of them) and applies one or several local optimizers.
subset of them) and applies one or several local optimizers.
...
@@ -315,7 +315,7 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
...
@@ -315,7 +315,7 @@ Aesara defines some shortcuts to make :class:`NodeRewriter`\s:
:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce local optimizers, which
:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce local optimizers, which
means that everything we said previously about local optimizers
means that everything we said previously about local optimizers
apply (e.g. they need to be wrapped in a :class:`N
avigatorOptimiz
er`, etc.)
apply (e.g. they need to be wrapped in a :class:`N
odeProcessingGraphRewrit
er`, etc.)
When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`
When an optimization can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`
...
@@ -702,7 +702,7 @@ Registering a :class:`NodeRewriter`
...
@@ -702,7 +702,7 @@ Registering a :class:`NodeRewriter`
:class:`NodeRewriter`\s may be registered in two ways:
:class:`NodeRewriter`\s may be registered in two ways:
* Wrap them in a :class:`N
avigatorOptimiz
er` and insert them like a global optimizer
* Wrap them in a :class:`N
odeProcessingGraphRewrit
er` and insert them like a global optimizer
(see previous section).
(see previous section).
* Put them in an :class:`EquilibriumDB`.
* Put them in an :class:`EquilibriumDB`.
...
@@ -795,8 +795,8 @@ under the assumption there are no inplace operations.
...
@@ -795,8 +795,8 @@ under the assumption there are no inplace operations.
.. _navigator:
.. _navigator:
:class:`N
avigatorOptimiz
er`
:class:`N
odeProcessingGraphRewrit
er`
---------------------------
---------------------------
---------
WRITEME
WRITEME
...
...
tests/graph/test_destroyhandler.py
浏览文件 @
d5013456
...
@@ -9,7 +9,7 @@ from aesara.graph.features import ReplaceValidate
...
@@ -9,7 +9,7 @@ from aesara.graph.features import ReplaceValidate
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
N
avigatorOptimiz
er
,
N
odeProcessingGraphRewrit
er
,
OpKeyOptimizer
,
OpKeyOptimizer
,
PatternNodeRewriter
,
PatternNodeRewriter
,
SubstitutionNodeRewriter
,
SubstitutionNodeRewriter
,
...
@@ -25,7 +25,7 @@ def PatternOptimizer(p1, p2, ign=True):
...
@@ -25,7 +25,7 @@ def PatternOptimizer(p1, p2, ign=True):
def
TopoSubstitutionNodeRewriter
(
def
TopoSubstitutionNodeRewriter
(
op1
,
op2
,
fail
=
N
avigatorOptimiz
er
.
warn_ignore
,
ign
=
True
op1
,
op2
,
fail
=
N
odeProcessingGraphRewrit
er
.
warn_ignore
,
ign
=
True
):
):
return
TopoOptimizer
(
return
TopoOptimizer
(
SubstitutionNodeRewriter
(
op1
,
op2
),
ignore_newtrees
=
ign
,
failure_callback
=
fail
SubstitutionNodeRewriter
(
op1
,
op2
),
ignore_newtrees
=
ign
,
failure_callback
=
fail
...
...
tests/graph/test_opt.py
浏览文件 @
d5013456
...
@@ -150,7 +150,7 @@ class TestPatternOptimizer:
...
@@ -150,7 +150,7 @@ class TestPatternOptimizer:
def
test_ambiguous
(
self
):
def
test_ambiguous
(
self
):
# this test should always work with TopoOptimizer and the
# this test should always work with TopoOptimizer and the
# ignore_newtrees flag set to False. Behavior with ignore_newtrees
# ignore_newtrees flag set to False. Behavior with ignore_newtrees
# = True or with other N
avigatorOptimiz
ers may differ.
# = True or with other N
odeProcessingGraphRewrit
ers may differ.
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
x
,
y
,
z
=
MyVariable
(
"x"
),
MyVariable
(
"y"
),
MyVariable
(
"z"
)
e
=
op1
(
op1
(
op1
(
op1
(
op1
(
x
)))))
e
=
op1
(
op1
(
op1
(
op1
(
op1
(
x
)))))
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论