Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0ce6eceb
提交
0ce6eceb
authored
7月 14, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor old global and local optimizers references and type hints
上级
550a6e98
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
286 行增加
和
283 行删除
+286
-283
kanren.py
aesara/graph/kanren.py
+2
-2
opt.py
aesara/graph/opt.py
+94
-69
optdb.py
aesara/graph/optdb.py
+17
-32
basic_opt.py
aesara/tensor/basic_opt.py
+169
-161
basic.py
aesara/tensor/nnet/basic.py
+4
-19
没有找到文件。
aesara/graph/kanren.py
浏览文件 @
0ce6eceb
...
@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple
...
@@ -11,7 +11,7 @@ from aesara.graph.unify import eval_if_etuple
class
KanrenRelationSub
(
NodeRewriter
):
class
KanrenRelationSub
(
NodeRewriter
):
r"""A
local optimiz
er that uses `kanren` to match and replace terms.
r"""A
rewrit
er that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information
See `kanren <https://github.com/pythological/kanren>`__ for more information
miniKanren and the API for constructing `kanren` goals.
miniKanren and the API for constructing `kanren` goals.
...
@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter):
...
@@ -56,7 +56,7 @@ class KanrenRelationSub(NodeRewriter):
A function that takes an input graph and an output logic variable and
A function that takes an input graph and an output logic variable and
returns a `kanren` goal.
returns a `kanren` goal.
results_filter
results_filter
A function that takes the direct output of `
kanren.run(None, ...)
`
A function that takes the direct output of `
`kanren.run(None, ...)`
`
and returns a single result. The default implementation returns
and returns a single result. The default implementation returns
the first result.
the first result.
node_filter
node_filter
...
...
aesara/graph/opt.py
浏览文件 @
0ce6eceb
...
@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque
...
@@ -17,7 +17,7 @@ from collections import UserList, defaultdict, deque
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
functools
import
_compose_mro
,
partial
,
reduce
# type: ignore
from
functools
import
_compose_mro
,
partial
,
reduce
# type: ignore
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
...
@@ -156,15 +156,20 @@ class NodeRewriter(Rewriter):
...
@@ -156,15 +156,20 @@ class NodeRewriter(Rewriter):
@abc.abstractmethod
@abc.abstractmethod
def
transform
(
def
transform
(
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
,
*
args
,
**
kwargs
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
,
*
args
,
**
kwargs
)
->
Union
[
bool
,
List
[
Variable
],
Dict
[
Variable
,
Variable
]]:
)
->
Union
[
r"""Transform a subgraph whose output is `node`.
bool
,
Sequence
[
Variable
],
Dict
[
Union
[
Variable
,
Literal
[
"remove"
]],
Union
[
Variable
,
Sequence
[
Variable
]]],
]:
r"""Rewrite the sub-graph given by `node`.
Subclasses should implement this function so that it returns one of the
Subclasses should implement this function so that it returns one of the
following:
following:
- ``False`` to indicate that this rewrite cannot be applied to `node`
- ``False`` to indicate that this rewrite cannot be applied to `node`
- A list of `Variable`\s to use in place of the `node`'s current outputs
- A list of `Variable`\s to use in place of the `node`'s current outputs
- A ``dict`` mapping old `Variable`\s to new `Variable`\s
- A ``dict`` mapping old `Variable`\s to `Variable`\s, or the key
``"remove"`` mapping to a list of `Variable`\s to be removed.
Parameters
Parameters
----------
----------
...
@@ -1850,10 +1855,15 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1850,10 +1855,15 @@ class NavigatorOptimizer(GraphRewriter):
if
u
is
not
None
:
if
u
is
not
None
:
fgraph
.
remove_feature
(
u
)
fgraph
.
remove_feature
(
u
)
def
process_node
(
self
,
fgraph
,
node
,
lopt
=
None
):
def
process_node
(
r"""Apply `lopt` to `node`.
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
,
node_rewriter
:
Optional
[
NodeRewriter
]
=
None
,
):
r"""Apply `node_rewriter` to `node`.
The :meth:`
lopt
.transform` method will return either ``False`` or a
The :meth:`
node_rewriter
.transform` method will return either ``False`` or a
list of `Variable`\s that are intended to replace :attr:`node.outputs`.
list of `Variable`\s that are intended to replace :attr:`node.outputs`.
If the `fgraph` accepts the replacement, then the optimization is
If the `fgraph` accepts the replacement, then the optimization is
...
@@ -1864,11 +1874,11 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1864,11 +1874,11 @@ class NavigatorOptimizer(GraphRewriter):
Parameters
Parameters
----------
----------
fgraph
:
fgraph
A `FunctionGraph`.
A `FunctionGraph`.
node
:
node
An `Apply` instance in `fgraph`
An `Apply` instance in `fgraph`
lopt :
node_rewriter
A `NodeRewriter` instance that may have a better idea for
A `NodeRewriter` instance that may have a better idea for
how to compute node's outputs.
how to compute node's outputs.
...
@@ -1878,13 +1888,15 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1878,13 +1888,15 @@ class NavigatorOptimizer(GraphRewriter):
``True`` iff the `node`'s outputs were replaced in the `fgraph`.
``True`` iff the `node`'s outputs were replaced in the `fgraph`.
"""
"""
lopt
=
lopt
or
self
.
node_rewriter
node_rewriter
=
node_rewriter
or
self
.
node_rewriter
# TODO FIXME: This class's interface is broken
assert
node_rewriter
is
not
None
try
:
try
:
replacements
=
lopt
.
transform
(
fgraph
,
node
)
replacements
=
node_rewriter
.
transform
(
fgraph
,
node
)
except
Exception
as
e
:
except
Exception
as
e
:
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
,
node
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
node_rewriter
,
node
)
)
return
False
return
False
else
:
else
:
...
@@ -1892,25 +1904,27 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1892,25 +1904,27 @@ class NavigatorOptimizer(GraphRewriter):
if
replacements
is
False
or
replacements
is
None
:
if
replacements
is
False
or
replacements
is
None
:
return
False
return
False
old_vars
=
node
.
outputs
old_vars
=
node
.
outputs
remove
=
[]
remove
:
List
[
Variable
]
=
[]
if
isinstance
(
replacements
,
dict
):
if
isinstance
(
replacements
,
dict
):
if
"remove"
in
replacements
:
if
"remove"
in
replacements
:
remove
=
replacements
.
pop
(
"remove"
)
remove
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
pop
(
"remove"
))
)
old_vars
=
list
(
replacements
.
keys
(
))
old_vars
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
keys
()
))
replacements
=
list
(
replacements
.
values
(
))
replacements
=
list
(
cast
(
Sequence
[
Variable
],
replacements
.
values
()
))
elif
not
isinstance
(
replacements
,
(
tuple
,
list
)):
elif
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
raise
TypeError
(
f
"Node rewriter {
lopt
} gave wrong type of replacement. "
f
"Node rewriter {
node_rewriter
} gave wrong type of replacement. "
f
"Expected list or tuple; got {replacements}"
f
"Expected list or tuple; got {replacements}"
)
)
if
len
(
old_vars
)
!=
len
(
replacements
):
if
len
(
old_vars
)
!=
len
(
replacements
):
raise
ValueError
(
f
"Node rewriter {lopt} gave wrong number of replacements"
)
raise
ValueError
(
f
"Node rewriter {node_rewriter} gave wrong number of replacements"
)
# None in the replacement mean that this variable isn't used
# None in the replacement mean that this variable isn't used
# and we want to remove it
# and we want to remove it
for
r
,
rnew
in
zip
(
old_vars
,
replacements
):
for
r
,
rnew
in
zip
(
old_vars
,
replacements
):
if
rnew
is
None
and
len
(
fgraph
.
clients
[
r
])
>
0
:
if
rnew
is
None
and
len
(
fgraph
.
clients
[
r
])
>
0
:
raise
ValueError
(
raise
ValueError
(
f
"Node rewriter {
lopt
} tried to remove a variable"
f
"Node rewriter {
node_rewriter
} tried to remove a variable"
f
" that is being used: {r}"
f
" that is being used: {r}"
)
)
# If an output would be replaced by itself, no need to perform
# If an output would be replaced by itself, no need to perform
...
@@ -1924,7 +1938,9 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1924,7 +1938,9 @@ class NavigatorOptimizer(GraphRewriter):
if
len
(
repl_pairs
)
==
0
:
if
len
(
repl_pairs
)
==
0
:
return
False
return
False
try
:
try
:
fgraph
.
replace_all_validate_remove
(
repl_pairs
,
reason
=
lopt
,
remove
=
remove
)
fgraph
.
replace_all_validate_remove
(
# type: ignore
repl_pairs
,
reason
=
node_rewriter
,
remove
=
remove
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
# This means the replacements were rejected by the fgraph.
# This means the replacements were rejected by the fgraph.
...
@@ -1932,7 +1948,7 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1932,7 +1948,7 @@ class NavigatorOptimizer(GraphRewriter):
# This is not supposed to happen. The default failure_callback
# This is not supposed to happen. The default failure_callback
# will print a traceback as a warning.
# will print a traceback as a warning.
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
,
node
)
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
node_rewriter
,
node
)
return
False
return
False
else
:
else
:
raise
raise
...
@@ -2027,7 +2043,7 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -2027,7 +2043,7 @@ class TopoOptimizer(NavigatorOptimizer):
io_t
,
io_t
,
loop_t
,
loop_t
,
callback_time
,
callback_time
,
lopt
,
node_rewriter
,
)
=
prof
)
=
prof
print
(
print
(
...
@@ -2046,16 +2062,16 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -2046,16 +2062,16 @@ class TopoOptimizer(NavigatorOptimizer):
print
(
blanc
,
" init io_toposort"
,
io_t
,
file
=
stream
)
print
(
blanc
,
" init io_toposort"
,
io_t
,
file
=
stream
)
print
(
blanc
,
" loop time"
,
loop_t
,
file
=
stream
)
print
(
blanc
,
" loop time"
,
loop_t
,
file
=
stream
)
print
(
blanc
,
" callback_time"
,
callback_time
,
file
=
stream
)
print
(
blanc
,
" callback_time"
,
callback_time
,
file
=
stream
)
if
isinstance
(
lopt
,
LocalOptGroup
):
if
isinstance
(
node_rewriter
,
LocalOptGroup
):
if
lopt
.
profile
:
if
node_rewriter
.
profile
:
lopt
.
print_profile
(
node_rewriter
.
print_profile
(
stream
,
stream
,
(
(
lopt
.
time_opts
,
node_rewriter
.
time_opts
,
lopt
.
process_count
,
node_rewriter
.
process_count
,
lopt
.
applied_true
,
node_rewriter
.
applied_true
,
lopt
.
node_created
,
node_rewriter
.
node_created
,
lopt
.
profile
,
node_rewriter
.
profile
,
),
),
level
=
level
+
1
,
level
=
level
+
1
,
)
)
...
@@ -2228,11 +2244,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2228,11 +2244,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
global_optimizers
:
List
[
GraphRewriter
]
=
[]
self
.
global_optimizers
:
List
[
GraphRewriter
]
=
[]
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
local
_tracker
=
LocalOptTracker
()
self
.
node
_tracker
=
LocalOptTracker
()
for
opt
in
optimizers
:
for
opt
in
optimizers
:
if
isinstance
(
opt
,
NodeRewriter
):
if
isinstance
(
opt
,
NodeRewriter
):
self
.
local
_tracker
.
add_tracker
(
opt
)
self
.
node
_tracker
.
add_tracker
(
opt
)
else
:
else
:
assert
isinstance
(
opt
,
GraphRewriter
)
assert
isinstance
(
opt
,
GraphRewriter
)
self
.
global_optimizers
.
append
(
opt
)
self
.
global_optimizers
.
append
(
opt
)
...
@@ -2250,7 +2266,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2250,7 +2266,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
max_use_ratio
=
max_use_ratio
self
.
max_use_ratio
=
max_use_ratio
def
get_node_rewriters
(
self
):
def
get_node_rewriters
(
self
):
yield
from
self
.
local
_tracker
.
get_rewriters
()
yield
from
self
.
node
_tracker
.
get_rewriters
()
def
get_local_optimizers
(
self
):
def
get_local_optimizers
(
self
):
warnings
.
warn
(
warnings
.
warn
(
...
@@ -2357,11 +2373,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2357,11 +2373,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
global_opt_timing
.
append
(
float
(
time
.
time
()
-
t0
))
global_opt_timing
.
append
(
float
(
time
.
time
()
-
t0
))
# apply clean up as global opt can have done changes that
# request that
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
# apply local optimizer
topo_t0
=
time
.
time
()
topo_t0
=
time
.
time
()
q
=
deque
(
io_toposort
(
fgraph
.
inputs
,
start_from
))
q
=
deque
(
io_toposort
(
fgraph
.
inputs
,
start_from
))
io_toposort_timing
.
append
(
time
.
time
()
-
topo_t0
)
io_toposort_timing
.
append
(
time
.
time
()
-
topo_t0
)
...
@@ -2390,23 +2403,25 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2390,23 +2403,25 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if
node
not
in
fgraph
.
apply_nodes
:
if
node
not
in
fgraph
.
apply_nodes
:
continue
continue
current_node
=
node
current_node
=
node
for
lopt
in
self
.
local
_tracker
.
get_trackers
(
node
.
op
):
for
node_rewriter
in
self
.
node
_tracker
.
get_trackers
(
node
.
op
):
nb
=
change_tracker
.
nb_imported
nb
=
change_tracker
.
nb_imported
t_opt
=
time
.
time
()
t_opt
=
time
.
time
()
lopt_change
=
self
.
process_node
(
fgraph
,
node
,
lopt
)
node_rewriter_change
=
self
.
process_node
(
time_opts
[
lopt
]
+=
time
.
time
()
-
t_opt
fgraph
,
node
,
node_rewriter
if
not
lopt_change
:
)
time_opts
[
node_rewriter
]
+=
time
.
time
()
-
t_opt
if
not
node_rewriter_change
:
continue
continue
process_count
.
setdefault
(
lopt
,
0
)
process_count
.
setdefault
(
node_rewriter
,
0
)
process_count
[
lopt
]
+=
1
process_count
[
node_rewriter
]
+=
1
global_process_count
[
lopt
]
+=
1
global_process_count
[
node_rewriter
]
+=
1
changed
=
True
changed
=
True
node_created
[
lopt
]
+=
change_tracker
.
nb_imported
-
nb
node_created
[
node_rewriter
]
+=
change_tracker
.
nb_imported
-
nb
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
changed
|=
apply_cleanup
(
iter_cleanup_sub_profs
)
if
global_process_count
[
lopt
]
>
max_use
:
if
global_process_count
[
node_rewriter
]
>
max_use
:
max_use_abort
=
True
max_use_abort
=
True
opt_name
=
getattr
(
lopt
,
"name"
,
None
)
or
getattr
(
opt_name
=
getattr
(
node_rewriter
,
"name"
,
None
)
or
getattr
(
lopt
,
"__name__"
,
""
node_rewriter
,
"__name__"
,
""
)
)
if
node
not
in
fgraph
.
apply_nodes
:
if
node
not
in
fgraph
.
apply_nodes
:
# go to next node
# go to next node
...
@@ -2494,8 +2509,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2494,8 +2509,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
f
"{' ' * level}{self.__class__.__name__} {name} id={id(self)}"
,
file
=
stream
f
"{' ' * level}{self.__class__.__name__} {name} id={id(self)}"
,
file
=
stream
)
)
if
depth
!=
0
:
if
depth
!=
0
:
for
lopt
in
self
.
get_node_rewriters
():
for
node_rewriter
in
self
.
get_node_rewriters
():
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
node_rewriter
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
)
)
@staticmethod
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
def
print_profile
(
stream
,
prof
,
level
=
0
):
...
@@ -2529,27 +2546,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2529,27 +2546,27 @@ class EquilibriumOptimizer(NavigatorOptimizer):
)
)
print
(
blanc
,
f
" time io_toposort {sum(io_toposort_timing):.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time io_toposort {sum(io_toposort_timing):.3f}s"
,
file
=
stream
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
get_node_rewriters
())
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
get_node_rewriters
())
print
(
blanc
,
f
" time in
local optimiz
ers {s:.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time in
node rewrit
ers {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
global_optimizers
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
global_optimizers
)
print
(
blanc
,
f
" time in g
lobal optimiz
ers {s:.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time in g
raph rewrit
ers {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
final_optimizers
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
final_optimizers
)
print
(
blanc
,
f
" time in final
optimiz
ers {s:.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time in final
rewrit
ers {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
cleanup_optimizers
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
cleanup_optimizers
)
print
(
blanc
,
f
" time in cleanup
optimiz
ers {s:.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time in cleanup
rewrit
ers {s:.3f}s"
,
file
=
stream
)
for
i
in
range
(
len
(
loop_timing
)):
for
i
in
range
(
len
(
loop_timing
)):
lo
pt
=
""
lo
op_times
=
""
if
loop_process_count
[
i
]:
if
loop_process_count
[
i
]:
d
=
list
(
d
=
list
(
reversed
(
sorted
(
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
]))
reversed
(
sorted
(
loop_process_count
[
i
]
.
items
(),
key
=
lambda
a
:
a
[
1
]))
)
)
lo
pt
=
" "
.
join
([
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]])
lo
op_times
=
" "
.
join
([
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]])
if
len
(
d
)
>
5
:
if
len
(
d
)
>
5
:
lo
pt
+=
" ..."
lo
op_times
+=
" ..."
print
(
print
(
blanc
,
blanc
,
(
(
f
" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in g
lobal opt
s, "
f
" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in g
raph rewriter
s, "
f
"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {lo
pt
}"
f
"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {lo
op_times
}"
),
),
file
=
stream
,
file
=
stream
,
)
)
...
@@ -2784,8 +2801,10 @@ def check_chain(r, *chain):
...
@@ -2784,8 +2801,10 @@ def check_chain(r, *chain):
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
def
pre_greedy_node_rewriter
(
fgraph
,
optimizations
,
out
):
def
pre_greedy_node_rewriter
(
"""Apply local optimizations to a graph.
fgraph
:
FunctionGraph
,
optimizations
:
Sequence
[
NodeRewriter
],
out
:
Variable
)
->
Variable
:
"""Apply node rewriters throughout a graph in a greedy, pre-traversal way.
This function traverses the computation graph in the graph before the
This function traverses the computation graph in the graph before the
variable `out` but that are not in the `fgraph`. It applies
variable `out` but that are not in the `fgraph`. It applies
...
@@ -2796,7 +2815,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
...
@@ -2796,7 +2815,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
This changes the nodes in a graph in-place.
This changes the nodes in a graph in-place.
Its main use is to apply locally constant folding when generating
Its main use is to apply locally constant folding when generating
the graph of the indices of a
subtensor
.
the graph of the indices of a
`Subtensor`
.
Changes should not be applied to nodes that are in an `fgraph`,
Changes should not be applied to nodes that are in an `fgraph`,
so we use `fgraph` to prevent that.
so we use `fgraph` to prevent that.
...
@@ -2810,16 +2829,21 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
...
@@ -2810,16 +2829,21 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
Parameters
Parameters
----------
----------
fgraph
: FunctionGraph
fgraph
The graph used to avoid/filter nodes.
The graph used to avoid/filter nodes.
optimizations
: list of NodeRewriter
optimizations
The list of local optimizations to apply
A sequence of rewrites to apply.
out
: Variable
out
A `Variable` specifying t
he graph to optimize.
T
he graph to optimize.
"""
"""
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
def
local_recursive_function
(
list_opt
:
Sequence
[
NodeRewriter
],
out
:
Variable
,
optimized_vars
:
Dict
[
Variable
,
Variable
],
depth
:
int
,
)
->
Tuple
[
List
[
Variable
],
Dict
[
Variable
,
Variable
]]:
if
not
getattr
(
out
,
"owner"
,
None
):
if
not
getattr
(
out
,
"owner"
,
None
):
return
[
out
],
optimized_vars
return
[
out
],
optimized_vars
node
=
out
.
owner
node
=
out
.
owner
...
@@ -2852,6 +2876,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
...
@@ -2852,6 +2876,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
for
opt
in
list_opt
:
for
opt
in
list_opt
:
ret
=
opt
.
transform
(
fgraph
,
node
)
ret
=
opt
.
transform
(
fgraph
,
node
)
if
ret
is
not
False
and
ret
is
not
None
:
if
ret
is
not
False
and
ret
is
not
None
:
assert
isinstance
(
ret
,
Sequence
)
assert
len
(
ret
)
==
len
(
node
.
outputs
),
opt
assert
len
(
ret
)
==
len
(
node
.
outputs
),
opt
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
optimized_vars
[
k
]
=
v
optimized_vars
[
k
]
=
v
...
@@ -2864,7 +2889,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
...
@@ -2864,7 +2889,7 @@ def pre_greedy_node_rewriter(fgraph, optimizations, out):
return
results
,
optimized_vars
return
results
,
optimized_vars
if
out
.
owner
:
if
out
.
owner
:
out_index
=
out
.
owner
.
outputs
.
index
(
out
)
out_index
:
int
=
out
.
owner
.
outputs
.
index
(
out
)
else
:
else
:
out_index
=
0
out_index
=
0
...
...
aesara/graph/optdb.py
浏览文件 @
0ce6eceb
...
@@ -290,55 +290,40 @@ class OptimizationQuery:
...
@@ -290,55 +290,40 @@ class OptimizationQuery:
class
EquilibriumDB
(
OptimizationDatabase
):
class
EquilibriumDB
(
OptimizationDatabase
):
"""
"""A database of rewrites that should be applied until equilibrium is reached.
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Parameters
----------
ignore_newtrees
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
Notes
Notes
-----
-----
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumOptimizer`
supports both.
supports both.
It is probably not a good idea to have
ignore_newtrees=False and
It is probably not a good idea to have
both ``ignore_newtrees == False``
tracks_on_change_inputs=True
and ``tracks_on_change_inputs == True``.
"""
"""
def
__init__
(
self
,
ignore_newtrees
=
True
,
tracks_on_change_inputs
=
False
):
def
__init__
(
self
,
ignore_newtrees
:
bool
=
True
,
tracks_on_change_inputs
:
bool
=
False
):
"""
"""
Parameters
Parameters
==========
----------
ignore_newtrees:
ignore_newtrees
If False, we will apply local opt on new node introduced during local
If ``False``, apply rewrites to new nodes introduced during
optimization application. This could result in less fgraph iterations,
rewriting.
but this doesn't mean it will be faster globally.
tracks_on_change_inputs
tracks_on_change_inputs:
If ``True``, re-apply rewrites on nodes with changed inputs.
If True, we will re-apply local opt on nodes whose inputs
changed during local optimization application. This could
result in less fgraph iterations, but this doesn't mean it
will be faster globally.
"""
"""
super
()
.
__init__
()
super
()
.
__init__
()
self
.
ignore_newtrees
=
ignore_newtrees
self
.
ignore_newtrees
=
ignore_newtrees
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
__final__
=
{}
self
.
__final__
:
Dict
[
str
,
aesara_opt
.
Rewriter
]
=
{}
self
.
__cleanup__
=
{}
self
.
__cleanup__
:
Dict
[
str
,
aesara_opt
.
Rewriter
]
=
{}
def
register
(
self
,
name
,
obj
,
*
tags
,
final_opt
=
False
,
cleanup
=
False
,
**
kwargs
):
def
register
(
self
,
name
,
obj
,
*
tags
,
final_opt
=
False
,
cleanup
=
False
,
**
kwargs
):
if
final_opt
and
cleanup
:
if
final_opt
and
cleanup
:
...
...
aesara/tensor/basic_opt.py
浏览文件 @
0ce6eceb
...
@@ -6,7 +6,7 @@ import time
...
@@ -6,7 +6,7 @@ import time
import
traceback
import
traceback
from
collections
import
defaultdict
from
collections
import
defaultdict
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -28,13 +28,15 @@ from aesara.graph.fg import FunctionGraph
...
@@ -28,13 +28,15 @@ from aesara.graph.fg import FunctionGraph
from
aesara.graph.op
import
compute_test_value
,
get_test_value
from
aesara.graph.op
import
compute_test_value
,
get_test_value
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
GraphRewriter
,
GraphRewriter
,
NodeRewriter
,
OpRemove
,
OpRemove
,
Rewriter
,
check_chain
,
check_chain
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
node_rewriter
,
node_rewriter
,
)
)
from
aesara.graph.optdb
import
SequenceDB
from
aesara.graph.optdb
import
OptimizationDatabase
,
SequenceDB
from
aesara.graph.utils
import
(
from
aesara.graph.utils
import
(
InconsistencyError
,
InconsistencyError
,
MethodNotDefined
,
MethodNotDefined
,
...
@@ -193,21 +195,19 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -193,21 +195,19 @@ class InplaceElemwiseOptimizer(GraphRewriter):
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
print
(
blanc
,
n
,
ndim
[
n
],
file
=
stream
)
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
"""
r"""
Usage: InplaceElemwiseOptimizer(op).optimize(fgraph)
Attempts to replace all
Broadcast ops by versions of them
Attempts to replace all
`Elemwise`\s by versions of them that operate
that operate inplace. It operates greedily: for each Broadcast
inplace. It operates greedily: for each `Elemwise` that is encountered,
Op that is encountered, for each output, tries each input to
for each output, it tries each input to see if it can operate inplace
see if it can operate inplace on that input. If so, makes the
on that input. If so, it makes the change and goes to the next output
change and go to the next output or Broadcast Op
.
or `Elemwise`
.
Examples
Examples
--------
--------
`x + y + z -> x += y += z`
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
`(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)`
"""
"""
# We should not validate too often as this takes too much time to
# We should not validate too often as this takes too much time to
...
@@ -225,7 +225,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -225,7 +225,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# Maybe Aesara should do online toposort as in
# Maybe Aesara should do online toposort as in
# http://code.google.com/p/acyclic
# http://code.google.com/p/acyclic
#
#
# The next longest
optimiz
er is the canonizer phase.
# The next longest
rewrit
er is the canonizer phase.
# Then I think it is the [io_?]toposort (need to validate) so check if
# Then I think it is the [io_?]toposort (need to validate) so check if
# the solution is also applicable there.
# the solution is also applicable there.
...
@@ -429,8 +429,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -429,8 +429,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if
check_each_change
!=
1
and
not
raised_warning
:
if
check_each_change
!=
1
and
not
raised_warning
:
print
(
print
(
(
(
"Some inplace
optimization
was not "
"Some inplace
rewriting
was not "
"performed due to unexpected error:"
"performed due to
an
unexpected error:"
),
),
file
=
sys
.
stderr
,
file
=
sys
.
stderr
,
)
)
...
@@ -450,8 +450,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -450,8 +450,8 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if
not
raised_warning
:
if
not
raised_warning
:
print
(
print
(
(
(
"Some inplace
optimization
was not "
"Some inplace
rewriting
was not "
"performed due to unexpected error"
"performed due to
an
unexpected error"
),
),
file
=
sys
.
stderr
,
file
=
sys
.
stderr
,
)
)
...
@@ -478,91 +478,111 @@ compile.optdb.register(
...
@@ -478,91 +478,111 @@ compile.optdb.register(
)
)
def
register_useless
(
lopt
,
*
tags
,
**
kwargs
):
def
register_useless
(
if
isinstance
(
lopt
,
str
):
node_rewriter
:
Union
[
OptimizationDatabase
,
NodeRewriter
,
str
],
*
tags
,
**
kwargs
):
if
isinstance
(
node_rewriter
,
str
):
def
register
(
inner_
lopt
):
def
register
(
inner_
rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
]
):
return
register_useless
(
inner_
lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register_useless
(
inner_
rewriter
,
node_rewriter
,
*
tags
,
**
kwargs
)
return
register
return
register
else
:
else
:
name
=
kwargs
.
pop
(
"name"
,
None
)
or
lopt
.
__name__
name
=
kwargs
.
pop
(
"name"
,
None
)
or
node_rewriter
.
__name__
compile
.
mode
.
local_useless
.
register
(
compile
.
mode
.
local_useless
.
register
(
name
,
lopt
,
"fast_run"
,
*
tags
,
position
=
"last"
,
**
kwargs
name
,
node_rewriter
,
"fast_run"
,
*
tags
,
position
=
"last"
,
**
kwargs
)
)
return
lopt
return
node_rewriter
def
register_canonicalize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_canonicalize
(
if
isinstance
(
lopt
,
str
):
node_rewriter
:
Union
[
OptimizationDatabase
,
NodeRewriter
,
str
],
*
tags
:
str
,
**
kwargs
):
if
isinstance
(
node_rewriter
,
str
):
def
register
(
inner_
lopt
):
def
register
(
inner_
rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
]
):
return
register_canonicalize
(
inner_
lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register_canonicalize
(
inner_
rewriter
,
node_rewriter
,
*
tags
,
**
kwargs
)
return
register
return
register
else
:
else
:
name
=
kwargs
.
pop
(
"name"
,
None
)
or
lopt
.
__name__
name
=
kwargs
.
pop
(
"name"
,
None
)
or
node_rewriter
.
__name__
compile
.
optdb
[
"canonicalize"
]
.
register
(
compile
.
optdb
[
"canonicalize"
]
.
register
(
name
,
lopt
,
"fast_run"
,
"fast_compile"
,
*
tags
,
**
kwargs
name
,
node_rewriter
,
"fast_run"
,
"fast_compile"
,
*
tags
,
**
kwargs
)
)
return
lopt
return
node_rewriter
def
register_stabilize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_stabilize
(
if
isinstance
(
lopt
,
str
):
node_rewriter
:
Union
[
OptimizationDatabase
,
NodeRewriter
,
str
],
*
tags
:
str
,
**
kwargs
):
if
isinstance
(
node_rewriter
,
str
):
def
register
(
inner_
lopt
):
def
register
(
inner_
rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
]
):
return
register_stabilize
(
inner_
lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register_stabilize
(
inner_
rewriter
,
node_rewriter
,
*
tags
,
**
kwargs
)
return
register
return
register
else
:
else
:
name
=
kwargs
.
pop
(
"name"
,
None
)
or
lopt
.
__name__
name
=
kwargs
.
pop
(
"name"
,
None
)
or
node_rewriter
.
__name__
compile
.
optdb
[
"stabilize"
]
.
register
(
name
,
lopt
,
"fast_run"
,
*
tags
,
**
kwargs
)
compile
.
optdb
[
"stabilize"
]
.
register
(
return
lopt
name
,
node_rewriter
,
"fast_run"
,
*
tags
,
**
kwargs
)
return
node_rewriter
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_specialize
(
if
isinstance
(
lopt
,
str
):
node_rewriter
:
Union
[
OptimizationDatabase
,
NodeRewriter
,
str
],
*
tags
:
str
,
**
kwargs
):
if
isinstance
(
node_rewriter
,
str
):
def
register
(
inner_
lopt
):
def
register
(
inner_
rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
]
):
return
register_specialize
(
inner_
lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register_specialize
(
inner_
rewriter
,
node_rewriter
,
*
tags
,
**
kwargs
)
return
register
return
register
else
:
else
:
name
=
kwargs
.
pop
(
"name"
,
None
)
or
lopt
.
__name__
name
=
kwargs
.
pop
(
"name"
,
None
)
or
node_rewriter
.
__name__
compile
.
optdb
[
"specialize"
]
.
register
(
name
,
lopt
,
"fast_run"
,
*
tags
,
**
kwargs
)
compile
.
optdb
[
"specialize"
]
.
register
(
return
lopt
name
,
node_rewriter
,
"fast_run"
,
*
tags
,
**
kwargs
)
return
node_rewriter
def
register_uncanonicalize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_uncanonicalize
(
if
isinstance
(
lopt
,
str
):
node_rewriter
:
Union
[
OptimizationDatabase
,
NodeRewriter
,
str
],
*
tags
:
str
,
**
kwargs
):
if
isinstance
(
node_rewriter
,
str
):
def
register
(
inner_lopt
):
def
register
(
inner_rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
]):
return
register_uncanonicalize
(
inner_lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register_uncanonicalize
(
inner_rewriter
,
node_rewriter
,
*
tags
,
**
kwargs
)
return
register
return
register
else
:
else
:
name
=
(
kwargs
and
kwargs
.
pop
(
"name"
,
None
))
or
lopt
.
__name__
name
=
(
kwargs
and
kwargs
.
pop
(
"name"
,
None
))
or
node_rewriter
.
__name__
compile
.
optdb
[
"uncanonicalize"
]
.
register
(
compile
.
optdb
[
"uncanonicalize"
]
.
register
(
name
,
lopt
,
"fast_run"
,
*
tags
,
**
kwargs
name
,
node_rewriter
,
"fast_run"
,
*
tags
,
**
kwargs
)
)
return
lopt
return
node_rewriter
def
register_specialize_device
(
lopt
,
*
tags
,
**
kwargs
):
def
register_specialize_device
(
if
isinstance
(
lopt
,
str
):
node_rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
,
str
],
*
tags
:
str
,
**
kwargs
):
if
isinstance
(
node_rewriter
,
str
):
def
register
(
inner_lopt
):
def
register
(
inner_rewriter
:
Union
[
OptimizationDatabase
,
Rewriter
]):
return
register_specialize_device
(
inner_lopt
,
lopt
,
*
tags
,
**
kwargs
)
return
register_specialize_device
(
inner_rewriter
,
node_rewriter
,
*
tags
,
**
kwargs
)
return
register
return
register
else
:
else
:
name
=
(
kwargs
and
kwargs
.
pop
(
"name"
,
None
))
or
lopt
.
__name__
name
=
(
kwargs
and
kwargs
.
pop
(
"name"
,
None
))
or
node_rewriter
.
__name__
compile
.
optdb
[
"specialize_device"
]
.
register
(
compile
.
optdb
[
"specialize_device"
]
.
register
(
name
,
lopt
,
"fast_run"
,
*
tags
,
**
kwargs
name
,
node_rewriter
,
"fast_run"
,
*
tags
,
**
kwargs
)
)
return
lopt
return
node_rewriter
def
apply_local_dimshuffle_lift
(
fgraph
,
var
):
def
apply_local_dimshuffle_lift
(
fgraph
,
var
):
...
@@ -762,19 +782,17 @@ pprint.assign(MakeVector, MakeVectorPrinter())
...
@@ -762,19 +782,17 @@ pprint.assign(MakeVector, MakeVectorPrinter())
class
ShapeFeature
(
Feature
):
class
ShapeFeature
(
Feature
):
"""Graph optimizer for removing all calls to shape()
.
r"""A `Feature` that tracks shape information in a graph
.
This
optimizer replaces all Shapes and Subtensors of Shape
s with
This
`Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\
s with
Shape_i and MakeVector Op
s.
`Shape_i` and `MakeVector` `Op`\
s.
This optimizer has several goals:
This `Feature` and its associated rewrites have several goals:
1. to 'lift' Shapes to as close to the inputs as possible.
1. to "lift" `Shape`\s to as close to the inputs as possible,
2. to infer the shape of every node in the graph in terms of the
2. to infer the shape of every node in the graph in terms of the
input shapes.
input shapes, and
3. remove fill `Op`\s (e.g. `Second`) from the graph.
3. remove all fills ``(at.second, at.fill)`` from the graph
Lifting shapes as close to the inputs as possible is important for
Lifting shapes as close to the inputs as possible is important for
canonicalization because it is very bad form to have to compute
canonicalization because it is very bad form to have to compute
...
@@ -782,7 +800,7 @@ class ShapeFeature(Feature):
...
@@ -782,7 +800,7 @@ class ShapeFeature(Feature):
of time to compute such outputs. But it is important to get rid
of time to compute such outputs. But it is important to get rid
of these outputs as early as possible in the compilation process
of these outputs as early as possible in the compilation process
because the extra computations make it appear as if many internal
because the extra computations make it appear as if many internal
graph nodes have multiple clients. Many
optimization
s refuse to
graph nodes have multiple clients. Many
rewrite
s refuse to
work on nodes with multiple clients.
work on nodes with multiple clients.
Lifting is done by using an `<Op>.infer_shape` function if one is
Lifting is done by using an `<Op>.infer_shape` function if one is
...
@@ -802,7 +820,7 @@ class ShapeFeature(Feature):
...
@@ -802,7 +820,7 @@ class ShapeFeature(Feature):
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),).
Inferring the shape of internal nodes in the graph is important
Inferring the shape of internal nodes in the graph is important
for doing size-driven
optimization
s. If we know how big various
for doing size-driven
rewrite
s. If we know how big various
intermediate results will be, we can estimate the cost of many Ops
intermediate results will be, we can estimate the cost of many Ops
accurately, and generate c-code that is specific [e.g. unrolled]
accurately, and generate c-code that is specific [e.g. unrolled]
to particular sizes.
to particular sizes.
...
@@ -818,14 +836,12 @@ class ShapeFeature(Feature):
...
@@ -818,14 +836,12 @@ class ShapeFeature(Feature):
shape, either via a .tag or some similar hacking; and 2) to
shape, either via a .tag or some similar hacking; and 2) to
add an optional In() argument to promise that inputs will
add an optional In() argument to promise that inputs will
have a certain shape (or even to have certain shapes in
have a certain shape (or even to have certain shapes in
certain dimensions). We can't automatically infer the shape of
certain dimensions).
shared variables as they can change of shape during the
execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC)
**Using Shape information in Optimizations**
We can't automatically infer the shape of shared variables as they can
change of shape during the execution by default.
To use this shape information in
OPTIMIZATIONS
, use the
To use this shape information in
rewrites
, use the
``shape_of`` dictionary.
``shape_of`` dictionary.
For example:
For example:
...
@@ -888,10 +904,10 @@ class ShapeFeature(Feature):
...
@@ -888,10 +904,10 @@ class ShapeFeature(Feature):
return
o_shapes
return
o_shapes
def
get_shape
(
self
,
var
,
idx
):
def
get_shape
(
self
,
var
,
idx
):
"""
Optimization can call this to get the current shape_i
"""
Rewrites can call this to get a `Shape_i`.
It is better to call this then use directly
shape_of[var][idx]
It is better to call this then use directly
``shape_of[var][idx]``
as this method should update
shape_of
if needed.
as this method should update
`shape_of`
if needed.
TODO: Up to now, we don't update it in all cases. Update in all cases.
TODO: Up to now, we don't update it in all cases. Update in all cases.
"""
"""
...
@@ -977,11 +993,9 @@ class ShapeFeature(Feature):
...
@@ -977,11 +993,9 @@ class ShapeFeature(Feature):
error reporting.
error reporting.
"""
"""
# unpack the s_i that the Op returned
assert
s_i
is
not
None
assert
s_i
is
not
None
if
s_i
==
1
:
if
s_i
==
1
:
# don't make the optimizer merge a zillion ones together
# by always returning the same object to represent 1
return
self
.
lscalar_one
return
self
.
lscalar_one
if
isinstance
(
s_i
,
float
)
and
int
(
s_i
)
==
s_i
:
if
isinstance
(
s_i
,
float
)
and
int
(
s_i
)
==
s_i
:
s_i
=
int
(
s_i
)
s_i
=
int
(
s_i
)
...
@@ -1080,10 +1094,9 @@ class ShapeFeature(Feature):
...
@@ -1080,10 +1094,9 @@ class ShapeFeature(Feature):
else
:
else
:
shape_vars
.
append
(
self
.
unpack
(
s
[
i
],
r
))
shape_vars
.
append
(
self
.
unpack
(
s
[
i
],
r
))
assert
all
(
assert
all
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
i
]
or
not
hasattr
(
r
.
type
,
"broadcastable"
)
# The two following comparison are a speed optimization
or
not
r
.
type
.
broadcastable
[
i
]
# But we never timed this speed optimization!
or
self
.
lscalar_one
.
equals
(
shape_vars
[
i
])
self
.
lscalar_one
.
equals
(
shape_vars
[
i
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
shape_vars
[
i
]))
or
self
.
lscalar_one
.
equals
(
extract_constant
(
shape_vars
[
i
]))
for
i
in
range
(
r
.
type
.
ndim
)
for
i
in
range
(
r
.
type
.
ndim
)
)
)
...
@@ -1118,9 +1131,9 @@ class ShapeFeature(Feature):
...
@@ -1118,9 +1131,9 @@ class ShapeFeature(Feature):
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
inputs
==
r
.
owner
.
inputs
and
other_r
.
owner
.
op
==
r
.
owner
.
op
and
other_r
.
owner
.
op
==
r
.
owner
.
op
):
):
# We are doing a merge
. So the 2 shapes graph
will be the
# We are doing a merge
, so the two shape graphs
will be the
# same. This is only
a speed optimization to call
# same. This is only
done so that we call `ancestors` less
#
ancestors() less
frequently.
# frequently.
return
return
# Merge other_shape with r_shape, giving the priority to other_shape
# Merge other_shape with r_shape, giving the priority to other_shape
...
@@ -1168,10 +1181,7 @@ class ShapeFeature(Feature):
...
@@ -1168,10 +1181,7 @@ class ShapeFeature(Feature):
or
not
r
.
type
.
broadcastable
[
i
]
or
not
r
.
type
.
broadcastable
[
i
]
and
not
other_r
.
type
.
broadcastable
[
i
]
and
not
other_r
.
type
.
broadcastable
[
i
]
)
)
or
or
self
.
lscalar_one
.
equals
(
merged_shape
[
i
])
# The two following comparison are a speed optimization
# But we never timed this speed optimization!
self
.
lscalar_one
.
equals
(
merged_shape
[
i
])
or
self
.
lscalar_one
.
equals
(
or
self
.
lscalar_one
.
equals
(
extract_constant
(
merged_shape
[
i
],
only_process_constants
=
True
)
extract_constant
(
merged_shape
[
i
],
only_process_constants
=
True
)
)
)
...
@@ -1194,10 +1204,9 @@ class ShapeFeature(Feature):
...
@@ -1194,10 +1204,9 @@ class ShapeFeature(Feature):
else
:
else
:
new_shape
.
append
(
s_j
)
new_shape
.
append
(
s_j
)
assert
all
(
assert
all
(
not
hasattr
(
r
.
type
,
"broadcastable"
)
or
not
r
.
type
.
broadcastable
[
idx
]
or
not
hasattr
(
r
.
type
,
"broadcastable"
)
# The two following comparison are a speed optimization
or
not
r
.
type
.
broadcastable
[
idx
]
# But we never timed this speed optimization!
or
self
.
lscalar_one
.
equals
(
new_shape
[
idx
])
self
.
lscalar_one
.
equals
(
new_shape
[
idx
])
or
self
.
lscalar_one
.
equals
(
extract_constant
(
new_shape
[
idx
]))
or
self
.
lscalar_one
.
equals
(
extract_constant
(
new_shape
[
idx
]))
for
idx
in
range
(
r
.
type
.
ndim
)
for
idx
in
range
(
r
.
type
.
ndim
)
)
)
...
@@ -1273,7 +1282,7 @@ class ShapeFeature(Feature):
...
@@ -1273,7 +1282,7 @@ class ShapeFeature(Feature):
)
)
# Ensure shapes are in 'int64'. This is to make sure the assert
# Ensure shapes are in 'int64'. This is to make sure the assert
# found in the `local_useless_subtensor`
optimization
does not fail.
# found in the `local_useless_subtensor`
rewrite
does not fail.
for
sh_idx
,
sh
in
enumerate
(
o_shapes
):
for
sh_idx
,
sh
in
enumerate
(
o_shapes
):
if
sh
is
None
:
if
sh
is
None
:
continue
continue
...
@@ -1444,7 +1453,7 @@ class ShapeFeature(Feature):
...
@@ -1444,7 +1453,7 @@ class ShapeFeature(Feature):
class
ShapeOptimizer
(
GraphRewriter
):
class
ShapeOptimizer
(
GraphRewriter
):
"""
Optimiz
er that adds `ShapeFeature` as a feature."""
"""
Rewrit
er that adds `ShapeFeature` as a feature."""
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ShapeFeature
())
fgraph
.
attach_feature
(
ShapeFeature
())
...
@@ -1454,7 +1463,7 @@ class ShapeOptimizer(GraphRewriter):
...
@@ -1454,7 +1463,7 @@ class ShapeOptimizer(GraphRewriter):
class
UnShapeOptimizer
(
GraphRewriter
):
class
UnShapeOptimizer
(
GraphRewriter
):
"""
Optimiz
er that removes `ShapeFeature` as a feature."""
"""
Rewrit
er that removes `ShapeFeature` as a feature."""
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
for
feature
in
fgraph
.
_features
:
for
feature
in
fgraph
.
_features
:
...
@@ -1528,7 +1537,7 @@ def local_elemwise_alloc(fgraph, node):
...
@@ -1528,7 +1537,7 @@ def local_elemwise_alloc(fgraph, node):
for
idx
,
i
in
enumerate
(
node
.
inputs
):
for
idx
,
i
in
enumerate
(
node
.
inputs
):
if
i
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
if
i
.
type
.
broadcastable
==
node
.
outputs
[
0
]
.
type
.
broadcastable
:
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc`, so that all `Alloc`s can be
optimized
.
# `Alloc`, so that all `Alloc`s can be
rewritten
.
if
idx
not
in
alloc_idxs
:
if
idx
not
in
alloc_idxs
:
ref_var_idx
=
idx
ref_var_idx
=
idx
break
break
...
@@ -1585,7 +1594,7 @@ def local_elemwise_alloc(fgraph, node):
...
@@ -1585,7 +1594,7 @@ def local_elemwise_alloc(fgraph, node):
new_inputs
[
idx
]
=
new_alloc
new_inputs
[
idx
]
=
new_alloc
# If this assert is triggered, it means we are recreating an equivalent graph
# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in
a cyclical merge optimization
.
# which would result in
cyclical merge rewrites
.
if
all
(
new
is
old
for
new
,
old
in
zip
(
new_inputs
,
node
.
inputs
)):
if
all
(
new
is
old
for
new
,
old
in
zip
(
new_inputs
,
node
.
inputs
)):
return
return
...
@@ -1702,9 +1711,9 @@ compile.optdb.register(
...
@@ -1702,9 +1711,9 @@ compile.optdb.register(
def
local_useless_fill
(
fgraph
,
node
):
def
local_useless_fill
(
fgraph
,
node
):
"""fill(s,v) -> v
"""fill(s,v) -> v
This
optimization is only needed in FAST_COMPILE
to make the code
This
rewrite is only needed in FAST_COMPILE mode
to make the code
more readable. Normally, it is done by the
local_fill_to_alloc
more readable. Normally, it is done by the
`local_fill_to_alloc`
opt
.
rewrite
.
"""
"""
r
,
v
=
node
.
inputs
r
,
v
=
node
.
inputs
...
@@ -1789,9 +1798,9 @@ def local_alloc_sink_dimshuffle(fgraph, node):
...
@@ -1789,9 +1798,9 @@ def local_alloc_sink_dimshuffle(fgraph, node):
def
local_alloc_empty_to_zeros
(
fgraph
,
node
):
def
local_alloc_empty_to_zeros
(
fgraph
,
node
):
"""This convert AllocEmpty to Alloc of 0.
"""This convert AllocEmpty to Alloc of 0.
This help
investigate NaN with NanGuardMode
. Not registered by
This help
s one investigate NaNs in `NanGuardMode`
. Not registered by
default. To activate it, use the
Aesara fla
g
default. To activate it, use the
settin
g
optimizer_including=alloc_empty_to_zeros
.
``optimizer_including == alloc_empty_to_zeros``
.
"""
"""
if
isinstance
(
node
.
op
,
AllocEmpty
):
if
isinstance
(
node
.
op
,
AllocEmpty
):
return
[
zeros
(
node
.
inputs
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)]
return
[
zeros
(
node
.
inputs
,
dtype
=
node
.
outputs
[
0
]
.
dtype
)]
...
@@ -1811,7 +1820,6 @@ compile.optdb.register(
...
@@ -1811,7 +1820,6 @@ compile.optdb.register(
@node_rewriter
([
Shape
])
@node_rewriter
([
Shape
])
def
local_shape_to_shape_i
(
fgraph
,
node
):
def
local_shape_to_shape_i
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
Shape
):
if
isinstance
(
node
.
op
,
Shape
):
# This optimization needs ShapeOpt and fgraph.shape_feature
if
not
hasattr
(
fgraph
,
"shape_feature"
):
if
not
hasattr
(
fgraph
,
"shape_feature"
):
return
return
shape_feature
=
fgraph
.
shape_feature
shape_feature
=
fgraph
.
shape_feature
...
@@ -1850,16 +1858,18 @@ def local_track_shape_i(fgraph, node):
...
@@ -1850,16 +1858,18 @@ def local_track_shape_i(fgraph, node):
@node_rewriter
([
Elemwise
])
@node_rewriter
([
Elemwise
])
def
local_useless_elemwise
(
fgraph
,
node
):
def
local_useless_elemwise
(
fgraph
,
node
):
"""
"""
eq(x, x) -> 1
eq(x, x) -> 1
neq(x, x) -> 0
neq(x, x) -> 0
mul(x) -> x
mul(x) -> x
add(x) -> x
add(x) -> x
identity(x) -> x
identity(x) -> x
and(x, 1) -> x (if x.dtype == 'bool')
and(x, 1) -> x (if x.dtype == 'bool')
and(x, 0) -> zeros_like(x)
and(x, 0) -> zeros_like(x)
or(x, 0) -> x
or(x, 0) -> x
or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
or(x, 1) -> ones_like(x) (if x.dtype == 'bool')
xor(x, x) -> zeros_like(x)
xor(x, x) -> zeros_like(x)
TODO: This implementation is painfully redundant.
"""
"""
if
isinstance
(
node
.
op
,
Elemwise
):
if
isinstance
(
node
.
op
,
Elemwise
):
...
@@ -1905,7 +1915,7 @@ def local_useless_elemwise(fgraph, node):
...
@@ -1905,7 +1915,7 @@ def local_useless_elemwise(fgraph, node):
return
[
zeros_like
(
node
.
inputs
[
1
],
dtype
=
dtype
,
opt
=
True
)]
return
[
zeros_like
(
node
.
inputs
[
1
],
dtype
=
dtype
,
opt
=
True
)]
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
# If the output is not Boolean, it is the bitwise AND,
# If the output is not Boolean, it is the bitwise AND,
# and this
optimization
would be wrong
# and this
rewrite
would be wrong
return
[
node
.
inputs
[
1
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
return
[
node
.
inputs
[
1
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
if
isinstance
(
node
.
inputs
[
1
],
TensorConstant
):
if
isinstance
(
node
.
inputs
[
1
],
TensorConstant
):
...
@@ -1917,7 +1927,7 @@ def local_useless_elemwise(fgraph, node):
...
@@ -1917,7 +1927,7 @@ def local_useless_elemwise(fgraph, node):
return
[
zeros_like
(
node
.
inputs
[
0
],
dtype
=
dtype
,
opt
=
True
)]
return
[
zeros_like
(
node
.
inputs
[
0
],
dtype
=
dtype
,
opt
=
True
)]
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
# If the output is not Boolean, it is the bitwise AND,
# If the output is not Boolean, it is the bitwise AND,
# and this
optimization
would be wrong
# and this
rewrite
would be wrong
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
elif
isinstance
(
node
.
op
.
scalar_op
,
aes
.
OR
)
and
len
(
node
.
inputs
)
==
2
:
elif
isinstance
(
node
.
op
.
scalar_op
,
aes
.
OR
)
and
len
(
node
.
inputs
)
==
2
:
...
@@ -1931,7 +1941,7 @@ def local_useless_elemwise(fgraph, node):
...
@@ -1931,7 +1941,7 @@ def local_useless_elemwise(fgraph, node):
return
[
node
.
inputs
[
1
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
return
[
node
.
inputs
[
1
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
# If the output is not Boolean, it is the bitwise OR,
# If the output is not Boolean, it is the bitwise OR,
# and this
optimization
would be wrong
# and this
rewrite
would be wrong
return
[
ones_like
(
node
.
inputs
[
1
],
dtype
=
dtype
,
opt
=
True
)]
return
[
ones_like
(
node
.
inputs
[
1
],
dtype
=
dtype
,
opt
=
True
)]
if
isinstance
(
node
.
inputs
[
1
],
TensorConstant
):
if
isinstance
(
node
.
inputs
[
1
],
TensorConstant
):
...
@@ -1943,7 +1953,7 @@ def local_useless_elemwise(fgraph, node):
...
@@ -1943,7 +1953,7 @@ def local_useless_elemwise(fgraph, node):
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
return
[
node
.
inputs
[
0
]
.
astype
(
node
.
outputs
[
0
]
.
dtype
)]
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
elif
node
.
outputs
[
0
]
.
dtype
==
"bool"
:
# If the output is not Boolean, it is the bitwise OR,
# If the output is not Boolean, it is the bitwise OR,
# and this
optimization
would be wrong
# and this
rewrite
would be wrong
return
[
ones_like
(
node
.
inputs
[
0
],
dtype
=
dtype
,
opt
=
True
)]
return
[
ones_like
(
node
.
inputs
[
0
],
dtype
=
dtype
,
opt
=
True
)]
elif
isinstance
(
node
.
op
.
scalar_op
,
aes
.
XOR
)
and
len
(
node
.
inputs
)
==
2
:
elif
isinstance
(
node
.
op
.
scalar_op
,
aes
.
XOR
)
and
len
(
node
.
inputs
)
==
2
:
...
@@ -2081,12 +2091,11 @@ def local_remove_useless_assert(fgraph, node):
...
@@ -2081,12 +2091,11 @@ def local_remove_useless_assert(fgraph, node):
@node_rewriter
([
Assert
])
@node_rewriter
([
Assert
])
def
local_remove_all_assert
(
fgraph
,
node
):
def
local_remove_all_assert
(
fgraph
,
node
):
"""An optimization disabled by default that removes all asserts from
r"""A rewrite that removes all `Assert`\s from a graph.
the graph.
Notes
Notes
-----
-----
See the :ref:`unsafe` section
to know how to enable it
.
See the :ref:`unsafe` section.
"""
"""
if
not
isinstance
(
node
.
op
,
Assert
):
if
not
isinstance
(
node
.
op
,
Assert
):
...
@@ -2346,7 +2355,7 @@ def local_join_make_vector(fgraph, node):
...
@@ -2346,7 +2355,7 @@ def local_join_make_vector(fgraph, node):
Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...)
This
in combination with the `local_join_1` optimization
can make `Join`\s
This
, in combination with the `local_join_1` rewrite,
can make `Join`\s
completely disappear.
completely disappear.
"""
"""
if
not
isinstance
(
node
.
op
,
Join
)
or
node
.
outputs
[
0
]
.
ndim
!=
1
:
if
not
isinstance
(
node
.
op
,
Join
)
or
node
.
outputs
[
0
]
.
ndim
!=
1
:
...
@@ -2388,16 +2397,16 @@ def local_join_make_vector(fgraph, node):
...
@@ -2388,16 +2397,16 @@ def local_join_make_vector(fgraph, node):
@node_rewriter
([
Elemwise
])
@node_rewriter
([
Elemwise
])
def
local_useless_switch
(
fgraph
,
node
):
def
local_useless_switch
(
fgraph
,
node
):
"""
"""
This
optimization makes the following changes in the
graph:
This
rewrite makes the following changes in a
graph:
``at.switch(cond, left, right)``
->
at.switch(cond, left, right)
->
``if cond is constant and cond == 0``
: right
if cond is constant and cond == 0
: right
``if cond is constant and cond != 0``
: left
if cond is constant and cond != 0
: left
``if left is right`` -> ``left``
if left is right -> left
and
and
``at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X))`` -> ``shape_i{id}(X)``
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
"""
if
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Switch
):
if
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
node
.
op
.
scalar_op
,
aes
.
Switch
):
...
@@ -2545,7 +2554,7 @@ def local_reshape_chain(op):
...
@@ -2545,7 +2554,7 @@ def local_reshape_chain(op):
# replaced the shape by one for which this cannot be guessed.
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# constant value... but in the meantime, better not apply this
#
optimization
.
#
rewrite
.
if
rval
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
if
rval
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
:
return
[
rval
]
return
[
rval
]
else
:
else
:
...
@@ -2709,10 +2718,12 @@ def local_reshape_to_dimshuffle(fgraph, node):
...
@@ -2709,10 +2718,12 @@ def local_reshape_to_dimshuffle(fgraph, node):
@node_rewriter
([
Reshape
])
@node_rewriter
([
Reshape
])
def
local_reshape_lift
(
fgraph
,
node
):
def
local_reshape_lift
(
fgraph
,
node
):
"""
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
This optimization is needed by optimization
Notes
log1msigm_to_softplus to get applied when there is a reshape.
-----
This rewrite is needed by `log1msigm_to_softplus` in order to get applied
when there is a reshape.
"""
"""
if
(
if
(
...
@@ -2840,10 +2851,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
...
@@ -2840,10 +2851,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
The number of dimensions is validated at call time by Aesara itself.
The number of dimensions is validated at call time by Aesara itself.
"""
"""
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
# TODO: use broadcast flag?
# TODO: use broadcast flag?
# TODO: don't do this
optimization as a localOptimizer
.
# TODO: don't do this
rewrite as a `NodeRewriter`
.
# Analyze the graph in terms of elemwise subgraphs, and then
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
# replace each subgraph with a Composite version.
...
@@ -2851,8 +2861,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
...
@@ -2851,8 +2861,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
# fit within the parameter space of 256 bytes
# fit within the parameter space of 256 bytes
#
#
# TODO: Merge with multiple output to merge when an inputs
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a local
# have multiple clients. This can't be done with a `NodeRewriter`
# optimiser.
# TODO: Related: Support composites with multiple outputs
# TODO: Related: Support composites with multiple outputs
...
@@ -2963,7 +2972,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
...
@@ -2963,7 +2972,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
except
(
NotImplementedError
,
MethodNotDefined
):
except
(
NotImplementedError
,
MethodNotDefined
):
_logger
.
warning
(
_logger
.
warning
(
(
(
"
Optimization W
arning: "
"
Rewrite w
arning: "
f
"The Op {i.owner.op.scalar_op} does not provide a C implementation."
f
"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
" As well as being potentially slow, this also disables "
"loop fusion."
"loop fusion."
...
@@ -3015,10 +3024,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
...
@@ -3015,10 +3024,9 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
return
False
return
False
if
new_nb_input
!=
len
(
inputs
)
or
len
(
s_inputs
)
!=
len
(
inputs
):
if
new_nb_input
!=
len
(
inputs
)
or
len
(
s_inputs
)
!=
len
(
inputs
):
# TODO FIXME: This shouldn't be a generic `Exception`
raise
Exception
(
raise
Exception
(
"""Something has gone wrong with the elemwise
"Something has gone wrong with the elemwise fusion rewrite; skipping."
fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower."""
)
)
s_new_out
=
node
.
op
.
scalar_op
(
*
s_g
,
return_list
=
True
)
s_new_out
=
node
.
op
.
scalar_op
(
*
s_g
,
return_list
=
True
)
...
@@ -3034,7 +3042,7 @@ your code will run correctly, but may be slower."""
...
@@ -3034,7 +3042,7 @@ your code will run correctly, but may be slower."""
name
=
str
(
s_new_out
[
0
]
.
owner
.
op
)
name
=
str
(
s_new_out
[
0
]
.
owner
.
op
)
_logger
.
warning
(
_logger
.
warning
(
(
(
"
Optimization W
arning: "
"
Rewrite w
arning: "
f
"The Op {name} does not provide a C implementation."
f
"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
" As well as being potentially slow, this also disables "
"loop fusion."
"loop fusion."
...
@@ -3086,15 +3094,15 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
...
@@ -3086,15 +3094,15 @@ local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fc
class
FusionOptimizer
(
GraphRewriter
):
class
FusionOptimizer
(
GraphRewriter
):
"""Graph
optimizer that simply runs local
fusion operations.
"""Graph
rewriter that simply runs node
fusion operations.
TODO: This is basically a `EquilibriumOptimizer`; we should just use that.
TODO: This is basically a
n
`EquilibriumOptimizer`; we should just use that.
"""
"""
def
__init__
(
self
,
node_rewriter
):
def
__init__
(
self
,
node_rewriter
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
optimiz
er
=
node_rewriter
self
.
node_rewrit
er
=
node_rewriter
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
fgraph
.
attach_feature
(
ReplaceValidate
())
...
@@ -3118,7 +3126,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -3118,7 +3126,7 @@ class FusionOptimizer(GraphRewriter):
for
node
in
nodelist
:
for
node
in
nodelist
:
# Don't try to fuse node that have already been fused.
# Don't try to fuse node that have already been fused.
if
node
in
fgraph
.
apply_nodes
:
if
node
in
fgraph
.
apply_nodes
:
new_outputs
=
self
.
optimiz
er
(
fgraph
,
node
)
new_outputs
=
self
.
node_rewrit
er
(
fgraph
,
node
)
if
new_outputs
:
if
new_outputs
:
assert
len
(
new_outputs
)
==
len
(
node
.
outputs
)
assert
len
(
new_outputs
)
==
len
(
node
.
outputs
)
try
:
try
:
...
@@ -3174,7 +3182,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -3174,7 +3182,7 @@ class FusionOptimizer(GraphRewriter):
if
config
.
tensor__local_elemwise_fusion
:
if
config
.
tensor__local_elemwise_fusion
:
_logger
.
debug
(
"Enabling Elemwise fusion
optimization
s in fast_run"
)
_logger
.
debug
(
"Enabling Elemwise fusion
rewriter
s in fast_run"
)
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt
=
SequenceDB
()
fuse_seqopt
=
SequenceDB
()
fuse_seqopt
.
register
(
fuse_seqopt
.
register
(
...
@@ -3194,7 +3202,7 @@ if config.tensor__local_elemwise_fusion:
...
@@ -3194,7 +3202,7 @@ if config.tensor__local_elemwise_fusion:
position
=
49
,
position
=
49
,
)
)
else
:
else
:
_logger
.
debug
(
"
not enabling optimization fusion elemwise
in fast_run"
)
_logger
.
debug
(
"
Not enabling Elemwise fusion rewriters
in fast_run"
)
compile
.
optdb
.
register
(
compile
.
optdb
.
register
(
"elemwise_fusion"
,
"elemwise_fusion"
,
FusionOptimizer
(
local_elemwise_fusion
),
FusionOptimizer
(
local_elemwise_fusion
),
...
@@ -3239,10 +3247,14 @@ def local_view_op(fgraph, node):
...
@@ -3239,10 +3247,14 @@ def local_view_op(fgraph, node):
@register_specialize
@register_specialize
@node_rewriter
([
Alloc
])
@node_rewriter
([
Alloc
])
def
local_merge_alloc
(
fgraph
,
node
):
def
local_merge_alloc
(
fgraph
,
node
):
# This opt takes care of several cases:
"""
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
This rewriter takes care of the following cases:
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w)
"""
if
not
isinstance
(
node
.
op
,
Alloc
):
if
not
isinstance
(
node
.
op
,
Alloc
):
return
False
return
False
if
not
node
.
inputs
[
0
]
.
owner
or
not
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Alloc
):
if
not
node
.
inputs
[
0
]
.
owner
or
not
isinstance
(
node
.
inputs
[
0
]
.
owner
.
op
,
Alloc
):
...
@@ -3276,11 +3288,7 @@ def local_merge_alloc(fgraph, node):
...
@@ -3276,11 +3288,7 @@ def local_merge_alloc(fgraph, node):
@register_useless
(
"fast_compile"
)
@register_useless
(
"fast_compile"
)
@node_rewriter
([
TopKOp
])
@node_rewriter
([
TopKOp
])
def
local_useless_topk
(
fgraph
,
node
):
def
local_useless_topk
(
fgraph
,
node
):
"""
"""Remove unused `TopKOp` outputs."""
TopKOp generates two outputs by default
This opt removes the useless ones
"""
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
TopKOp
):
if
not
isinstance
(
op
,
TopKOp
):
return
return
...
...
aesara/tensor/nnet/basic.py
浏览文件 @
0ce6eceb
...
@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
...
@@ -1849,16 +1849,6 @@ crossentropy_categorical_1hot = CrossentropyCategorical1Hot()
@register_specialize
(
"fast_compile"
)
@register_specialize
(
"fast_compile"
)
@optimizer
@optimizer
def
crossentropy_to_crossentropy_with_softmax_with_bias
(
fgraph
):
def
crossentropy_to_crossentropy_with_softmax_with_bias
(
fgraph
):
"""
This is a stabilization optimization.
Notes
-----
Not a local optimization because we are replacing outputs
from several nodes at once.
"""
def
search_make_one_sub
():
def
search_make_one_sub
():
for
node
in
fgraph
.
toposort
():
for
node
in
fgraph
.
toposort
():
if
node
.
op
==
crossentropy_categorical_1hot
:
if
node
.
op
==
crossentropy_categorical_1hot
:
...
@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
...
@@ -1887,18 +1877,13 @@ def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph):
@optimizer
@optimizer
def
crossentropy_to_crossentropy_with_softmax
(
fgraph
):
def
crossentropy_to_crossentropy_with_softmax
(
fgraph
):
"""
"""
This is a stabilization optimization that is more general than
This is a stabilization rewrite that is more general than
crossentropy_to_crossentropy_with_softmax_with_bias.
`crossentropy_to_crossentropy_with_softmax_with_bias`.
It must be executed after local_softmax_with_bias optimization in
specialize.
TODO : This is a stabilization optimization! How to make this more cleanly?
Notes
Notes
-----
-----
Not a local optimization because we are replacing outputs from several
It must be executed after `local_softmax_with_bias` during the
nodes at once
.
specialization passes
.
"""
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论