Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
66974655
提交
66974655
authored
7月 15, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename LocalOptGroup to SequentialNodeRewriter
上级
40296322
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
26 行增加
和
21 行删除
+26
-21
opt.py
aesara/graph/opt.py
+14
-9
optdb.py
aesara/graph/optdb.py
+2
-2
math_opt.py
aesara/tensor/math_opt.py
+3
-3
test_opt.py
tests/graph/test_opt.py
+3
-3
test_math_opt.py
tests/tensor/test_math_opt.py
+4
-4
没有找到文件。
aesara/graph/opt.py
浏览文件 @
66974655
...
@@ -1204,7 +1204,7 @@ class OpToRewriterTracker:
...
@@ -1204,7 +1204,7 @@ class OpToRewriterTracker:
)
)
class
LocalOptGroup
(
NodeRewriter
):
class
SequentialNodeRewriter
(
NodeRewriter
):
r"""An optimizer that applies a list of `NodeRewriter`\s to a node.
r"""An optimizer that applies a list of `NodeRewriter`\s to a node.
Attributes
Attributes
...
@@ -1272,7 +1272,7 @@ class LocalOptGroup(NodeRewriter):
...
@@ -1272,7 +1272,7 @@ class LocalOptGroup(NodeRewriter):
return
getattr
(
return
getattr
(
self
,
self
,
"__name__"
,
"__name__"
,
f
"
LocalOptGroup
({','.join([str(o) for o in self.opts])})"
,
f
"
{type(self).__name__}
({','.join([str(o) for o in self.opts])})"
,
)
)
def
tracks
(
self
):
def
tracks
(
self
):
...
@@ -1332,15 +1332,15 @@ class LocalOptGroup(NodeRewriter):
...
@@ -1332,15 +1332,15 @@ class LocalOptGroup(NodeRewriter):
repl
=
new_repl
repl
=
new_repl
node
=
new_vars
[
0
]
.
owner
node
=
new_vars
[
0
]
.
owner
@
static
method
@
class
method
def
print_profile
(
stream
,
prof
,
level
=
0
):
def
print_profile
(
cls
,
stream
,
prof
,
level
=
0
):
(
time_opts
,
process_count
,
applied_true
,
node_created
,
profile
)
=
prof
(
time_opts
,
process_count
,
applied_true
,
node_created
,
profile
)
=
prof
if
not
profile
:
if
not
profile
:
return
return
blanc
=
" "
*
int
(
level
)
blanc
=
" "
*
int
(
level
)
print
(
blanc
,
"LocalOptGroup
"
,
file
=
stream
)
print
(
blanc
,
f
"{cls.__name__}
"
,
file
=
stream
)
print
(
blanc
,
"---------------------"
,
file
=
stream
)
print
(
blanc
,
"---------------------"
,
file
=
stream
)
count_opt
=
[]
count_opt
=
[]
not_used
=
[]
not_used
=
[]
...
@@ -2064,7 +2064,7 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -2064,7 +2064,7 @@ class TopoOptimizer(NavigatorOptimizer):
print
(
blanc
,
" init io_toposort"
,
io_t
,
file
=
stream
)
print
(
blanc
,
" init io_toposort"
,
io_t
,
file
=
stream
)
print
(
blanc
,
" loop time"
,
loop_t
,
file
=
stream
)
print
(
blanc
,
" loop time"
,
loop_t
,
file
=
stream
)
print
(
blanc
,
" callback_time"
,
callback_time
,
file
=
stream
)
print
(
blanc
,
" callback_time"
,
callback_time
,
file
=
stream
)
if
isinstance
(
node_rewriter
,
LocalOptGroup
):
if
isinstance
(
node_rewriter
,
SequentialNodeRewriter
):
if
node_rewriter
.
profile
:
if
node_rewriter
.
profile
:
node_rewriter
.
print_profile
(
node_rewriter
.
print_profile
(
stream
,
stream
,
...
@@ -2089,14 +2089,14 @@ def topogroup_optimizer(
...
@@ -2089,14 +2089,14 @@ def topogroup_optimizer(
failure_callback
=
TopoOptimizer
.
warn_inplace
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
,
**
kwargs
,
):
):
"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
r
"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`
s, and uses a `LocalOptGroup
` when there's
This constructs `TopoOptimizer`
\s, and uses a `SequentialNodeRewriter
` when there's
more than one entry in `node_rewriters`.
more than one entry in `node_rewriters`.
"""
"""
if
len
(
node_rewriters
)
>
1
:
if
len
(
node_rewriters
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
# Don't wrap it uselessly if their is only 1 optimization.
node_rewriters
=
LocalOptGroup
(
*
node_rewriters
)
node_rewriters
=
SequentialNodeRewriter
(
*
node_rewriters
)
else
:
else
:
(
node_rewriters
,)
=
node_rewriters
(
node_rewriters
,)
=
node_rewriters
if
not
name
:
if
not
name
:
...
@@ -3168,6 +3168,11 @@ DEPRECATED_NAMES = [
...
@@ -3168,6 +3168,11 @@ DEPRECATED_NAMES = [
"`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead."
,
"`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead."
,
OpToRewriterTracker
,
OpToRewriterTracker
,
),
),
(
"LocalOptGroup"
,
"`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead."
,
SequentialNodeRewriter
,
),
]
]
...
...
aesara/graph/optdb.py
浏览文件 @
66974655
...
@@ -457,13 +457,13 @@ class SequenceDB(OptimizationDatabase):
...
@@ -457,13 +457,13 @@ class SequenceDB(OptimizationDatabase):
class
LocalGroupDB
(
SequenceDB
):
class
LocalGroupDB
(
SequenceDB
):
r"""A database that generates `NodeRewriter`\s of type `
LocalOptGroup
`."""
r"""A database that generates `NodeRewriter`\s of type `
SequentialNodeRewriter
`."""
def
__init__
(
def
__init__
(
self
,
self
,
apply_all_opts
:
bool
=
False
,
apply_all_opts
:
bool
=
False
,
profile
:
bool
=
False
,
profile
:
bool
=
False
,
node_rewriter
=
aesara_opt
.
LocalOptGroup
,
node_rewriter
=
aesara_opt
.
SequentialNodeRewriter
,
):
):
super
()
.
__init__
(
failure_callback
=
None
)
super
()
.
__init__
(
failure_callback
=
None
)
self
.
apply_all_opts
=
apply_all_opts
self
.
apply_all_opts
=
apply_all_opts
...
...
aesara/tensor/math_opt.py
浏览文件 @
66974655
...
@@ -10,9 +10,9 @@ import aesara.scalar.basic as aes
...
@@ -10,9 +10,9 @@ import aesara.scalar.basic as aes
import
aesara.scalar.math
as
aes_math
import
aesara.scalar.math
as
aes_math
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
LocalOptGroup
,
NodeRewriter
,
NodeRewriter
,
PatternSub
,
PatternSub
,
SequentialNodeRewriter
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
node_rewriter
,
node_rewriter
,
...
@@ -2117,7 +2117,7 @@ def local_add_specialize(fgraph, node):
...
@@ -2117,7 +2117,7 @@ def local_add_specialize(fgraph, node):
mul_canonizer
=
in2out
(
mul_canonizer
=
in2out
(
LocalOptGroup
(
local_mul_canonizer
,
local_fill_sink
,
apply_all_opts
=
True
),
SequentialNodeRewriter
(
local_mul_canonizer
,
local_fill_sink
,
apply_all_opts
=
True
),
name
=
"mul_canonizer_groups"
,
name
=
"mul_canonizer_groups"
,
)
)
...
@@ -2344,7 +2344,7 @@ def add_calculate(num, denum, aslist=False, out_type=None):
...
@@ -2344,7 +2344,7 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer
=
AlgebraicCanonizer
(
add
,
sub
,
neg
,
add_calculate
)
local_add_canonizer
=
AlgebraicCanonizer
(
add
,
sub
,
neg
,
add_calculate
)
add_canonizer
=
in2out
(
add_canonizer
=
in2out
(
LocalOptGroup
(
local_add_canonizer
,
local_fill_sink
,
apply_all_opts
=
True
),
SequentialNodeRewriter
(
local_add_canonizer
,
local_fill_sink
,
apply_all_opts
=
True
),
name
=
"add_canonizer_group"
,
name
=
"add_canonizer_group"
,
)
)
...
...
tests/graph/test_opt.py
浏览文件 @
66974655
...
@@ -7,12 +7,12 @@ from aesara.graph.fg import FunctionGraph
...
@@ -7,12 +7,12 @@ from aesara.graph.fg import FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
EquilibriumOptimizer
,
EquilibriumOptimizer
,
LocalOptGroup
,
MergeOptimizer
,
MergeOptimizer
,
OpKeyOptimizer
,
OpKeyOptimizer
,
OpSub
,
OpSub
,
OpToRewriterTracker
,
OpToRewriterTracker
,
PatternSub
,
PatternSub
,
SequentialNodeRewriter
,
TopoOptimizer
,
TopoOptimizer
,
in2out
,
in2out
,
logging
,
logging
,
...
@@ -664,7 +664,7 @@ def test_patternsub_different_output_lengths():
...
@@ -664,7 +664,7 @@ def test_patternsub_different_output_lengths():
assert
fgraph
.
outputs
[
0
]
.
owner
.
op
==
op1
assert
fgraph
.
outputs
[
0
]
.
owner
.
op
==
op1
class
Test
LocalOptGroup
:
class
Test
SequentialNodeRewriter
:
def
test_optimizer_verbose
(
self
,
capsys
):
def
test_optimizer_verbose
(
self
,
capsys
):
x
=
MyVariable
(
"x"
)
x
=
MyVariable
(
"x"
)
...
@@ -685,7 +685,7 @@ class TestLocalOptGroup:
...
@@ -685,7 +685,7 @@ class TestLocalOptGroup:
res
=
op2
(
x
,
*
node
.
inputs
[
1
:])
res
=
op2
(
x
,
*
node
.
inputs
[
1
:])
return
[
res
]
return
[
res
]
opt_group
=
LocalOptGroup
(
local_opt_1
,
local_opt_2
)
opt_group
=
SequentialNodeRewriter
(
local_opt_1
,
local_opt_2
)
with
config
.
change_flags
(
optimizer_verbose
=
True
):
with
config
.
change_flags
(
optimizer_verbose
=
True
):
(
new_res
,)
=
opt_group
.
transform
(
fgraph
,
o1
.
owner
)
(
new_res
,)
=
opt_group
.
transform
(
fgraph
,
o1
.
owner
)
...
...
tests/tensor/test_math_opt.py
浏览文件 @
66974655
...
@@ -19,7 +19,7 @@ from aesara.configdefaults import config
...
@@ -19,7 +19,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
,
Constant
,
equal_computations
from
aesara.graph.basic
import
Apply
,
Constant
,
equal_computations
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
LocalOptGroup
,
SequentialNodeRewriter
,
TopoOptimizer
,
TopoOptimizer
,
check_stack_trace
,
check_stack_trace
,
in2out
,
in2out
,
...
@@ -191,7 +191,7 @@ class TestGreedyDistribute:
...
@@ -191,7 +191,7 @@ class TestGreedyDistribute:
g
=
FunctionGraph
([
a
,
b
,
c
,
d
,
x
,
y
,
z
],
[
e
])
g
=
FunctionGraph
([
a
,
b
,
c
,
d
,
x
,
y
,
z
],
[
e
])
mul_canonizer
.
optimize
(
g
)
mul_canonizer
.
optimize
(
g
)
TopoOptimizer
(
TopoOptimizer
(
LocalOptGroup
(
local_greedy_distributor
),
order
=
"out_to_in"
SequentialNodeRewriter
(
local_greedy_distributor
),
order
=
"out_to_in"
)
.
optimize
(
g
)
)
.
optimize
(
g
)
assert
str
(
pprint
(
g
.
outputs
[
0
]))
==
"((a * x) + (b * z))"
assert
str
(
pprint
(
g
.
outputs
[
0
]))
==
"((a * x) + (b * z))"
...
@@ -200,7 +200,7 @@ class TestGreedyDistribute:
...
@@ -200,7 +200,7 @@ class TestGreedyDistribute:
g
=
FunctionGraph
([
a
,
b
,
x
],
[
e
])
g
=
FunctionGraph
([
a
,
b
,
x
],
[
e
])
mul_canonizer
.
optimize
(
g
)
mul_canonizer
.
optimize
(
g
)
TopoOptimizer
(
TopoOptimizer
(
LocalOptGroup
(
local_greedy_distributor
),
order
=
"out_to_in"
SequentialNodeRewriter
(
local_greedy_distributor
),
order
=
"out_to_in"
)
.
optimize
(
g
)
)
.
optimize
(
g
)
assert
str
(
pprint
(
g
.
outputs
[
0
]))
==
"(a + (b * x))"
assert
str
(
pprint
(
g
.
outputs
[
0
]))
==
"(a + (b * x))"
...
@@ -3053,7 +3053,7 @@ class TestLocalErfc:
...
@@ -3053,7 +3053,7 @@ class TestLocalErfc:
fg
=
FunctionGraph
(
inputs
,
[
no_match
],
clone
=
False
)
fg
=
FunctionGraph
(
inputs
,
[
no_match
],
clone
=
False
)
TopoOptimizer
(
TopoOptimizer
(
LocalOptGroup
(
local_grad_log_erfc_neg
),
order
=
"out_to_in"
SequentialNodeRewriter
(
local_grad_log_erfc_neg
),
order
=
"out_to_in"
)
.
optimize
(
fg
)
)
.
optimize
(
fg
)
# Make sure that the graph hasn't been changed
# Make sure that the graph hasn't been changed
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论