Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
550a6e98
提交
550a6e98
authored
7月 14, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename LocalOptimizer to NodeRewriter
上级
214ef4cf
显示空白字符变更
内嵌
并排
正在显示
29 个修改的文件
包含
479 行增加
和
444 行删除
+479
-444
builders.py
aesara/compile/builders.py
+2
-2
__init__.py
aesara/graph/__init__.py
+1
-1
kanren.py
aesara/graph/kanren.py
+2
-2
opt.py
aesara/graph/opt.py
+183
-142
optdb.py
aesara/graph/optdb.py
+8
-14
ifelse.py
aesara/ifelse.py
+7
-7
ops.py
aesara/sandbox/linalg/ops.py
+8
-8
rng_mrg.py
aesara/sandbox/rng_mrg.py
+2
-2
opt.py
aesara/scan/opt.py
+8
-8
opt.py
aesara/sparse/opt.py
+14
-14
basic_opt.py
aesara/tensor/basic_opt.py
+49
-49
blas.py
aesara/tensor/blas.py
+10
-10
blas_c.py
aesara/tensor/blas_c.py
+5
-5
blas_scipy.py
aesara/tensor/blas_scipy.py
+3
-3
math_opt.py
aesara/tensor/math_opt.py
+58
-58
basic.py
aesara/tensor/nnet/basic.py
+10
-10
batchnorm.py
aesara/tensor/nnet/batchnorm.py
+4
-4
conv3d2d.py
aesara/tensor/nnet/conv3d2d.py
+2
-2
ctc.py
aesara/tensor/nnet/ctc.py
+2
-2
opt.py
aesara/tensor/nnet/opt.py
+13
-13
sigm.py
aesara/tensor/nnet/sigm.py
+3
-3
opt_uncanonicalize.py
aesara/tensor/opt_uncanonicalize.py
+7
-7
opt.py
aesara/tensor/random/opt.py
+5
-5
subtensor_opt.py
aesara/tensor/subtensor_opt.py
+27
-27
opt.py
aesara/typed_list/opt.py
+2
-2
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+17
-17
test_debugmode.py
tests/compile/test_debugmode.py
+5
-5
test_opt.py
tests/graph/test_opt.py
+20
-20
test_basic_opt.py
tests/tensor/test_basic_opt.py
+2
-2
没有找到文件。
aesara/compile/builders.py
浏览文件 @
550a6e98
...
@@ -24,7 +24,7 @@ from aesara.graph.basic import (
...
@@ -24,7 +24,7 @@ from aesara.graph.basic import (
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.null_type
import
NullType
from
aesara.graph.null_type
import
NullType
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.opt
import
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
in2out
,
node_rewrit
er
from
aesara.graph.utils
import
MissingInputError
from
aesara.graph.utils
import
MissingInputError
from
aesara.tensor.basic_opt
import
ShapeFeature
from
aesara.tensor.basic_opt
import
ShapeFeature
...
@@ -928,7 +928,7 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -928,7 +928,7 @@ class OpFromGraph(Op, HasInnerGraph):
output
[
0
]
=
variable
output
[
0
]
=
variable
@
local_optimiz
er
([
OpFromGraph
])
@
node_rewrit
er
([
OpFromGraph
])
def
inline_ofg_expansion
(
fgraph
,
node
):
def
inline_ofg_expansion
(
fgraph
,
node
):
"""
"""
This optimization expands internal graph of OpFromGraph.
This optimization expands internal graph of OpFromGraph.
...
...
aesara/graph/__init__.py
浏览文件 @
550a6e98
...
@@ -13,7 +13,7 @@ from aesara.graph.basic import (
...
@@ -13,7 +13,7 @@ from aesara.graph.basic import (
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
local_optimiz
er
,
optimizer
from
aesara.graph.opt
import
node_rewrit
er
,
optimizer
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.graph.optdb
import
OptimizationQuery
...
...
aesara/graph/kanren.py
浏览文件 @
550a6e98
...
@@ -6,11 +6,11 @@ from unification import var
...
@@ -6,11 +6,11 @@ from unification import var
from
unification.variable
import
Var
from
unification.variable
import
Var
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.opt
import
LocalOptimiz
er
from
aesara.graph.opt
import
NodeRewrit
er
from
aesara.graph.unify
import
eval_if_etuple
from
aesara.graph.unify
import
eval_if_etuple
class
KanrenRelationSub
(
LocalOptimiz
er
):
class
KanrenRelationSub
(
NodeRewrit
er
):
r"""A local optimizer that uses `kanren` to match and replace terms.
r"""A local optimizer that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information
See `kanren <https://github.com/pythological/kanren>`__ for more information
...
...
aesara/graph/opt.py
浏览文件 @
550a6e98
...
@@ -48,7 +48,7 @@ FailureCallbackType = Callable[
...
@@ -48,7 +48,7 @@ FailureCallbackType = Callable[
Exception
,
Exception
,
"NavigatorOptimizer"
,
"NavigatorOptimizer"
,
List
[
Tuple
[
Variable
,
None
]],
List
[
Tuple
[
Variable
,
None
]],
"
LocalOptimiz
er"
,
"
NodeRewrit
er"
,
Apply
,
Apply
,
],
],
None
,
None
,
...
@@ -142,13 +142,13 @@ class GraphRewriter(Rewriter):
...
@@ -142,13 +142,13 @@ class GraphRewriter(Rewriter):
)
)
class
LocalOptimiz
er
(
Rewriter
):
class
NodeRewrit
er
(
Rewriter
):
"""A
node-based optimizer
."""
"""A
`Rewriter` that is applied to an `Apply` node
."""
def
tracks
(
self
):
def
tracks
(
self
)
->
Optional
[
Sequence
[
Op
]]
:
"""Return the list of `Op` classes to which this
optimization
applies.
"""Return the list of `Op` classes to which this
rewrite
applies.
Returns ``None`` when the
optimization
applies to all nodes.
Returns ``None`` when the
rewrite
applies to all nodes.
"""
"""
return
None
return
None
...
@@ -162,23 +162,22 @@ class LocalOptimizer(Rewriter):
...
@@ -162,23 +162,22 @@ class LocalOptimizer(Rewriter):
Subclasses should implement this function so that it returns one of the
Subclasses should implement this function so that it returns one of the
following:
following:
- ``False`` to indicate that no optimization can be applied to this `node`;
- ``False`` to indicate that this rewrite cannot be applied to `node`
- A list of `Variable`\s to use in place of the `node`'s current outputs.
- A list of `Variable`\s to use in place of the `node`'s current outputs
- A ``dict`` mapping old `Variable`\s to `Variable`\s.
- A ``dict`` mapping old `Variable`\s to new `Variable`\s
Parameters
Parameters
----------
----------
fgraph
:
fgraph
A `FunctionGraph` containing `node`.
A `FunctionGraph` containing `node`.
node
:
node
An `Apply` node to be
transformed
.
An `Apply` node to be
rewritten
.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
:
FunctionGraph
):
r"""Add required `Feature`\s to `fgraph`."""
r"""Add required `Feature`\s to `fgraph`."""
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
...
@@ -939,9 +938,9 @@ def pre_constant_merge(fgraph, variables):
...
@@ -939,9 +938,9 @@ def pre_constant_merge(fgraph, variables):
return
[
recursive_merge
(
v
)
for
v
in
variables
]
return
[
recursive_merge
(
v
)
for
v
in
variables
]
class
LocalMetaOptimizer
(
LocalOptimiz
er
):
class
LocalMetaOptimizer
(
NodeRewrit
er
):
r"""
r"""
Base class for meta-optimizers that try a set of `
LocalOptimiz
er`\s
Base class for meta-optimizers that try a set of `
NodeRewrit
er`\s
to replace a node and choose the one that executes the fastest.
to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during
If the error `MetaNodeRewriterSkip` is raised during
...
@@ -1058,8 +1057,8 @@ class LocalMetaOptimizer(LocalOptimizer):
...
@@ -1058,8 +1057,8 @@ class LocalMetaOptimizer(LocalOptimizer):
return
time
.
time
()
-
start
return
time
.
time
()
-
start
class
FromFunctionLocalOptimizer
(
LocalOptimiz
er
):
class
FromFunctionLocalOptimizer
(
NodeRewrit
er
):
"""A `
LocalOptimiz
er` constructed from a function."""
"""A `
NodeRewrit
er` constructed from a function."""
def
__init__
(
self
,
fn
,
tracks
=
None
,
requirements
=
()):
def
__init__
(
self
,
fn
,
tracks
=
None
,
requirements
=
()):
self
.
fn
=
fn
self
.
fn
=
fn
...
@@ -1095,7 +1094,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
...
@@ -1095,7 +1094,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
print
(
f
"{' ' * level}{self.transform} id={id(self)}"
,
file
=
stream
)
print
(
f
"{' ' * level}{self.transform} id={id(self)}"
,
file
=
stream
)
def
local_optimiz
er
(
def
node_rewrit
er
(
tracks
:
Optional
[
Sequence
[
Union
[
Op
,
type
]]],
tracks
:
Optional
[
Sequence
[
Union
[
Op
,
type
]]],
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
requirements
:
Optional
[
Tuple
[
type
,
...
]]
=
(),
requirements
:
Optional
[
Tuple
[
type
,
...
]]
=
(),
...
@@ -1150,12 +1149,12 @@ class LocalOptTracker:
...
@@ -1150,12 +1149,12 @@ class LocalOptTracker:
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance."""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
tracked_instances
:
Dict
[
Op
,
List
[
LocalOptimiz
er
]]
=
{}
self
.
tracked_instances
:
Dict
[
Op
,
List
[
NodeRewrit
er
]]
=
{}
self
.
tracked_types
:
Dict
[
type
,
List
[
LocalOptimiz
er
]]
=
{}
self
.
tracked_types
:
Dict
[
type
,
List
[
NodeRewrit
er
]]
=
{}
self
.
untracked_opts
:
List
[
LocalOptimiz
er
]
=
[]
self
.
untracked_opts
:
List
[
NodeRewrit
er
]
=
[]
def
add_tracker
(
self
,
rw
:
LocalOptimiz
er
):
def
add_tracker
(
self
,
rw
:
NodeRewrit
er
):
"""Add a `
LocalOptimizer` to be keyed by its `LocalOptimiz
er.tracks` or applied generally."""
"""Add a `
NodeRewriter` to be keyed by its `NodeRewrit
er.tracks` or applied generally."""
tracks
=
rw
.
tracks
()
tracks
=
rw
.
tracks
()
if
tracks
is
None
:
if
tracks
is
None
:
...
@@ -1167,8 +1166,8 @@ class LocalOptTracker:
...
@@ -1167,8 +1166,8 @@ class LocalOptTracker:
else
:
else
:
self
.
tracked_instances
.
setdefault
(
c
,
[])
.
append
(
rw
)
self
.
tracked_instances
.
setdefault
(
c
,
[])
.
append
(
rw
)
def
_find_impl
(
self
,
cls
)
->
List
[
LocalOptimiz
er
]:
def
_find_impl
(
self
,
cls
)
->
List
[
NodeRewrit
er
]:
r"""Returns the `
LocalOptimiz
er`\s that apply to `cls` based on inheritance.
r"""Returns the `
NodeRewrit
er`\s that apply to `cls` based on inheritance.
This based on `functools._find_impl`.
This based on `functools._find_impl`.
"""
"""
...
@@ -1181,7 +1180,7 @@ class LocalOptTracker:
...
@@ -1181,7 +1180,7 @@ class LocalOptTracker:
return
matches
return
matches
@functools.lru_cache
()
@functools.lru_cache
()
def
get_trackers
(
self
,
op
:
Op
)
->
List
[
LocalOptimiz
er
]:
def
get_trackers
(
self
,
op
:
Op
)
->
List
[
NodeRewrit
er
]:
"""Get all the rewrites applicable to `op`."""
"""Get all the rewrites applicable to `op`."""
return
(
return
(
self
.
_find_impl
(
type
(
op
))
self
.
_find_impl
(
type
(
op
))
...
@@ -1198,8 +1197,8 @@ class LocalOptTracker:
...
@@ -1198,8 +1197,8 @@ class LocalOptTracker:
)
)
class
LocalOptGroup
(
LocalOptimiz
er
):
class
LocalOptGroup
(
NodeRewrit
er
):
r"""An optimizer that applies a list of `
LocalOptimiz
er`\s to a node.
r"""An optimizer that applies a list of `
NodeRewrit
er`\s to a node.
Attributes
Attributes
----------
----------
...
@@ -1390,7 +1389,7 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -1390,7 +1389,7 @@ class LocalOptGroup(LocalOptimizer):
opt
.
add_requirements
(
fgraph
)
opt
.
add_requirements
(
fgraph
)
class
OpSub
(
LocalOptimiz
er
):
class
OpSub
(
NodeRewrit
er
):
"""
"""
Replaces the application of a certain `Op` by the application of
Replaces the application of a certain `Op` by the application of
...
@@ -1440,7 +1439,7 @@ class OpSub(LocalOptimizer):
...
@@ -1440,7 +1439,7 @@ class OpSub(LocalOptimizer):
return
f
"{self.op1} -> {self.op2}"
return
f
"{self.op1} -> {self.op2}"
class
OpRemove
(
LocalOptimiz
er
):
class
OpRemove
(
NodeRewrit
er
):
"""
"""
Removes all applications of an `Op` by transferring each of its
Removes all applications of an `Op` by transferring each of its
outputs to the corresponding input.
outputs to the corresponding input.
...
@@ -1473,7 +1472,7 @@ class OpRemove(LocalOptimizer):
...
@@ -1473,7 +1472,7 @@ class OpRemove(LocalOptimizer):
)
)
class
PatternSub
(
LocalOptimiz
er
):
class
PatternSub
(
NodeRewrit
er
):
"""Replace all occurrences of an input pattern with an output pattern.
"""Replace all occurrences of an input pattern with an output pattern.
The input and output patterns have the following syntax:
The input and output patterns have the following syntax:
...
@@ -1719,44 +1718,20 @@ class Updater(Feature):
...
@@ -1719,44 +1718,20 @@ class Updater(Feature):
class
NavigatorOptimizer
(
GraphRewriter
):
class
NavigatorOptimizer
(
GraphRewriter
):
r"""An optimizer that applies a `
LocalOptimiz
er` with considerations for the new nodes it creates.
r"""An optimizer that applies a `
NodeRewrit
er` with considerations for the new nodes it creates.
This optimizer also allows the `
LocalOptimiz
er` to use a special ``"remove"`` value
This optimizer also allows the `
NodeRewrit
er` to use a special ``"remove"`` value
in the ``dict``\s returned by :meth:`
LocalOptimiz
er`. `Variable`\s mapped to this
in the ``dict``\s returned by :meth:`
NodeRewrit
er`. `Variable`\s mapped to this
value are removed from the `FunctionGraph`.
value are removed from the `FunctionGraph`.
Parameters
----------
local_opt :
A `LocalOptimizer` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees :
- ``True``: new subgraphs returned by an optimization are not a
candidate for optimization.
- ``False``: new subgraphs returned by an optimization is a candidate
for optimization.
- ``'auto'``: let the `local_opt` set this parameter via its :attr:`reentrant`
attribute.
failure_callback
A function with the signature ``(exception, navigator, [(old, new),
(old,new),...])`` that is called when there's an exception.
If the exception is raised in ``local_opt.transform``, the ``new`` variables
will be ``None``.
If the exception is raised during validation (e.g. the new types don't
match) then the new variables will be the ones created by ``self.transform``.
If this parameter is ``None``, then exceptions are not caught here and
are raised normally.
"""
"""
@staticmethod
@staticmethod
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
,
node
):
def
warn
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
):
"""A failure callback that prints a traceback."""
"""A failure callback that prints a traceback."""
if
config
.
on_opt_error
!=
"ignore"
:
if
config
.
on_opt_error
!=
"ignore"
:
_logger
.
error
(
f
"Optimization failure due to: {
local_opt
}"
)
_logger
.
error
(
f
"Optimization failure due to: {
node_rewriter
}"
)
_logger
.
error
(
f
"node: {node}"
)
_logger
.
error
(
f
"node: {node}"
)
_logger
.
error
(
"TRACEBACK:"
)
_logger
.
error
(
"TRACEBACK:"
)
_logger
.
error
(
traceback
.
format_exc
())
_logger
.
error
(
traceback
.
format_exc
())
...
@@ -1768,30 +1743,59 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1768,30 +1743,59 @@ class NavigatorOptimizer(GraphRewriter):
raise
exc
raise
exc
@staticmethod
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
,
node
):
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
):
r"""A failure callback that ignores `
`InconsistencyError`
`\s and prints a traceback.
r"""A failure callback that ignores `
InconsistencyError
`\s and prints a traceback.
If the error occurred during replacement, `
`repl_pairs`
` is set;
If the error occurred during replacement, `
repl_pairs
` is set;
otherwise, its value is ``None``.
otherwise, its value is ``None``.
"""
"""
if
isinstance
(
exc
,
InconsistencyError
):
if
isinstance
(
exc
,
InconsistencyError
):
return
return
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
,
node
)
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
)
@staticmethod
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
,
node
):
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
node_rewriter
,
node
):
"""A failure callback that ignores all errors."""
"""A failure callback that ignores all errors."""
def
__init__
(
def
__init__
(
self
,
self
,
local_opt
:
LocalOptimizer
,
node_rewriter
:
Optional
[
NodeRewriter
]
,
ignore_newtrees
:
Literal
[
True
,
False
,
"auto"
],
ignore_newtrees
:
Literal
[
True
,
False
,
"auto"
],
failure_callback
:
Optional
[
FailureCallbackType
]
=
None
,
failure_callback
:
Optional
[
FailureCallbackType
]
=
None
,
):
):
self
.
local_opt
=
local_opt
"""
Parameters
----------
node_rewriter
A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``).
ignore_newtrees
- ``True``: new subgraphs returned by an optimization are not a
candidate for optimization.
- ``False``: new subgraphs returned by an optimization is a
candidate for optimization.
- ``'auto'``: let the `node_rewriter` set this parameter via its
:attr:`reentrant` attribute.
failure_callback
A function with the signature
``(exception, navigator, [(old, new), (old,new),...])``
that is called when there's an exception.
If the exception is raised in `node_rewriter.transform`, the
``new`` variables will be ``None``.
If the exception is raised during validation (e.g. the new types
don't match) then the new variables will be the ones created by
``self.transform``.
If this parameter is ``None``, then exceptions are not caught here
and are raised normally.
"""
self
.
node_rewriter
=
node_rewriter
if
ignore_newtrees
==
"auto"
:
if
ignore_newtrees
==
"auto"
:
self
.
ignore_newtrees
=
not
getattr
(
local_opt
,
"reentrant"
,
True
)
self
.
ignore_newtrees
=
not
getattr
(
node_rewriter
,
"reentrant"
,
True
)
else
:
else
:
self
.
ignore_newtrees
=
ignore_newtrees
self
.
ignore_newtrees
=
ignore_newtrees
self
.
failure_callback
=
failure_callback
self
.
failure_callback
=
failure_callback
...
@@ -1865,7 +1869,7 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1865,7 +1869,7 @@ class NavigatorOptimizer(GraphRewriter):
node :
node :
An `Apply` instance in `fgraph`
An `Apply` instance in `fgraph`
lopt :
lopt :
A `
LocalOptimiz
er` instance that may have a better idea for
A `
NodeRewrit
er` instance that may have a better idea for
how to compute node's outputs.
how to compute node's outputs.
Returns
Returns
...
@@ -1874,7 +1878,7 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1874,7 +1878,7 @@ class NavigatorOptimizer(GraphRewriter):
``True`` iff the `node`'s outputs were replaced in the `fgraph`.
``True`` iff the `node`'s outputs were replaced in the `fgraph`.
"""
"""
lopt
=
lopt
or
self
.
local_opt
lopt
=
lopt
or
self
.
node_rewriter
try
:
try
:
replacements
=
lopt
.
transform
(
fgraph
,
node
)
replacements
=
lopt
.
transform
(
fgraph
,
node
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -1896,19 +1900,17 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1896,19 +1900,17 @@ class NavigatorOptimizer(GraphRewriter):
replacements
=
list
(
replacements
.
values
())
replacements
=
list
(
replacements
.
values
())
elif
not
isinstance
(
replacements
,
(
tuple
,
list
)):
elif
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
raise
TypeError
(
f
"
Local optimiz
er {lopt} gave wrong type of replacement. "
f
"
Node rewrit
er {lopt} gave wrong type of replacement. "
f
"Expected list or tuple; got {replacements}"
f
"Expected list or tuple; got {replacements}"
)
)
if
len
(
old_vars
)
!=
len
(
replacements
):
if
len
(
old_vars
)
!=
len
(
replacements
):
raise
ValueError
(
raise
ValueError
(
f
"Node rewriter {lopt} gave wrong number of replacements"
)
f
"Local optimizer {lopt} gave wrong number of replacements"
)
# None in the replacement mean that this variable isn't used
# None in the replacement mean that this variable isn't used
# and we want to remove it
# and we want to remove it
for
r
,
rnew
in
zip
(
old_vars
,
replacements
):
for
r
,
rnew
in
zip
(
old_vars
,
replacements
):
if
rnew
is
None
and
len
(
fgraph
.
clients
[
r
])
>
0
:
if
rnew
is
None
and
len
(
fgraph
.
clients
[
r
])
>
0
:
raise
ValueError
(
raise
ValueError
(
f
"
Local optimiz
er {lopt} tried to remove a variable"
f
"
Node rewrit
er {lopt} tried to remove a variable"
f
" that is being used: {r}"
f
" that is being used: {r}"
)
)
# If an output would be replaced by itself, no need to perform
# If an output would be replaced by itself, no need to perform
...
@@ -1939,21 +1941,23 @@ class NavigatorOptimizer(GraphRewriter):
...
@@ -1939,21 +1941,23 @@ class NavigatorOptimizer(GraphRewriter):
super
()
.
add_requirements
(
fgraph
)
super
()
.
add_requirements
(
fgraph
)
# Added by default
# Added by default
# fgraph.attach_feature(ReplaceValidate())
# fgraph.attach_feature(ReplaceValidate())
if
self
.
local_opt
:
if
self
.
node_rewriter
:
self
.
local_opt
.
add_requirements
(
fgraph
)
self
.
node_rewriter
.
add_requirements
(
fgraph
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
f
"{' ' * level}{self.__class__.__name__} id={id(self)}"
,
file
=
stream
)
print
(
f
"{' ' * level}{self.__class__.__name__} id={id(self)}"
,
file
=
stream
)
if
depth
!=
0
:
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
self
.
node_rewriter
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
)
)
class
TopoOptimizer
(
NavigatorOptimizer
):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""An optimizer that applies a single `
LocalOptimiz
er` to each node in topological order (or reverse)."""
"""An optimizer that applies a single `
NodeRewrit
er` to each node in topological order (or reverse)."""
def
__init__
(
def
__init__
(
self
,
self
,
local_opt
:
LocalOptimiz
er
,
node_rewriter
:
NodeRewrit
er
,
order
:
Literal
[
"out_to_in"
,
"in_to_out"
]
=
"in_to_out"
,
order
:
Literal
[
"out_to_in"
,
"in_to_out"
]
=
"in_to_out"
,
ignore_newtrees
:
bool
=
False
,
ignore_newtrees
:
bool
=
False
,
failure_callback
:
Optional
[
FailureCallbackType
]
=
None
,
failure_callback
:
Optional
[
FailureCallbackType
]
=
None
,
...
@@ -1961,7 +1965,7 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -1961,7 +1965,7 @@ class TopoOptimizer(NavigatorOptimizer):
if
order
not
in
(
"out_to_in"
,
"in_to_out"
):
if
order
not
in
(
"out_to_in"
,
"in_to_out"
):
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
self
.
order
=
order
self
.
order
=
order
super
()
.
__init__
(
local_opt
,
ignore_newtrees
,
failure_callback
)
super
()
.
__init__
(
node_rewriter
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
fgraph
,
start_from
=
None
):
def
apply
(
self
,
fgraph
,
start_from
=
None
):
if
start_from
is
None
:
if
start_from
is
None
:
...
@@ -2005,7 +2009,7 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -2005,7 +2009,7 @@ class TopoOptimizer(NavigatorOptimizer):
io_t
,
io_t
,
loop_t
,
loop_t
,
callback_time
,
callback_time
,
self
.
local_opt
,
self
.
node_rewriter
,
)
)
@staticmethod
@staticmethod
...
@@ -2061,22 +2065,26 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -2061,22 +2065,26 @@ class TopoOptimizer(NavigatorOptimizer):
def
topogroup_optimizer
(
def
topogroup_optimizer
(
order
,
*
local_opts
,
name
=
None
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
order
,
*
node_rewriters
,
name
=
None
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
,
):
):
"""Apply `
local_opt
s` from the input/output nodes to the output/input nodes of a graph.
"""Apply `
node_rewriter
s` from the input/output nodes to the output/input nodes of a graph.
This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's
This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's
more than one entry in `
local_opt
s`.
more than one entry in `
node_rewriter
s`.
"""
"""
if
len
(
local_opt
s
)
>
1
:
if
len
(
node_rewriter
s
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
LocalOptGroup
(
*
local_opt
s
)
node_rewriters
=
LocalOptGroup
(
*
node_rewriter
s
)
else
:
else
:
(
local_opts
,)
=
local_opt
s
(
node_rewriters
,)
=
node_rewriter
s
if
not
name
:
if
not
name
:
name
=
local_opt
s
.
__name__
name
=
node_rewriter
s
.
__name__
ret
=
TopoOptimizer
(
ret
=
TopoOptimizer
(
local_opt
s
,
node_rewriter
s
,
order
=
order
,
order
=
order
,
failure_callback
=
failure_callback
,
failure_callback
=
failure_callback
,
**
kwargs
,
**
kwargs
,
...
@@ -2091,9 +2099,9 @@ out2in = partial(topogroup_optimizer, "out_to_in")
...
@@ -2091,9 +2099,9 @@ out2in = partial(topogroup_optimizer, "out_to_in")
class
OpKeyOptimizer
(
NavigatorOptimizer
):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
r"""An optimizer that applies a `
LocalOptimiz
er` to specific `Op`\s.
r"""An optimizer that applies a `
NodeRewrit
er` to specific `Op`\s.
The `Op`\s are provided by a :meth:`
LocalOptimiz
er.op_key` method (either
The `Op`\s are provided by a :meth:`
NodeRewrit
er.op_key` method (either
as a list of `Op`\s or a single `Op`), and discovered within a
as a list of `Op`\s or a single `Op`), and discovered within a
`FunctionGraph` using the `NodeFinder` `Feature`.
`FunctionGraph` using the `NodeFinder` `Feature`.
...
@@ -2101,13 +2109,13 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -2101,13 +2109,13 @@ class OpKeyOptimizer(NavigatorOptimizer):
"""
"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
node_rewriter
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
"op_key"
):
if
not
hasattr
(
node_rewriter
,
"op_key"
):
raise
TypeError
(
f
"{
local_opt
} must have an `op_key` method."
)
raise
TypeError
(
f
"{
node_rewriter
} must have an `op_key` method."
)
super
()
.
__init__
(
local_opt
,
ignore_newtrees
,
failure_callback
)
super
()
.
__init__
(
node_rewriter
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
op
=
self
.
local_opt
.
op_key
()
op
=
self
.
node_rewriter
.
op_key
()
if
isinstance
(
op
,
(
list
,
tuple
)):
if
isinstance
(
op
,
(
list
,
tuple
)):
q
=
reduce
(
list
.
__iadd__
,
map
(
fgraph
.
get_nodes
,
op
))
q
=
reduce
(
list
.
__iadd__
,
map
(
fgraph
.
get_nodes
,
op
))
else
:
else
:
...
@@ -2175,68 +2183,86 @@ def merge_dict(d1, d2):
...
@@ -2175,68 +2183,86 @@ def merge_dict(d1, d2):
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
"""An optimizer that applies an optimization until a fixed-point/equilibrium is reached.
"""An `Rewriter` that applies an optimization until a fixed-point/equilibrium is reached."""
def
__init__
(
self
,
optimizers
:
Sequence
[
Rewriter
],
failure_callback
:
Optional
[
FailureCallbackType
]
=
None
,
ignore_newtrees
:
bool
=
True
,
tracks_on_change_inputs
:
bool
=
False
,
max_use_ratio
:
Optional
[
float
]
=
None
,
final_optimizers
:
Optional
[
Sequence
[
GraphRewriter
]]
=
None
,
cleanup_optimizers
:
Optional
[
Sequence
[
GraphRewriter
]]
=
None
,
):
"""
Parameters
Parameters
----------
----------
optimizers : list or set
optimizers
Local or global optimization
s to apply until equilibrium.
Node or graph rewriter
s to apply until equilibrium.
The global optimizer will be run at the start of each iteration before
The global optimizer will be run at the start of each iteration before
the local optimizer.
the node rewriter.
max_use_ratio : int or float
failure_callback
Each optimizer can be applied at most ``(size of graph * this number)``
See :attr:`NavigatorOptimizer.failure_callback`.
ignore_newtrees
See :attr:`NavigatorOptimizer.ignore_newtrees`.
tracks_on_change_inputs
See :attr:`NavigatorOptimizer.tracks_on_change_inputs`.
max_use_ratio
Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)``
times.
times.
ignore_newtrees :
final_optimizers
See :attr:`EquilibriumDB.ignore_newtrees`.
Rewriters that will be run after each iteration.
final_optimizers :
cleanup_optimizers
Global optimizers that will be run after each iteration.
Rewriters applied after all graph rewriters, then when one
cleanup_optimizers :
`NodeRewriter` is applied, then after all final rewriters.
Global optimizers that apply a list of pre determined optimization.
They should not traverse the entire graph, since they are called
They must not traverse the graph as they are called very frequently.
very frequently. The `MergeOptimizer` is one example of a rewriter
The MergeOptimizer is one example of optimization that respect this.
that respect this.
They are applied after all global optimizers, then when one local
optimizer is applied, then after all final optimizers.
"""
"""
def
__init__
(
self
,
optimizers
,
failure_callback
=
None
,
ignore_newtrees
=
True
,
tracks_on_change_inputs
=
False
,
max_use_ratio
=
None
,
final_optimizers
=
None
,
cleanup_optimizers
=
None
,
):
super
()
.
__init__
(
super
()
.
__init__
(
None
,
ignore_newtrees
=
ignore_newtrees
,
failure_callback
=
failure_callback
None
,
ignore_newtrees
=
ignore_newtrees
,
failure_callback
=
failure_callback
)
)
self
.
global_optimizers
=
[]
self
.
global_optimizers
:
List
[
GraphRewriter
]
=
[]
self
.
final_optimizers
=
[]
self
.
cleanup_optimizers
=
[]
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
local_tracker
=
LocalOptTracker
()
self
.
local_tracker
=
LocalOptTracker
()
for
opt
in
optimizers
:
for
opt
in
optimizers
:
if
isinstance
(
opt
,
LocalOptimiz
er
):
if
isinstance
(
opt
,
NodeRewrit
er
):
self
.
local_tracker
.
add_tracker
(
opt
)
self
.
local_tracker
.
add_tracker
(
opt
)
else
:
else
:
assert
isinstance
(
opt
,
GraphRewriter
)
self
.
global_optimizers
.
append
(
opt
)
self
.
global_optimizers
.
append
(
opt
)
if
final_optimizers
:
if
final_optimizers
:
self
.
final_optimizers
=
final_optimizers
self
.
final_optimizers
=
list
(
final_optimizers
)
else
:
self
.
final_optimizers
=
[]
if
cleanup_optimizers
:
if
cleanup_optimizers
:
self
.
cleanup_optimizers
=
cleanup_optimizers
self
.
cleanup_optimizers
=
list
(
cleanup_optimizers
)
else
:
self
.
cleanup_optimizers
=
[]
self
.
max_use_ratio
=
max_use_ratio
self
.
max_use_ratio
=
max_use_ratio
def
get_
local_optimiz
ers
(
self
):
def
get_
node_rewrit
ers
(
self
):
yield
from
self
.
local_tracker
.
get_rewriters
()
yield
from
self
.
local_tracker
.
get_rewriters
()
def
get_local_optimizers
(
self
):
warnings
.
warn
(
"`get_local_optimizers` is deprecated; use `get_node_rewriters` instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
yield
from
self
.
get_node_rewriters
()
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
super
()
.
add_requirements
(
fgraph
)
super
()
.
add_requirements
(
fgraph
)
for
opt
in
self
.
get_
local_optimiz
ers
():
for
opt
in
self
.
get_
node_rewrit
ers
():
opt
.
add_requirements
(
fgraph
)
opt
.
add_requirements
(
fgraph
)
for
opt
in
self
.
global_optimizers
:
for
opt
in
self
.
global_optimizers
:
opt
.
add_requirements
(
fgraph
)
opt
.
add_requirements
(
fgraph
)
...
@@ -2274,7 +2300,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2274,7 +2300,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
cleanup_sub_profs
=
[]
cleanup_sub_profs
=
[]
for
opt
in
(
for
opt
in
(
self
.
global_optimizers
self
.
global_optimizers
+
list
(
self
.
get_
local_optimiz
ers
())
+
list
(
self
.
get_
node_rewrit
ers
())
+
self
.
final_optimizers
+
self
.
final_optimizers
+
self
.
cleanup_optimizers
+
self
.
cleanup_optimizers
):
):
...
@@ -2468,7 +2494,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2468,7 +2494,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
f
"{' ' * level}{self.__class__.__name__} {name} id={id(self)}"
,
file
=
stream
f
"{' ' * level}{self.__class__.__name__} {name} id={id(self)}"
,
file
=
stream
)
)
if
depth
!=
0
:
if
depth
!=
0
:
for
lopt
in
self
.
get_
local_optimiz
ers
():
for
lopt
in
self
.
get_
node_rewrit
ers
():
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
@staticmethod
@staticmethod
...
@@ -2502,7 +2528,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2502,7 +2528,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
file
=
stream
,
file
=
stream
,
)
)
print
(
blanc
,
f
" time io_toposort {sum(io_toposort_timing):.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time io_toposort {sum(io_toposort_timing):.3f}s"
,
file
=
stream
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
get_
local_optimiz
ers
())
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
get_
node_rewrit
ers
())
print
(
blanc
,
f
" time in local optimizers {s:.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time in local optimizers {s:.3f}s"
,
file
=
stream
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
global_optimizers
)
s
=
sum
(
time_opts
[
o
]
for
o
in
opt
.
global_optimizers
)
print
(
blanc
,
f
" time in global optimizers {s:.3f}s"
,
file
=
stream
)
print
(
blanc
,
f
" time in global optimizers {s:.3f}s"
,
file
=
stream
)
...
@@ -2534,7 +2560,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2534,7 +2560,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count
=
{}
process_count
=
{}
for
o
in
(
for
o
in
(
opt
.
global_optimizers
opt
.
global_optimizers
+
list
(
opt
.
get_
local_optimiz
ers
())
+
list
(
opt
.
get_
node_rewrit
ers
())
+
list
(
opt
.
final_optimizers
)
+
list
(
opt
.
final_optimizers
)
+
list
(
opt
.
cleanup_optimizers
)
+
list
(
opt
.
cleanup_optimizers
)
):
):
...
@@ -2605,8 +2631,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2605,8 +2631,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
merge_profile
(
prof1
,
prof2
):
def
merge_profile
(
prof1
,
prof2
):
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers
=
OrderedSet
(
prof1
[
0
]
.
get_local_optimiz
ers
())
.
union
(
node_rewriters
=
OrderedSet
(
prof1
[
0
]
.
get_node_rewrit
ers
())
.
union
(
prof2
[
0
]
.
get_
local_optimiz
ers
()
prof2
[
0
]
.
get_
node_rewrit
ers
()
)
)
global_optimizers
=
OrderedSet
(
prof1
[
0
]
.
global_optimizers
)
.
union
(
global_optimizers
=
OrderedSet
(
prof1
[
0
]
.
global_optimizers
)
.
union
(
prof2
[
0
]
.
global_optimizers
prof2
[
0
]
.
global_optimizers
...
@@ -2618,7 +2644,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2618,7 +2644,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
OrderedSet
(
prof1
[
0
]
.
cleanup_optimizers
)
.
union
(
prof2
[
0
]
.
cleanup_optimizers
)
OrderedSet
(
prof1
[
0
]
.
cleanup_optimizers
)
.
union
(
prof2
[
0
]
.
cleanup_optimizers
)
)
)
new_opt
=
EquilibriumOptimizer
(
new_opt
=
EquilibriumOptimizer
(
local_optimiz
ers
.
union
(
global_optimizers
),
node_rewrit
ers
.
union
(
global_optimizers
),
max_use_ratio
=
1
,
max_use_ratio
=
1
,
final_optimizers
=
final_optimizers
,
final_optimizers
=
final_optimizers
,
cleanup_optimizers
=
cleanup_optimizers
,
cleanup_optimizers
=
cleanup_optimizers
,
...
@@ -2758,7 +2784,7 @@ def check_chain(r, *chain):
...
@@ -2758,7 +2784,7 @@ def check_chain(r, *chain):
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
def
pre_greedy_
local_optimiz
er
(
fgraph
,
optimizations
,
out
):
def
pre_greedy_
node_rewrit
er
(
fgraph
,
optimizations
,
out
):
"""Apply local optimizations to a graph.
"""Apply local optimizations to a graph.
This function traverses the computation graph in the graph before the
This function traverses the computation graph in the graph before the
...
@@ -2786,7 +2812,7 @@ def pre_greedy_local_optimizer(fgraph, optimizations, out):
...
@@ -2786,7 +2812,7 @@ def pre_greedy_local_optimizer(fgraph, optimizations, out):
----------
----------
fgraph : FunctionGraph
fgraph : FunctionGraph
The graph used to avoid/filter nodes.
The graph used to avoid/filter nodes.
optimizations : list of
LocalOptimiz
er
optimizations : list of
NodeRewrit
er
The list of local optimizations to apply
The list of local optimizations to apply
out : Variable
out : Variable
A `Variable` specifying the graph to optimize.
A `Variable` specifying the graph to optimize.
...
@@ -3065,6 +3091,21 @@ DEPRECATED_NAMES = [
...
@@ -3065,6 +3091,21 @@ DEPRECATED_NAMES = [
"`GlobalOptimizer` is deprecated: use `GraphRewriter` instead."
,
"`GlobalOptimizer` is deprecated: use `GraphRewriter` instead."
,
GraphRewriter
,
GraphRewriter
,
),
),
(
"LocalOptimizer"
,
"`LocalOptimizer` is deprecated: use `NodeRewriter` instead."
,
NodeRewriter
,
),
(
"local_optimizer"
,
"`local_optimizer` is deprecated: use `node_rewriter` instead."
,
node_rewriter
,
),
(
"pre_greedy_local_optimizer"
,
"`pre_greedy_local_optimizer` is deprecated: use `pre_greedy_node_rewriter` instead."
,
pre_greedy_node_rewriter
,
),
]
]
...
...
aesara/graph/optdb.py
浏览文件 @
550a6e98
...
@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
...
@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
from
aesara.utils
import
DefaultOrderedDict
from
aesara.utils
import
DefaultOrderedDict
OptimizersType
=
Union
[
aesara_opt
.
GraphRewriter
,
aesara_opt
.
LocalOptimiz
er
]
OptimizersType
=
Union
[
aesara_opt
.
GraphRewriter
,
aesara_opt
.
NodeRewrit
er
]
class
OptimizationDatabase
:
class
OptimizationDatabase
:
r"""A class that represents a collection/database of optimizations.
r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers
These databases are used to logically organize collections of optimizers
(i.e. `GraphRewriter`\s and `
LocalOptimiz
er`).
(i.e. `GraphRewriter`\s and `
NodeRewrit
er`).
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -62,7 +62,7 @@ class OptimizationDatabase:
...
@@ -62,7 +62,7 @@ class OptimizationDatabase:
(
(
OptimizationDatabase
,
OptimizationDatabase
,
aesara_opt
.
GraphRewriter
,
aesara_opt
.
GraphRewriter
,
aesara_opt
.
LocalOptimiz
er
,
aesara_opt
.
NodeRewrit
er
,
),
),
):
):
raise
TypeError
(
f
"{optimizer} is not a valid optimizer type."
)
raise
TypeError
(
f
"{optimizer} is not a valid optimizer type."
)
...
@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
...
@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
Notes
Notes
-----
-----
We can use `
LocalOptimiz
er` and `GraphRewriter` since `EquilibriumOptimizer`
We can use `
NodeRewrit
er` and `GraphRewriter` since `EquilibriumOptimizer`
supports both.
supports both.
It is probably not a good idea to have ignore_newtrees=False and
It is probably not a good idea to have ignore_newtrees=False and
...
@@ -474,24 +474,18 @@ class SequenceDB(OptimizationDatabase):
...
@@ -474,24 +474,18 @@ class SequenceDB(OptimizationDatabase):
class
LocalGroupDB
(
SequenceDB
):
class
LocalGroupDB
(
SequenceDB
):
"""
r"""A database that generates `NodeRewriter`\s of type `LocalOptGroup`."""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
It supports the tracks, to only get applied to some Op.
"""
def
__init__
(
def
__init__
(
self
,
self
,
apply_all_opts
:
bool
=
False
,
apply_all_opts
:
bool
=
False
,
profile
:
bool
=
False
,
profile
:
bool
=
False
,
local_opt
=
aesara_opt
.
LocalOptGroup
,
node_rewriter
=
aesara_opt
.
LocalOptGroup
,
):
):
super
()
.
__init__
(
failure_callback
=
None
)
super
()
.
__init__
(
failure_callback
=
None
)
self
.
apply_all_opts
=
apply_all_opts
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
self
.
profile
=
profile
self
.
local_opt
=
local_opt
self
.
node_rewriter
=
node_rewriter
self
.
__name__
:
str
=
""
self
.
__name__
:
str
=
""
def
register
(
self
,
name
,
obj
,
*
tags
,
position
=
"last"
,
**
kwargs
):
def
register
(
self
,
name
,
obj
,
*
tags
,
position
=
"last"
,
**
kwargs
):
...
@@ -499,7 +493,7 @@ class LocalGroupDB(SequenceDB):
...
@@ -499,7 +493,7 @@ class LocalGroupDB(SequenceDB):
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
opts
=
list
(
super
()
.
query
(
*
tags
,
**
kwtags
))
opts
=
list
(
super
()
.
query
(
*
tags
,
**
kwtags
))
ret
=
self
.
local_opt
(
ret
=
self
.
node_rewriter
(
*
opts
,
apply_all_opts
=
self
.
apply_all_opts
,
profile
=
self
.
profile
*
opts
,
apply_all_opts
=
self
.
apply_all_opts
,
profile
=
self
.
profile
)
)
return
ret
return
ret
...
...
aesara/ifelse.py
浏览文件 @
550a6e98
...
@@ -22,7 +22,7 @@ from aesara.compile import optdb
...
@@ -22,7 +22,7 @@ from aesara.compile import optdb
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
node_rewrit
er
from
aesara.graph.type
import
HasDataType
,
HasShape
from
aesara.graph.type
import
HasDataType
,
HasShape
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
...
@@ -404,7 +404,7 @@ def ifelse(
...
@@ -404,7 +404,7 @@ def ifelse(
return
tuple
(
rval
)
return
tuple
(
rval
)
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_make_inplace
(
fgraph
,
node
):
def
cond_make_inplace
(
fgraph
,
node
):
op
=
node
.
op
op
=
node
.
op
if
(
if
(
...
@@ -482,7 +482,7 @@ acceptable_ops = (
...
@@ -482,7 +482,7 @@ acceptable_ops = (
)
)
@
local_optimiz
er
(
acceptable_ops
)
@
node_rewrit
er
(
acceptable_ops
)
def
ifelse_lift_single_if_through_acceptable_ops
(
fgraph
,
main_node
):
def
ifelse_lift_single_if_through_acceptable_ops
(
fgraph
,
main_node
):
"""This optimization lifts up certain ifelse instances.
"""This optimization lifts up certain ifelse instances.
...
@@ -529,7 +529,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
...
@@ -529,7 +529,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
return
nw_outs
return
nw_outs
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_merge_ifs_true
(
fgraph
,
node
):
def
cond_merge_ifs_true
(
fgraph
,
node
):
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
IfElse
):
if
not
isinstance
(
op
,
IfElse
):
...
@@ -556,7 +556,7 @@ def cond_merge_ifs_true(fgraph, node):
...
@@ -556,7 +556,7 @@ def cond_merge_ifs_true(fgraph, node):
return
op
(
*
old_ins
,
return_list
=
True
)
return
op
(
*
old_ins
,
return_list
=
True
)
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_merge_ifs_false
(
fgraph
,
node
):
def
cond_merge_ifs_false
(
fgraph
,
node
):
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
IfElse
):
if
not
isinstance
(
op
,
IfElse
):
...
@@ -635,7 +635,7 @@ class CondMerge(GraphRewriter):
...
@@ -635,7 +635,7 @@ class CondMerge(GraphRewriter):
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"cond_merge"
)
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"cond_merge"
)
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_remove_identical
(
fgraph
,
node
):
def
cond_remove_identical
(
fgraph
,
node
):
op
=
node
.
op
op
=
node
.
op
...
@@ -681,7 +681,7 @@ def cond_remove_identical(fgraph, node):
...
@@ -681,7 +681,7 @@ def cond_remove_identical(fgraph, node):
return
rval
return
rval
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_merge_random_op
(
fgraph
,
main_node
):
def
cond_merge_random_op
(
fgraph
,
main_node
):
if
isinstance
(
main_node
.
op
,
IfElse
):
if
isinstance
(
main_node
.
op
,
IfElse
):
return
False
return
False
...
...
aesara/sandbox/linalg/ops.py
浏览文件 @
550a6e98
import
logging
import
logging
from
aesara.graph.opt
import
local_optimiz
er
from
aesara.graph.opt
import
node_rewrit
er
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
basic
as
at
from
aesara.tensor.basic_opt
import
(
from
aesara.tensor.basic_opt
import
(
register_canonicalize
,
register_canonicalize
,
...
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
...
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
transinv_to_invtrans
(
fgraph
,
node
):
def
transinv_to_invtrans
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
DimShuffle
):
if
isinstance
(
node
.
op
,
DimShuffle
):
if
node
.
op
.
new_order
==
(
1
,
0
):
if
node
.
op
.
new_order
==
(
1
,
0
):
...
@@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node):
...
@@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
Dot
,
Dot22
])
@
node_rewrit
er
([
Dot
,
Dot22
])
def
inv_as_solve
(
fgraph
,
node
):
def
inv_as_solve
(
fgraph
,
node
):
"""
"""
This utilizes a boolean `symmetric` tag on the matrices.
This utilizes a boolean `symmetric` tag on the matrices.
...
@@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node):
...
@@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node):
@register_stabilize
@register_stabilize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Solve
])
@
node_rewrit
er
([
Solve
])
def
tag_solve_triangular
(
fgraph
,
node
):
def
tag_solve_triangular
(
fgraph
,
node
):
"""
"""
If a general solve() is applied to the output of a cholesky op, then
If a general solve() is applied to the output of a cholesky op, then
...
@@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node):
...
@@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
no_transpose_symmetric
(
fgraph
,
node
):
def
no_transpose_symmetric
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
DimShuffle
):
if
isinstance
(
node
.
op
,
DimShuffle
):
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node):
...
@@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
Solve
])
@
node_rewrit
er
([
Solve
])
def
psd_solve_with_chol
(
fgraph
,
node
):
def
psd_solve_with_chol
(
fgraph
,
node
):
"""
"""
This utilizes a boolean `psd` tag on matrices.
This utilizes a boolean `psd` tag on matrices.
...
@@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node):
...
@@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Det
])
@
node_rewrit
er
([
Det
])
def
local_det_chol
(
fgraph
,
node
):
def
local_det_chol
(
fgraph
,
node
):
"""
"""
If we have det(X) and there is already an L=cholesky(X)
If we have det(X) and there is already an L=cholesky(X)
...
@@ -129,7 +129,7 @@ def local_det_chol(fgraph, node):
...
@@ -129,7 +129,7 @@ def local_det_chol(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
log
])
@
node_rewrit
er
([
log
])
def
local_log_prod_sqr
(
fgraph
,
node
):
def
local_log_prod_sqr
(
fgraph
,
node
):
"""
"""
This utilizes a boolean `positive` tag on matrices.
This utilizes a boolean `positive` tag on matrices.
...
...
aesara/sandbox/rng_mrg.py
浏览文件 @
550a6e98
...
@@ -25,7 +25,7 @@ from aesara.compile import optdb
...
@@ -25,7 +25,7 @@ from aesara.compile import optdb
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.gradient
import
undefined_grad
from
aesara.gradient
import
undefined_grad
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.opt
import
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
in2out
,
node_rewrit
er
from
aesara.link.c.op
import
COp
,
Op
from
aesara.link.c.op
import
COp
,
Op
from
aesara.link.c.params_type
import
ParamsType
from
aesara.link.c.params_type
import
ParamsType
from
aesara.sandbox
import
multinomial
from
aesara.sandbox
import
multinomial
...
@@ -1343,7 +1343,7 @@ def _check_size(size):
...
@@ -1343,7 +1343,7 @@ def _check_size(size):
return
at
.
as_tensor_variable
(
size
,
ndim
=
1
)
return
at
.
as_tensor_variable
(
size
,
ndim
=
1
)
@
local_optimiz
er
((
mrg_uniform_base
,))
@
node_rewrit
er
((
mrg_uniform_base
,))
def
mrg_random_make_inplace
(
fgraph
,
node
):
def
mrg_random_make_inplace
(
fgraph
,
node
):
op
=
node
.
op
op
=
node
.
op
...
...
aesara/scan/opt.py
浏览文件 @
550a6e98
...
@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
...
@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
from
aesara.graph.features
import
ReplaceValidate
from
aesara.graph.features
import
ReplaceValidate
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
node_rewrit
er
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.graph.type
import
HasShape
from
aesara.graph.type
import
HasShape
from
aesara.graph.utils
import
InconsistencyError
from
aesara.graph.utils
import
InconsistencyError
...
@@ -67,7 +67,7 @@ list_opt_slice = [
...
@@ -67,7 +67,7 @@ list_opt_slice = [
]
]
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
remove_constants_and_unused_inputs_scan
(
fgraph
,
node
):
def
remove_constants_and_unused_inputs_scan
(
fgraph
,
node
):
"""Move constants into the inner graph, and remove unused inputs.
"""Move constants into the inner graph, and remove unused inputs.
...
@@ -192,7 +192,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
...
@@ -192,7 +192,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return
False
return
False
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_non_seq_scan
(
fgraph
,
node
):
def
push_out_non_seq_scan
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
...
@@ -400,7 +400,7 @@ def push_out_non_seq_scan(fgraph, node):
...
@@ -400,7 +400,7 @@ def push_out_non_seq_scan(fgraph, node):
return
False
return
False
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_seq_scan
(
fgraph
,
node
):
def
push_out_seq_scan
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
...
@@ -812,7 +812,7 @@ def add_nitsot_outputs(
...
@@ -812,7 +812,7 @@ def add_nitsot_outputs(
return
new_scan_node
,
{}
return
new_scan_node
,
{}
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_add_scan
(
fgraph
,
node
):
def
push_out_add_scan
(
fgraph
,
node
):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
r"""Push `Add` operations performed at the end of the inner graph to the outside.
...
@@ -1113,7 +1113,7 @@ def sanitize(x):
...
@@ -1113,7 +1113,7 @@ def sanitize(x):
return
at
.
as_tensor_variable
(
x
)
return
at
.
as_tensor_variable
(
x
)
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
save_mem_new_scan
(
fgraph
,
node
):
def
save_mem_new_scan
(
fgraph
,
node
):
r"""Graph optimizer that reduces scan memory consumption.
r"""Graph optimizer that reduces scan memory consumption.
...
@@ -1950,7 +1950,7 @@ def make_equiv(lo, li):
...
@@ -1950,7 +1950,7 @@ def make_equiv(lo, li):
return
left
,
right
return
left
,
right
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
scan_merge_inouts
(
fgraph
,
node
):
def
scan_merge_inouts
(
fgraph
,
node
):
"""
"""
This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well
This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well
...
@@ -2154,7 +2154,7 @@ def scan_merge_inouts(fgraph, node):
...
@@ -2154,7 +2154,7 @@ def scan_merge_inouts(fgraph, node):
return
na
.
outer_outputs
return
na
.
outer_outputs
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_dot1_scan
(
fgraph
,
node
):
def
push_out_dot1_scan
(
fgraph
,
node
):
r"""
r"""
This is another optimization that attempts to detect certain patterns of
This is another optimization that attempts to detect certain patterns of
...
...
aesara/sparse/opt.py
浏览文件 @
550a6e98
...
@@ -4,7 +4,7 @@ import aesara
...
@@ -4,7 +4,7 @@ import aesara
import
aesara.scalar
as
aes
import
aesara.scalar
as
aes
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.opt
import
PatternSub
,
TopoOptimizer
,
local_optimiz
er
from
aesara.graph.opt
import
PatternSub
,
TopoOptimizer
,
node_rewrit
er
from
aesara.link.c.op
import
COp
,
_NoPythonCOp
from
aesara.link.c.op
import
COp
,
_NoPythonCOp
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.sparse
import
basic
as
sparse
from
aesara.sparse
import
basic
as
sparse
...
@@ -32,7 +32,7 @@ _is_dense = sparse._is_dense
...
@@ -32,7 +32,7 @@ _is_dense = sparse._is_dense
# This is tested in tests/test_opt.py:test_local_csm_properties_csm
# This is tested in tests/test_opt.py:test_local_csm_properties_csm
@
local_optimiz
er
([
csm_properties
])
@
node_rewrit
er
([
csm_properties
])
def
local_csm_properties_csm
(
fgraph
,
node
):
def
local_csm_properties_csm
(
fgraph
,
node
):
"""
"""
If we find csm_properties(CSM(*args)), then we can replace that with the
If we find csm_properties(CSM(*args)), then we can replace that with the
...
@@ -51,7 +51,7 @@ register_specialize(local_csm_properties_csm)
...
@@ -51,7 +51,7 @@ register_specialize(local_csm_properties_csm)
# This is tested in tests/test_basic.py:test_remove0
# This is tested in tests/test_basic.py:test_remove0
@
local_optimiz
er
([
sparse
.
Remove0
])
@
node_rewrit
er
([
sparse
.
Remove0
])
def
local_inplace_remove0
(
fgraph
,
node
):
def
local_inplace_remove0
(
fgraph
,
node
):
"""
"""
Optimization to insert inplace versions of Remove0.
Optimization to insert inplace versions of Remove0.
...
@@ -188,7 +188,7 @@ class AddSD_ccode(_NoPythonCOp):
...
@@ -188,7 +188,7 @@ class AddSD_ccode(_NoPythonCOp):
return
(
2
,)
return
(
2
,)
@
local_optimiz
er
([
sparse
.
AddSD
])
@
node_rewrit
er
([
sparse
.
AddSD
])
def
local_inplace_addsd_ccode
(
fgraph
,
node
):
def
local_inplace_addsd_ccode
(
fgraph
,
node
):
"""
"""
Optimization to insert inplace versions of AddSD.
Optimization to insert inplace versions of AddSD.
...
@@ -218,7 +218,7 @@ aesara.compile.optdb.register(
...
@@ -218,7 +218,7 @@ aesara.compile.optdb.register(
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@register_specialize
@register_specialize
@
local_optimiz
er
([
sparse
.
DenseFromSparse
])
@
node_rewrit
er
([
sparse
.
DenseFromSparse
])
def
local_dense_from_sparse_sparse_from_dense
(
fgraph
,
node
):
def
local_dense_from_sparse_sparse_from_dense
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
sparse
.
DenseFromSparse
):
if
isinstance
(
node
.
op
,
sparse
.
DenseFromSparse
):
inp
=
node
.
inputs
[
0
]
inp
=
node
.
inputs
[
0
]
...
@@ -226,7 +226,7 @@ def local_dense_from_sparse_sparse_from_dense(fgraph, node):
...
@@ -226,7 +226,7 @@ def local_dense_from_sparse_sparse_from_dense(fgraph, node):
return
inp
.
owner
.
inputs
return
inp
.
owner
.
inputs
@
local_optimiz
er
([
sparse
.
AddSD
])
@
node_rewrit
er
([
sparse
.
AddSD
])
def
local_addsd_ccode
(
fgraph
,
node
):
def
local_addsd_ccode
(
fgraph
,
node
):
"""
"""
Convert AddSD to faster AddSD_ccode.
Convert AddSD to faster AddSD_ccode.
...
@@ -638,7 +638,7 @@ sd_csr = StructuredDotCSR()
...
@@ -638,7 +638,7 @@ sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx
# register a specialization to replace StructuredDot -> StructuredDotCSx
# This is tested in tests/test_basic.py:792
# This is tested in tests/test_basic.py:792
@
local_optimiz
er
([
sparse
.
_structured_dot
])
@
node_rewrit
er
([
sparse
.
_structured_dot
])
def
local_structured_dot
(
fgraph
,
node
):
def
local_structured_dot
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
_structured_dot
:
if
node
.
op
==
sparse
.
_structured_dot
:
a
,
b
=
node
.
inputs
a
,
b
=
node
.
inputs
...
@@ -950,7 +950,7 @@ register_specialize(local_usmm, name="local_usmm")
...
@@ -950,7 +950,7 @@ register_specialize(local_usmm, name="local_usmm")
# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace
# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace
# This is tested in tests/test_basic.py:UsmmTests
# This is tested in tests/test_basic.py:UsmmTests
@
local_optimiz
er
([
usmm_csc_dense
])
@
node_rewrit
er
([
usmm_csc_dense
])
def
local_usmm_csc_dense_inplace
(
fgraph
,
node
):
def
local_usmm_csc_dense_inplace
(
fgraph
,
node
):
if
node
.
op
==
usmm_csc_dense
:
if
node
.
op
==
usmm_csc_dense
:
return
[
usmm_csc_dense_inplace
(
*
node
.
inputs
)]
return
[
usmm_csc_dense_inplace
(
*
node
.
inputs
)]
...
@@ -960,7 +960,7 @@ register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace")
...
@@ -960,7 +960,7 @@ register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace")
# This is tested in tests/test_basic.py:UsmmTests
# This is tested in tests/test_basic.py:UsmmTests
@
local_optimiz
er
([
usmm
])
@
node_rewrit
er
([
usmm
])
def
local_usmm_csx
(
fgraph
,
node
):
def
local_usmm_csx
(
fgraph
,
node
):
"""
"""
usmm -> usmm_csc_dense
usmm -> usmm_csc_dense
...
@@ -1120,7 +1120,7 @@ csm_grad_c = CSMGradC()
...
@@ -1120,7 +1120,7 @@ csm_grad_c = CSMGradC()
# register a specialization to replace csm_grad -> csm_grad_c
# register a specialization to replace csm_grad -> csm_grad_c
# This is tested in tests/test_opt.py:test_local_csm_grad_c
# This is tested in tests/test_opt.py:test_local_csm_grad_c
@
local_optimiz
er
([
csm_grad
(
None
)])
@
node_rewrit
er
([
csm_grad
(
None
)])
def
local_csm_grad_c
(
fgraph
,
node
):
def
local_csm_grad_c
(
fgraph
,
node
):
"""
"""
csm_grad(None) -> csm_grad_c
csm_grad(None) -> csm_grad_c
...
@@ -1404,7 +1404,7 @@ mul_s_d_csr = MulSDCSR()
...
@@ -1404,7 +1404,7 @@ mul_s_d_csr = MulSDCSR()
# register a specialization to replace MulSD -> MulSDCSX
# register a specialization to replace MulSD -> MulSDCSX
@
local_optimiz
er
([
sparse
.
mul_s_d
])
@
node_rewrit
er
([
sparse
.
mul_s_d
])
def
local_mul_s_d
(
fgraph
,
node
):
def
local_mul_s_d
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
mul_s_d
:
if
node
.
op
==
sparse
.
mul_s_d
:
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
...
@@ -1584,7 +1584,7 @@ mul_s_v_csr = MulSVCSR()
...
@@ -1584,7 +1584,7 @@ mul_s_v_csr = MulSVCSR()
# register a specialization to replace MulSV -> MulSVCSR
# register a specialization to replace MulSV -> MulSVCSR
@
local_optimiz
er
([
sparse
.
mul_s_v
])
@
node_rewrit
er
([
sparse
.
mul_s_v
])
def
local_mul_s_v
(
fgraph
,
node
):
def
local_mul_s_v
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
mul_s_v
:
if
node
.
op
==
sparse
.
mul_s_v
:
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
...
@@ -1762,7 +1762,7 @@ structured_add_s_v_csr = StructuredAddSVCSR()
...
@@ -1762,7 +1762,7 @@ structured_add_s_v_csr = StructuredAddSVCSR()
# register a specialization to replace
# register a specialization to replace
# structured_add_s_v -> structured_add_s_v_csr
# structured_add_s_v -> structured_add_s_v_csr
@
local_optimiz
er
([
sparse
.
structured_add_s_v
])
@
node_rewrit
er
([
sparse
.
structured_add_s_v
])
def
local_structured_add_s_v
(
fgraph
,
node
):
def
local_structured_add_s_v
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
structured_add_s_v
:
if
node
.
op
==
sparse
.
structured_add_s_v
:
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
...
@@ -2051,7 +2051,7 @@ sampling_dot_csr = SamplingDotCSR()
...
@@ -2051,7 +2051,7 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr
# register a specialization to replace SamplingDot -> SamplingDotCsr
@
local_optimiz
er
([
sparse
.
sampling_dot
])
@
node_rewrit
er
([
sparse
.
sampling_dot
])
def
local_sampling_dot_csr
(
fgraph
,
node
):
def
local_sampling_dot_csr
(
fgraph
,
node
):
if
not
config
.
blas__ldflags
:
if
not
config
.
blas__ldflags
:
# The C implementation of SamplingDotCsr relies on BLAS routines
# The C implementation of SamplingDotCsr relies on BLAS routines
...
...
aesara/tensor/basic_opt.py
浏览文件 @
550a6e98
...
@@ -32,7 +32,7 @@ from aesara.graph.opt import (
...
@@ -32,7 +32,7 @@ from aesara.graph.opt import (
check_chain
,
check_chain
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
local_optimiz
er
,
node_rewrit
er
,
)
)
from
aesara.graph.optdb
import
SequenceDB
from
aesara.graph.optdb
import
SequenceDB
from
aesara.graph.utils
import
(
from
aesara.graph.utils
import
(
...
@@ -605,7 +605,7 @@ def is_dimshuffle_useless(new_order, input):
...
@@ -605,7 +605,7 @@ def is_dimshuffle_useless(new_order, input):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_lift
(
fgraph
,
node
):
def
local_dimshuffle_lift
(
fgraph
,
node
):
"""
"""
"Lifts" DimShuffle through Elemwise operations and merges
"Lifts" DimShuffle through Elemwise operations and merges
...
@@ -651,7 +651,7 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -651,7 +651,7 @@ def local_dimshuffle_lift(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_useless_dimshuffle_makevector
(
fgraph
,
node
):
def
local_useless_dimshuffle_makevector
(
fgraph
,
node
):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
...
@@ -680,7 +680,7 @@ def local_useless_dimshuffle_makevector(fgraph, node):
...
@@ -680,7 +680,7 @@ def local_useless_dimshuffle_makevector(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Reshape
])
@
node_rewrit
er
([
Reshape
])
def
local_useless_dimshuffle_in_reshape
(
fgraph
,
node
):
def
local_useless_dimshuffle_in_reshape
(
fgraph
,
node
):
"""
"""
Removes useless DimShuffle operation inside Reshape:
Removes useless DimShuffle operation inside Reshape:
...
@@ -720,7 +720,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
...
@@ -720,7 +720,7 @@ def local_useless_dimshuffle_in_reshape(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
TensorFromScalar
])
@
node_rewrit
er
([
TensorFromScalar
])
def
local_tensor_scalar_tensor
(
fgraph
,
node
):
def
local_tensor_scalar_tensor
(
fgraph
,
node
):
"""tensor_from_scalar(scalar_from_tensor(x)) -> x"""
"""tensor_from_scalar(scalar_from_tensor(x)) -> x"""
if
isinstance
(
node
.
op
,
TensorFromScalar
):
if
isinstance
(
node
.
op
,
TensorFromScalar
):
...
@@ -734,7 +734,7 @@ def local_tensor_scalar_tensor(fgraph, node):
...
@@ -734,7 +734,7 @@ def local_tensor_scalar_tensor(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
ScalarFromTensor
])
@
node_rewrit
er
([
ScalarFromTensor
])
def
local_scalar_tensor_scalar
(
fgraph
,
node
):
def
local_scalar_tensor_scalar
(
fgraph
,
node
):
"""scalar_from_tensor(tensor_from_scalar(x)) -> x"""
"""scalar_from_tensor(tensor_from_scalar(x)) -> x"""
if
isinstance
(
node
.
op
,
ScalarFromTensor
):
if
isinstance
(
node
.
op
,
ScalarFromTensor
):
...
@@ -1474,7 +1474,7 @@ aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10
...
@@ -1474,7 +1474,7 @@ aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10
@register_specialize
(
"local_alloc_elemwise"
)
@register_specialize
(
"local_alloc_elemwise"
)
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_elemwise_alloc
(
fgraph
,
node
):
def
local_elemwise_alloc
(
fgraph
,
node
):
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
...
@@ -1595,7 +1595,7 @@ def local_elemwise_alloc(fgraph, node):
...
@@ -1595,7 +1595,7 @@ def local_elemwise_alloc(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_fill_sink
(
fgraph
,
node
):
def
local_fill_sink
(
fgraph
,
node
):
"""
"""
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
...
@@ -1647,7 +1647,7 @@ def local_fill_sink(fgraph, node):
...
@@ -1647,7 +1647,7 @@ def local_fill_sink(fgraph, node):
@register_specialize
@register_specialize
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
fill
])
@
node_rewrit
er
([
fill
])
def
local_fill_to_alloc
(
fgraph
,
node
):
def
local_fill_to_alloc
(
fgraph
,
node
):
r"""Remove `fill`\s or replace them with `Alloc`\s.
r"""Remove `fill`\s or replace them with `Alloc`\s.
...
@@ -1698,7 +1698,7 @@ compile.optdb.register(
...
@@ -1698,7 +1698,7 @@ compile.optdb.register(
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@register_useless
@register_useless
@
local_optimiz
er
([
fill
])
@
node_rewrit
er
([
fill
])
def
local_useless_fill
(
fgraph
,
node
):
def
local_useless_fill
(
fgraph
,
node
):
"""fill(s,v) -> v
"""fill(s,v) -> v
...
@@ -1721,7 +1721,7 @@ def local_useless_fill(fgraph, node):
...
@@ -1721,7 +1721,7 @@ def local_useless_fill(fgraph, node):
@register_stabilize
@register_stabilize
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
@
local_optimiz
er
([
Alloc
])
@
node_rewrit
er
([
Alloc
])
def
local_useless_alloc
(
fgraph
,
node
):
def
local_useless_alloc
(
fgraph
,
node
):
"""
"""
If the input type is the same as the output type (dtype and broadcast)
If the input type is the same as the output type (dtype and broadcast)
...
@@ -1751,7 +1751,7 @@ def local_useless_alloc(fgraph, node):
...
@@ -1751,7 +1751,7 @@ def local_useless_alloc(fgraph, node):
@register_specialize
@register_specialize
@register_stabilize
@register_stabilize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Alloc
])
@
node_rewrit
er
([
Alloc
])
def
local_alloc_sink_dimshuffle
(
fgraph
,
node
):
def
local_alloc_sink_dimshuffle
(
fgraph
,
node
):
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s."""
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s."""
op
=
node
.
op
op
=
node
.
op
...
@@ -1785,7 +1785,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
...
@@ -1785,7 +1785,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
return
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inner
)]
return
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inner
)]
@
local_optimiz
er
([
AllocEmpty
])
@
node_rewrit
er
([
AllocEmpty
])
def
local_alloc_empty_to_zeros
(
fgraph
,
node
):
def
local_alloc_empty_to_zeros
(
fgraph
,
node
):
"""This convert AllocEmpty to Alloc of 0.
"""This convert AllocEmpty to Alloc of 0.
...
@@ -1808,7 +1808,7 @@ compile.optdb.register(
...
@@ -1808,7 +1808,7 @@ compile.optdb.register(
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Shape
])
@
node_rewrit
er
([
Shape
])
def
local_shape_to_shape_i
(
fgraph
,
node
):
def
local_shape_to_shape_i
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
Shape
):
if
isinstance
(
node
.
op
,
Shape
):
# This optimization needs ShapeOpt and fgraph.shape_feature
# This optimization needs ShapeOpt and fgraph.shape_feature
...
@@ -1824,7 +1824,7 @@ def local_shape_to_shape_i(fgraph, node):
...
@@ -1824,7 +1824,7 @@ def local_shape_to_shape_i(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Shape_i
])
@
node_rewrit
er
([
Shape_i
])
def
local_track_shape_i
(
fgraph
,
node
):
def
local_track_shape_i
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
Shape_i
):
if
not
isinstance
(
node
.
op
,
Shape_i
):
return
False
return
False
...
@@ -1847,7 +1847,7 @@ def local_track_shape_i(fgraph, node):
...
@@ -1847,7 +1847,7 @@ def local_track_shape_i(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_useless_elemwise
(
fgraph
,
node
):
def
local_useless_elemwise
(
fgraph
,
node
):
"""
"""
eq(x, x) -> 1
eq(x, x) -> 1
...
@@ -1952,7 +1952,7 @@ def local_useless_elemwise(fgraph, node):
...
@@ -1952,7 +1952,7 @@ def local_useless_elemwise(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_alloc_unary
(
fgraph
,
node
):
def
local_alloc_unary
(
fgraph
,
node
):
"""unary(alloc(x, shp)) -> alloc(unary(x), shp)"""
"""unary(alloc(x, shp)) -> alloc(unary(x), shp)"""
if
isinstance
(
node
.
op
,
Elemwise
)
and
len
(
node
.
inputs
)
==
1
:
if
isinstance
(
node
.
op
,
Elemwise
)
and
len
(
node
.
inputs
)
==
1
:
...
@@ -1974,7 +1974,7 @@ def local_alloc_unary(fgraph, node):
...
@@ -1974,7 +1974,7 @@ def local_alloc_unary(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_cast_cast
(
fgraph
,
node
):
def
local_cast_cast
(
fgraph
,
node
):
"""cast(cast(x, dtype1), dtype2)
"""cast(cast(x, dtype1), dtype2)
...
@@ -2052,7 +2052,7 @@ def is_an_upcast(type1, type2):
...
@@ -2052,7 +2052,7 @@ def is_an_upcast(type1, type2):
@register_useless
@register_useless
@register_specialize
@register_specialize
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_remove_useless_assert
(
fgraph
,
node
):
def
local_remove_useless_assert
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
CheckAndRaise
):
if
not
isinstance
(
node
.
op
,
CheckAndRaise
):
return
False
return
False
...
@@ -2079,7 +2079,7 @@ def local_remove_useless_assert(fgraph, node):
...
@@ -2079,7 +2079,7 @@ def local_remove_useless_assert(fgraph, node):
return
[
new_var
]
return
[
new_var
]
@
local_optimiz
er
([
Assert
])
@
node_rewrit
er
([
Assert
])
def
local_remove_all_assert
(
fgraph
,
node
):
def
local_remove_all_assert
(
fgraph
,
node
):
"""An optimization disabled by default that removes all asserts from
"""An optimization disabled by default that removes all asserts from
the graph.
the graph.
...
@@ -2122,7 +2122,7 @@ compile.optdb["useless"].register(
...
@@ -2122,7 +2122,7 @@ compile.optdb["useless"].register(
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_upcast_elemwise_constant_inputs
(
fgraph
,
node
):
def
local_upcast_elemwise_constant_inputs
(
fgraph
,
node
):
"""This explicitly upcasts constant inputs to elemwise Ops, when
"""This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway.
those Ops do implicit upcasting anyway.
...
@@ -2197,7 +2197,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
...
@@ -2197,7 +2197,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Unbroadcast
])
@
node_rewrit
er
([
Unbroadcast
])
def
local_useless_unbroadcast
(
fgraph
,
node
):
def
local_useless_unbroadcast
(
fgraph
,
node
):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
...
@@ -2225,7 +2225,7 @@ def local_useless_unbroadcast(fgraph, node):
...
@@ -2225,7 +2225,7 @@ def local_useless_unbroadcast(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Unbroadcast
])
@
node_rewrit
er
([
Unbroadcast
])
def
local_unbroadcast_lift
(
fgraph
,
node
):
def
local_unbroadcast_lift
(
fgraph
,
node
):
"""
"""
Lifts `Unbroadcast` through unary Elemwise operations,
Lifts `Unbroadcast` through unary Elemwise operations,
...
@@ -2271,7 +2271,7 @@ def local_unbroadcast_lift(fgraph, node):
...
@@ -2271,7 +2271,7 @@ def local_unbroadcast_lift(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
@
local_optimiz
er
([
Join
])
@
node_rewrit
er
([
Join
])
def
local_join_1
(
fgraph
,
node
):
def
local_join_1
(
fgraph
,
node
):
"""Join(i, x) => x
"""Join(i, x) => x
...
@@ -2291,7 +2291,7 @@ def local_join_1(fgraph, node):
...
@@ -2291,7 +2291,7 @@ def local_join_1(fgraph, node):
@register_useless
@register_useless
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Join
])
@
node_rewrit
er
([
Join
])
def
local_join_empty
(
fgraph
,
node
):
def
local_join_empty
(
fgraph
,
node
):
"""Join(i, x, y, empty) => Join(i, x, y)
"""Join(i, x, y, empty) => Join(i, x, y)
...
@@ -2338,7 +2338,7 @@ def local_join_empty(fgraph, node):
...
@@ -2338,7 +2338,7 @@ def local_join_empty(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
@
local_optimiz
er
([
Join
])
@
node_rewrit
er
([
Join
])
def
local_join_make_vector
(
fgraph
,
node
):
def
local_join_make_vector
(
fgraph
,
node
):
r"""Merge `MakeVector` inputs within a `Join`.
r"""Merge `MakeVector` inputs within a `Join`.
...
@@ -2385,7 +2385,7 @@ def local_join_make_vector(fgraph, node):
...
@@ -2385,7 +2385,7 @@ def local_join_make_vector(fgraph, node):
@register_useless
(
"local_remove_switch_const_cond"
)
@register_useless
(
"local_remove_switch_const_cond"
)
@register_canonicalize
(
"fast_compile"
,
"local_remove_switch_const_cond"
)
@register_canonicalize
(
"fast_compile"
,
"local_remove_switch_const_cond"
)
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_useless_switch
(
fgraph
,
node
):
def
local_useless_switch
(
fgraph
,
node
):
"""
"""
This optimization makes the following changes in the graph:
This optimization makes the following changes in the graph:
...
@@ -2462,7 +2462,7 @@ def local_useless_switch(fgraph, node):
...
@@ -2462,7 +2462,7 @@ def local_useless_switch(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_merge_switch_same_cond
(
fgraph
,
node
):
def
local_merge_switch_same_cond
(
fgraph
,
node
):
"""
"""
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
...
@@ -2499,7 +2499,7 @@ def local_merge_switch_same_cond(fgraph, node):
...
@@ -2499,7 +2499,7 @@ def local_merge_switch_same_cond(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Split
])
@
node_rewrit
er
([
Split
])
def
local_useless_split
(
fgraph
,
node
):
def
local_useless_split
(
fgraph
,
node
):
"""Split{n_splits=1}(x, y) -> x
"""Split{n_splits=1}(x, y) -> x
...
@@ -2520,7 +2520,7 @@ def local_useless_split(fgraph, node):
...
@@ -2520,7 +2520,7 @@ def local_useless_split(fgraph, node):
def
local_reshape_chain
(
op
):
def
local_reshape_chain
(
op
):
@
local_optimiz
er
([
op
])
@
node_rewrit
er
([
op
])
def
f
(
fgraph
,
node
):
def
f
(
fgraph
,
node
):
"""
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
...
@@ -2560,7 +2560,7 @@ register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
...
@@ -2560,7 +2560,7 @@ register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
Reshape
])
@
node_rewrit
er
([
Reshape
])
def
local_useless_reshape
(
fgraph
,
node
):
def
local_useless_reshape
(
fgraph
,
node
):
"""
"""
Remove two kinds of useless reshape.
Remove two kinds of useless reshape.
...
@@ -2658,7 +2658,7 @@ def local_useless_reshape(fgraph, node):
...
@@ -2658,7 +2658,7 @@ def local_useless_reshape(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Reshape
])
@
node_rewrit
er
([
Reshape
])
def
local_reshape_to_dimshuffle
(
fgraph
,
node
):
def
local_reshape_to_dimshuffle
(
fgraph
,
node
):
"""
"""
Broadcastable dimensions in Reshape are replaced with dimshuffle.
Broadcastable dimensions in Reshape are replaced with dimshuffle.
...
@@ -2706,7 +2706,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
...
@@ -2706,7 +2706,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
Reshape
])
@
node_rewrit
er
([
Reshape
])
def
local_reshape_lift
(
fgraph
,
node
):
def
local_reshape_lift
(
fgraph
,
node
):
"""
"""
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
...
@@ -2736,7 +2736,7 @@ def local_reshape_lift(fgraph, node):
...
@@ -2736,7 +2736,7 @@ def local_reshape_lift(fgraph, node):
register_canonicalize
(
OpRemove
(
tensor_copy
),
name
=
"remove_tensor_copy"
)
register_canonicalize
(
OpRemove
(
tensor_copy
),
name
=
"remove_tensor_copy"
)
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
constant_folding
(
fgraph
,
node
):
def
constant_folding
(
fgraph
,
node
):
if
not
node
.
op
.
do_constant_folding
(
fgraph
,
node
):
if
not
node
.
op
.
do_constant_folding
(
fgraph
,
node
):
...
@@ -3092,9 +3092,9 @@ class FusionOptimizer(GraphRewriter):
...
@@ -3092,9 +3092,9 @@ class FusionOptimizer(GraphRewriter):
"""
"""
def
__init__
(
self
,
local_optimiz
er
):
def
__init__
(
self
,
node_rewrit
er
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
optimizer
=
local_optimiz
er
self
.
optimizer
=
node_rewrit
er
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
ReplaceValidate
())
fgraph
.
attach_feature
(
ReplaceValidate
())
...
@@ -3206,7 +3206,7 @@ else:
...
@@ -3206,7 +3206,7 @@ else:
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_useless_composite
(
fgraph
,
node
):
def
local_useless_composite
(
fgraph
,
node
):
"""For elemwise Composite that have multiple outputs, remove the
"""For elemwise Composite that have multiple outputs, remove the
outputs that are not used.
outputs that are not used.
...
@@ -3227,7 +3227,7 @@ def local_useless_composite(fgraph, node):
...
@@ -3227,7 +3227,7 @@ def local_useless_composite(fgraph, node):
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@register_useless
(
"fast_compile"
)
@register_useless
(
"fast_compile"
)
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_view_op
(
fgraph
,
node
):
def
local_view_op
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
ViewOp
):
if
isinstance
(
node
.
op
,
ViewOp
):
return
node
.
inputs
return
node
.
inputs
...
@@ -3237,7 +3237,7 @@ def local_view_op(fgraph, node):
...
@@ -3237,7 +3237,7 @@ def local_view_op(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Alloc
])
@
node_rewrit
er
([
Alloc
])
def
local_merge_alloc
(
fgraph
,
node
):
def
local_merge_alloc
(
fgraph
,
node
):
# This opt takes care of several cases:
# This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
...
@@ -3274,7 +3274,7 @@ def local_merge_alloc(fgraph, node):
...
@@ -3274,7 +3274,7 @@ def local_merge_alloc(fgraph, node):
@register_useless
(
"fast_compile"
)
@register_useless
(
"fast_compile"
)
@
local_optimiz
er
([
TopKOp
])
@
node_rewrit
er
([
TopKOp
])
def
local_useless_topk
(
fgraph
,
node
):
def
local_useless_topk
(
fgraph
,
node
):
"""
"""
TopKOp generates two outputs by default
TopKOp generates two outputs by default
...
@@ -3310,7 +3310,7 @@ def local_useless_topk(fgraph, node):
...
@@ -3310,7 +3310,7 @@ def local_useless_topk(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
SpecifyShape
])
@
node_rewrit
er
([
SpecifyShape
])
def
local_merge_consecutive_specify_shape
(
fgraph
,
node
):
def
local_merge_consecutive_specify_shape
(
fgraph
,
node
):
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
"""Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``,
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
where s3 is the union of specified dimensions in s1 and s2, with preference given to s2.
...
@@ -3336,7 +3336,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
...
@@ -3336,7 +3336,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Shape
])
@
node_rewrit
er
([
Shape
])
def
local_Shape_of_SpecifyShape
(
fgraph
,
node
):
def
local_Shape_of_SpecifyShape
(
fgraph
,
node
):
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
"""Replace ``specify_shape(x, s).shape`` with ``s``."""
...
@@ -3360,7 +3360,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
...
@@ -3360,7 +3360,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Shape_i
])
@
node_rewrit
er
([
Shape_i
])
def
local_Shape_i_of_broadcastable
(
fgraph
,
node
):
def
local_Shape_i_of_broadcastable
(
fgraph
,
node
):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
...
@@ -3378,7 +3378,7 @@ def local_Shape_i_of_broadcastable(fgraph, node):
...
@@ -3378,7 +3378,7 @@ def local_Shape_i_of_broadcastable(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Unique
])
@
node_rewrit
er
([
Unique
])
def
local_Unique_scalar
(
fgraph
,
node
):
def
local_Unique_scalar
(
fgraph
,
node
):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if
not
isinstance
(
node
.
op
,
Unique
):
if
not
isinstance
(
node
.
op
,
Unique
):
...
@@ -3399,7 +3399,7 @@ def local_Unique_scalar(fgraph, node):
...
@@ -3399,7 +3399,7 @@ def local_Unique_scalar(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Unique
])
@
node_rewrit
er
([
Unique
])
def
local_Unique_Alloc_lift
(
fgraph
,
node
):
def
local_Unique_Alloc_lift
(
fgraph
,
node
):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
...
@@ -3432,7 +3432,7 @@ def local_Unique_Alloc_lift(fgraph, node):
...
@@ -3432,7 +3432,7 @@ def local_Unique_Alloc_lift(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Unique
])
@
node_rewrit
er
([
Unique
])
def
local_Unique_BroadcastTo_lift
(
fgraph
,
node
):
def
local_Unique_BroadcastTo_lift
(
fgraph
,
node
):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
...
@@ -3465,7 +3465,7 @@ def local_Unique_BroadcastTo_lift(fgraph, node):
...
@@ -3465,7 +3465,7 @@ def local_Unique_BroadcastTo_lift(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Unique
])
@
node_rewrit
er
([
Unique
])
def
local_Unique_Repeat_lift
(
fgraph
,
node
):
def
local_Unique_Repeat_lift
(
fgraph
,
node
):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
...
@@ -3498,7 +3498,7 @@ def local_Unique_Repeat_lift(fgraph, node):
...
@@ -3498,7 +3498,7 @@ def local_Unique_Repeat_lift(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Unique
])
@
node_rewrit
er
([
Unique
])
def
local_Unique_second
(
fgraph
,
node
):
def
local_Unique_second
(
fgraph
,
node
):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
...
@@ -3535,7 +3535,7 @@ def local_Unique_second(fgraph, node):
...
@@ -3535,7 +3535,7 @@ def local_Unique_second(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
BroadcastTo
])
@
node_rewrit
er
([
BroadcastTo
])
def
local_remove_scalar_BroadcastTo
(
fgraph
,
node
):
def
local_remove_scalar_BroadcastTo
(
fgraph
,
node
):
bcast_shape
=
node
.
inputs
[
1
:]
bcast_shape
=
node
.
inputs
[
1
:]
...
...
aesara/tensor/blas.py
浏览文件 @
550a6e98
...
@@ -150,7 +150,7 @@ from aesara.graph.opt import (
...
@@ -150,7 +150,7 @@ from aesara.graph.opt import (
GraphRewriter
,
GraphRewriter
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
local_optimiz
er
,
node_rewrit
er
,
)
)
from
aesara.graph.optdb
import
SequenceDB
from
aesara.graph.optdb
import
SequenceDB
from
aesara.graph.utils
import
InconsistencyError
,
MethodNotDefined
,
TestValueError
from
aesara.graph.utils
import
InconsistencyError
,
MethodNotDefined
,
TestValueError
...
@@ -1733,7 +1733,7 @@ class Dot22(GemmRelated):
...
@@ -1733,7 +1733,7 @@ class Dot22(GemmRelated):
_dot22
=
Dot22
()
_dot22
=
Dot22
()
@
local_optimiz
er
([
Dot
])
@
node_rewrit
er
([
Dot
])
def
local_dot_to_dot22
(
fgraph
,
node
):
def
local_dot_to_dot22
(
fgraph
,
node
):
# This works for tensor.outer too because basic.outer is a macro that
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
# produces a dot(dimshuffle,dimshuffle) of form 4 below
...
@@ -1766,7 +1766,7 @@ def local_dot_to_dot22(fgraph, node):
...
@@ -1766,7 +1766,7 @@ def local_dot_to_dot22(fgraph, node):
_logger
.
info
(
f
"Not optimizing dot with inputs {x} {y} {x.type} {y.type}"
)
_logger
.
info
(
f
"Not optimizing dot with inputs {x} {y} {x.type} {y.type}"
)
@
local_optimiz
er
([
gemm_no_inplace
],
inplace
=
True
)
@
node_rewrit
er
([
gemm_no_inplace
],
inplace
=
True
)
def
local_inplace_gemm
(
fgraph
,
node
):
def
local_inplace_gemm
(
fgraph
,
node
):
if
node
.
op
==
gemm_no_inplace
:
if
node
.
op
==
gemm_no_inplace
:
new_out
=
[
gemm_inplace
(
*
node
.
inputs
)]
new_out
=
[
gemm_inplace
(
*
node
.
inputs
)]
...
@@ -1774,7 +1774,7 @@ def local_inplace_gemm(fgraph, node):
...
@@ -1774,7 +1774,7 @@ def local_inplace_gemm(fgraph, node):
return
new_out
return
new_out
@
local_optimiz
er
([
gemv_no_inplace
],
inplace
=
True
)
@
node_rewrit
er
([
gemv_no_inplace
],
inplace
=
True
)
def
local_inplace_gemv
(
fgraph
,
node
):
def
local_inplace_gemv
(
fgraph
,
node
):
if
node
.
op
==
gemv_no_inplace
:
if
node
.
op
==
gemv_no_inplace
:
new_out
=
[
gemv_inplace
(
*
node
.
inputs
)]
new_out
=
[
gemv_inplace
(
*
node
.
inputs
)]
...
@@ -1782,7 +1782,7 @@ def local_inplace_gemv(fgraph, node):
...
@@ -1782,7 +1782,7 @@ def local_inplace_gemv(fgraph, node):
return
new_out
return
new_out
@
local_optimiz
er
([
ger
],
inplace
=
True
)
@
node_rewrit
er
([
ger
],
inplace
=
True
)
def
local_inplace_ger
(
fgraph
,
node
):
def
local_inplace_ger
(
fgraph
,
node
):
if
node
.
op
==
ger
:
if
node
.
op
==
ger
:
new_out
=
[
ger_destructive
(
*
node
.
inputs
)]
new_out
=
[
ger_destructive
(
*
node
.
inputs
)]
...
@@ -1790,7 +1790,7 @@ def local_inplace_ger(fgraph, node):
...
@@ -1790,7 +1790,7 @@ def local_inplace_ger(fgraph, node):
return
new_out
return
new_out
@
local_optimiz
er
([
gemm_no_inplace
])
@
node_rewrit
er
([
gemm_no_inplace
])
def
local_gemm_to_gemv
(
fgraph
,
node
):
def
local_gemm_to_gemv
(
fgraph
,
node
):
"""GEMM acting on row or column matrices -> GEMV."""
"""GEMM acting on row or column matrices -> GEMV."""
if
node
.
op
==
gemm_no_inplace
:
if
node
.
op
==
gemm_no_inplace
:
...
@@ -1807,7 +1807,7 @@ def local_gemm_to_gemv(fgraph, node):
...
@@ -1807,7 +1807,7 @@ def local_gemm_to_gemv(fgraph, node):
return
new_out
return
new_out
@
local_optimiz
er
([
gemm_no_inplace
])
@
node_rewrit
er
([
gemm_no_inplace
])
def
local_gemm_to_ger
(
fgraph
,
node
):
def
local_gemm_to_ger
(
fgraph
,
node
):
"""GEMM computing an outer-product -> GER."""
"""GEMM computing an outer-product -> GER."""
if
node
.
op
==
gemm_no_inplace
:
if
node
.
op
==
gemm_no_inplace
:
...
@@ -1839,7 +1839,7 @@ def local_gemm_to_ger(fgraph, node):
...
@@ -1839,7 +1839,7 @@ def local_gemm_to_ger(fgraph, node):
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working
# working
@
local_optimiz
er
([
_dot22
])
@
node_rewrit
er
([
_dot22
])
def
local_dot22_to_ger_or_gemv
(
fgraph
,
node
):
def
local_dot22_to_ger_or_gemv
(
fgraph
,
node
):
"""dot22 computing an outer-product -> GER."""
"""dot22 computing an outer-product -> GER."""
if
node
.
op
==
_dot22
:
if
node
.
op
==
_dot22
:
...
@@ -2033,7 +2033,7 @@ class Dot22Scalar(GemmRelated):
...
@@ -2033,7 +2033,7 @@ class Dot22Scalar(GemmRelated):
_dot22scalar
=
Dot22Scalar
()
_dot22scalar
=
Dot22Scalar
()
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_dot22_to_dot22scalar
(
fgraph
,
node
):
def
local_dot22_to_dot22scalar
(
fgraph
,
node
):
"""
"""
Notes
Notes
...
@@ -2651,7 +2651,7 @@ _batched_dot = BatchedDot()
...
@@ -2651,7 +2651,7 @@ _batched_dot = BatchedDot()
# from opt import register_specialize, register_canonicalize
# from opt import register_specialize, register_canonicalize
# @register_specialize
# @register_specialize
@
local_optimiz
er
([
sub
,
add
])
@
node_rewrit
er
([
sub
,
add
])
def
local_print_as_we_go_along
(
fgraph
,
node
):
def
local_print_as_we_go_along
(
fgraph
,
node
):
if
node
.
op
in
(
sub
,
add
):
if
node
.
op
in
(
sub
,
add
):
debugprint
(
node
)
debugprint
(
node
)
...
...
aesara/tensor/blas_c.py
浏览文件 @
550a6e98
...
@@ -15,7 +15,7 @@ from aesara.tensor.blas import (
...
@@ -15,7 +15,7 @@ from aesara.tensor.blas import (
ger
,
ger
,
ger_destructive
,
ger_destructive
,
ldflags
,
ldflags
,
local_optimiz
er
,
node_rewrit
er
,
optdb
,
optdb
,
)
)
...
@@ -344,7 +344,7 @@ cger_inplace = CGer(True)
...
@@ -344,7 +344,7 @@ cger_inplace = CGer(True)
cger_no_inplace
=
CGer
(
False
)
cger_no_inplace
=
CGer
(
False
)
@
local_optimiz
er
([
ger
,
ger_destructive
])
@
node_rewrit
er
([
ger
,
ger_destructive
])
def
use_c_ger
(
fgraph
,
node
):
def
use_c_ger
(
fgraph
,
node
):
if
not
config
.
blas__ldflags
:
if
not
config
.
blas__ldflags
:
return
return
...
@@ -355,7 +355,7 @@ def use_c_ger(fgraph, node):
...
@@ -355,7 +355,7 @@ def use_c_ger(fgraph, node):
return
[
CGer
(
True
)(
*
node
.
inputs
)]
return
[
CGer
(
True
)(
*
node
.
inputs
)]
@
local_optimiz
er
([
CGer
(
False
)])
@
node_rewrit
er
([
CGer
(
False
)])
def
make_c_ger_destructive
(
fgraph
,
node
):
def
make_c_ger_destructive
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
CGer
)
and
not
node
.
op
.
destructive
:
if
isinstance
(
node
.
op
,
CGer
)
and
not
node
.
op
.
destructive
:
return
[
cger_inplace
(
*
node
.
inputs
)]
return
[
cger_inplace
(
*
node
.
inputs
)]
...
@@ -699,7 +699,7 @@ int main() {
...
@@ -699,7 +699,7 @@ int main() {
check_force_gemv_init
.
_force_init_beta
=
None
check_force_gemv_init
.
_force_init_beta
=
None
@
local_optimiz
er
([
gemv_inplace
,
gemv_no_inplace
])
@
node_rewrit
er
([
gemv_inplace
,
gemv_no_inplace
])
def
use_c_gemv
(
fgraph
,
node
):
def
use_c_gemv
(
fgraph
,
node
):
if
not
config
.
blas__ldflags
:
if
not
config
.
blas__ldflags
:
return
return
...
@@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node):
...
@@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node):
return
[
cgemv_inplace
(
*
node
.
inputs
)]
return
[
cgemv_inplace
(
*
node
.
inputs
)]
@
local_optimiz
er
([
CGemv
(
inplace
=
False
)])
@
node_rewrit
er
([
CGemv
(
inplace
=
False
)])
def
make_c_gemv_destructive
(
fgraph
,
node
):
def
make_c_gemv_destructive
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
CGemv
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
CGemv
)
and
not
node
.
op
.
inplace
:
inputs
=
list
(
node
.
inputs
)
inputs
=
list
(
node
.
inputs
)
...
...
aesara/tensor/blas_scipy.py
浏览文件 @
550a6e98
...
@@ -11,7 +11,7 @@ from aesara.tensor.blas import (
...
@@ -11,7 +11,7 @@ from aesara.tensor.blas import (
ger
,
ger
,
ger_destructive
,
ger_destructive
,
have_fblas
,
have_fblas
,
local_optimiz
er
,
node_rewrit
er
,
optdb
,
optdb
,
)
)
...
@@ -58,13 +58,13 @@ scipy_ger_no_inplace = ScipyGer(False)
...
@@ -58,13 +58,13 @@ scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace
=
ScipyGer
(
True
)
scipy_ger_inplace
=
ScipyGer
(
True
)
@
local_optimiz
er
([
ger
,
ger_destructive
])
@
node_rewrit
er
([
ger
,
ger_destructive
])
def
use_scipy_ger
(
fgraph
,
node
):
def
use_scipy_ger
(
fgraph
,
node
):
if
node
.
op
==
ger
:
if
node
.
op
==
ger
:
return
[
scipy_ger_no_inplace
(
*
node
.
inputs
)]
return
[
scipy_ger_no_inplace
(
*
node
.
inputs
)]
@
local_optimiz
er
([
scipy_ger_no_inplace
])
@
node_rewrit
er
([
scipy_ger_no_inplace
])
def
make_ger_destructive
(
fgraph
,
node
):
def
make_ger_destructive
(
fgraph
,
node
):
if
node
.
op
==
scipy_ger_no_inplace
:
if
node
.
op
==
scipy_ger_no_inplace
:
return
[
scipy_ger_inplace
(
*
node
.
inputs
)]
return
[
scipy_ger_inplace
(
*
node
.
inputs
)]
...
...
aesara/tensor/math_opt.py
浏览文件 @
550a6e98
...
@@ -11,11 +11,11 @@ import aesara.scalar.math as aes_math
...
@@ -11,11 +11,11 @@ import aesara.scalar.math as aes_math
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.opt
import
(
from
aesara.graph.opt
import
(
LocalOptGroup
,
LocalOptGroup
,
LocalOptimiz
er
,
NodeRewrit
er
,
PatternSub
,
PatternSub
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
local_optimiz
er
,
node_rewrit
er
,
)
)
from
aesara.graph.opt_utils
import
get_clients_at_depth
from
aesara.graph.opt_utils
import
get_clients_at_depth
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.misc.safe_asarray
import
_asarray
...
@@ -148,7 +148,7 @@ def fill_chain(new_out, orig_inputs):
...
@@ -148,7 +148,7 @@ def fill_chain(new_out, orig_inputs):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
Dot
])
@
node_rewrit
er
([
Dot
])
def
local_0_dot_x
(
fgraph
,
node
):
def
local_0_dot_x
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
Dot
):
if
not
isinstance
(
node
.
op
,
Dot
):
return
False
return
False
...
@@ -185,7 +185,7 @@ def local_0_dot_x(fgraph, node):
...
@@ -185,7 +185,7 @@ def local_0_dot_x(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_lift_transpose_through_dot
(
fgraph
,
node
):
def
local_lift_transpose_through_dot
(
fgraph
,
node
):
"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``
"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``
...
@@ -229,7 +229,7 @@ def is_inverse_pair(node_op, prev_op, inv_pair):
...
@@ -229,7 +229,7 @@ def is_inverse_pair(node_op, prev_op, inv_pair):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_func_inv
(
fgraph
,
node
):
def
local_func_inv
(
fgraph
,
node
):
"""
"""
Check for two consecutive operations that are functional inverses
Check for two consecutive operations that are functional inverses
...
@@ -271,7 +271,7 @@ def local_func_inv(fgraph, node):
...
@@ -271,7 +271,7 @@ def local_func_inv(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_exp_log
(
fgraph
,
node
):
def
local_exp_log
(
fgraph
,
node
):
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -313,7 +313,7 @@ def local_exp_log(fgraph, node):
...
@@ -313,7 +313,7 @@ def local_exp_log(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_exp_log_nan_switch
(
fgraph
,
node
):
def
local_exp_log_nan_switch
(
fgraph
,
node
):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
@@ -371,7 +371,7 @@ def local_exp_log_nan_switch(fgraph, node):
...
@@ -371,7 +371,7 @@ def local_exp_log_nan_switch(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Sum
])
@
node_rewrit
er
([
Sum
])
def
local_sumsqr2dot
(
fgraph
,
node
):
def
local_sumsqr2dot
(
fgraph
,
node
):
"""
"""
This optimization detects
This optimization detects
...
@@ -418,7 +418,7 @@ def local_sumsqr2dot(fgraph, node):
...
@@ -418,7 +418,7 @@ def local_sumsqr2dot(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_expm1
(
fgraph
,
node
):
def
local_expm1
(
fgraph
,
node
):
"""
"""
This optimization detects exp(a)-1 and converts this to expm1(a).
This optimization detects exp(a)-1 and converts this to expm1(a).
...
@@ -446,7 +446,7 @@ def local_expm1(fgraph, node):
...
@@ -446,7 +446,7 @@ def local_expm1(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_mul_switch_sink
(
fgraph
,
node
):
def
local_mul_switch_sink
(
fgraph
,
node
):
"""
"""
This optimization makes the following changes in the graph:
This optimization makes the following changes in the graph:
...
@@ -540,7 +540,7 @@ def local_mul_switch_sink(fgraph, node):
...
@@ -540,7 +540,7 @@ def local_mul_switch_sink(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
true_div
,
int_div
])
@
node_rewrit
er
([
true_div
,
int_div
])
def
local_div_switch_sink
(
fgraph
,
node
):
def
local_div_switch_sink
(
fgraph
,
node
):
"""
"""
This optimization makes the following changes in the graph:
This optimization makes the following changes in the graph:
...
@@ -616,33 +616,33 @@ def local_div_switch_sink(fgraph, node):
...
@@ -616,33 +616,33 @@ def local_div_switch_sink(fgraph, node):
return
False
return
False
class
AlgebraicCanonizer
(
LocalOptimiz
er
):
class
AlgebraicCanonizer
(
NodeRewrit
er
):
r"""
Simplification tool
.
r"""
A `Rewriter` that rewrites algebraic expressions
.
The variable is a `
`local_optimizer`
`. It is best used
The variable is a `
node_rewriter
`. It is best used
with a `
`TopoOptimizer`` in ``in_to_out``
order.
with a `
TopoOptimizer` in in-to-out
order.
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)``
Parameters
Parameters
----------
----------
main
main
A suitable `
`Op`
` class that is commutative, associative and
A suitable `
Op
` class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or
takes one to an arbitrary number of inputs, e.g. add or
mul
mul
inverse
inverse
An `
`Op`
` class such that ``inverse(main(x, y), y) == x``
An `
Op
` class such that ``inverse(main(x, y), y) == x``
e.g. ``sub`` or true_div
(e.g. `sub` or `true_div`).
reciprocal
reciprocal
A function such that ``main(x, reciprocal(y)) == inverse(x, y)``
A function such that ``main(x, reciprocal(y)) == inverse(x, y)``
e.g. ``neg`` or ``reciprocal``
(e.g. `neg` or `reciprocal`).
calculate
calculate
Function that takes a list of
numpy.ndarray
instances
Function that takes a list of
`numpy.ndarray`
instances
for the numerator, another list for the denumerator,
for the numerator, another list for the denumerator,
and calculates ``inverse(main(\*num), main(\*denum))``. It
and calculates ``inverse(main(\*num), main(\*denum))``. It
takes a keyword argument,
aslist. If True
, the value
takes a keyword argument,
`aslist`. If ``True``
, the value
should be returned as a list of one element, unless
should be returned as a list of one element, unless
the value is such that
value = main()
. In that case,
the value is such that
``value = main()``
. In that case,
the return value should be an empty list.
the return value should be an empty list.
Examples
Examples
...
@@ -654,7 +654,7 @@ class AlgebraicCanonizer(LocalOptimizer):
...
@@ -654,7 +654,7 @@ class AlgebraicCanonizer(LocalOptimizer):
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
>>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\
... lambda n, d: prod(n) / prod(d))
... lambda n, d: prod(n) / prod(d))
Examples of optimizations `
`mul_canonizer`
` can perform:
Examples of optimizations `
mul_canonizer
` can perform:
| x / x -> 1
| x / x -> 1
| (x * y) / x -> y
| (x * y) / x -> y
...
@@ -1082,14 +1082,14 @@ register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
...
@@ -1082,14 +1082,14 @@ register_canonicalize(local_mul_canonizer, name="local_mul_canonizer")
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
neg
])
@
node_rewrit
er
([
neg
])
def
local_neg_to_mul
(
fgraph
,
node
):
def
local_neg_to_mul
(
fgraph
,
node
):
if
node
.
op
==
neg
:
if
node
.
op
==
neg
:
return
[
mul
(
np
.
array
(
-
1
,
dtype
=
node
.
inputs
[
0
]
.
dtype
),
node
.
inputs
[
0
])]
return
[
mul
(
np
.
array
(
-
1
,
dtype
=
node
.
inputs
[
0
]
.
dtype
),
node
.
inputs
[
0
])]
@register_specialize
@register_specialize
@
local_optimiz
er
([
Sum
,
Prod
])
@
node_rewrit
er
([
Sum
,
Prod
])
def
local_sum_prod_mul_by_scalar
(
fgraph
,
node
):
def
local_sum_prod_mul_by_scalar
(
fgraph
,
node
):
"""
"""
sum(scalar * smth) -> scalar * sum(smth)
sum(scalar * smth) -> scalar * sum(smth)
...
@@ -1175,7 +1175,7 @@ def local_sum_prod_mul_by_scalar(fgraph, node):
...
@@ -1175,7 +1175,7 @@ def local_sum_prod_mul_by_scalar(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_elemwise_sub_zeros
(
fgraph
,
node
):
def
local_elemwise_sub_zeros
(
fgraph
,
node
):
"""
"""
Elemwise{sub}(X,X) -> zeros_like(X)
Elemwise{sub}(X,X) -> zeros_like(X)
...
@@ -1197,7 +1197,7 @@ def local_elemwise_sub_zeros(fgraph, node):
...
@@ -1197,7 +1197,7 @@ def local_elemwise_sub_zeros(fgraph, node):
@register_specialize
@register_specialize
@register_stabilize
@register_stabilize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_useless_elemwise_comparison
(
fgraph
,
node
):
def
local_useless_elemwise_comparison
(
fgraph
,
node
):
"""...
"""...
...
@@ -1407,7 +1407,7 @@ def local_useless_elemwise_comparison(fgraph, node):
...
@@ -1407,7 +1407,7 @@ def local_useless_elemwise_comparison(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Sum
,
Prod
])
@
node_rewrit
er
([
Sum
,
Prod
])
def
local_sum_prod_div_dimshuffle
(
fgraph
,
node
):
def
local_sum_prod_div_dimshuffle
(
fgraph
,
node
):
"""
"""
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
...
@@ -1499,7 +1499,7 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
...
@@ -1499,7 +1499,7 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Sum
,
Prod
])
@
node_rewrit
er
([
Sum
,
Prod
])
def
local_sum_prod_all_to_none
(
fgraph
,
node
):
def
local_sum_prod_all_to_none
(
fgraph
,
node
):
"""
"""
Sum{0,1,...N} -> Sum{} or
Sum{0,1,...N} -> Sum{} or
...
@@ -1517,7 +1517,7 @@ def local_sum_prod_all_to_none(fgraph, node):
...
@@ -1517,7 +1517,7 @@ def local_sum_prod_all_to_none(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Sum
,
Prod
])
@
node_rewrit
er
([
Sum
,
Prod
])
def
local_op_of_op
(
fgraph
,
node
):
def
local_op_of_op
(
fgraph
,
node
):
"""
"""
Prod(Prod()) -> single Prod()
Prod(Prod()) -> single Prod()
...
@@ -1573,7 +1573,7 @@ ALL_REDUCE = (
...
@@ -1573,7 +1573,7 @@ ALL_REDUCE = (
@register_canonicalize
@register_canonicalize
@register_uncanonicalize
# Needed for MaxAndArgmax -> CAReduce
@register_uncanonicalize
# Needed for MaxAndArgmax -> CAReduce
@
local_optimiz
er
(
ALL_REDUCE
)
@
node_rewrit
er
(
ALL_REDUCE
)
def
local_reduce_join
(
fgraph
,
node
):
def
local_reduce_join
(
fgraph
,
node
):
"""
"""
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
...
@@ -1645,7 +1645,7 @@ def local_reduce_join(fgraph, node):
...
@@ -1645,7 +1645,7 @@ def local_reduce_join(fgraph, node):
@register_canonicalize
(
"fast_compile"
,
"local_cut_useless_reduce"
)
@register_canonicalize
(
"fast_compile"
,
"local_cut_useless_reduce"
)
@register_useless
(
"local_cut_useless_reduce"
)
@register_useless
(
"local_cut_useless_reduce"
)
@
local_optimiz
er
(
ALL_REDUCE
)
@
node_rewrit
er
(
ALL_REDUCE
)
def
local_useless_reduce
(
fgraph
,
node
):
def
local_useless_reduce
(
fgraph
,
node
):
"""Sum(a, axis=[]) -> a"""
"""Sum(a, axis=[]) -> a"""
if
isinstance
(
node
.
op
,
CAReduce
):
if
isinstance
(
node
.
op
,
CAReduce
):
...
@@ -1658,7 +1658,7 @@ def local_useless_reduce(fgraph, node):
...
@@ -1658,7 +1658,7 @@ def local_useless_reduce(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_uncanonicalize
@register_uncanonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
(
ALL_REDUCE
)
@
node_rewrit
er
(
ALL_REDUCE
)
def
local_reduce_broadcastable
(
fgraph
,
node
):
def
local_reduce_broadcastable
(
fgraph
,
node
):
"""Remove reduction over broadcastable dimensions."""
"""Remove reduction over broadcastable dimensions."""
if
isinstance
(
node
.
op
,
CAReduce
):
if
isinstance
(
node
.
op
,
CAReduce
):
...
@@ -1700,7 +1700,7 @@ def local_reduce_broadcastable(fgraph, node):
...
@@ -1700,7 +1700,7 @@ def local_reduce_broadcastable(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
Sum
,
Prod
])
@
node_rewrit
er
([
Sum
,
Prod
])
def
local_opt_alloc
(
fgraph
,
node
):
def
local_opt_alloc
(
fgraph
,
node
):
"""
"""
sum(alloc(constant,shapes...)) => constant*prod(shapes)
sum(alloc(constant,shapes...)) => constant*prod(shapes)
...
@@ -1764,7 +1764,7 @@ def local_opt_alloc(fgraph, node):
...
@@ -1764,7 +1764,7 @@ def local_opt_alloc(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
neg
])
@
node_rewrit
er
([
neg
])
def
local_neg_div_neg
(
fgraph
,
node
):
def
local_neg_div_neg
(
fgraph
,
node
):
"""
"""
- (-a / b) -> a / b
- (-a / b) -> a / b
...
@@ -1788,7 +1788,7 @@ def local_neg_div_neg(fgraph, node):
...
@@ -1788,7 +1788,7 @@ def local_neg_div_neg(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_mul_zero
(
fgraph
,
node
):
def
local_mul_zero
(
fgraph
,
node
):
"""
"""
As part of canonicalization, we replace multiplication by zero
As part of canonicalization, we replace multiplication by zero
...
@@ -1811,7 +1811,7 @@ def local_mul_zero(fgraph, node):
...
@@ -1811,7 +1811,7 @@ def local_mul_zero(fgraph, node):
# TODO: Add this to the canonicalization to reduce redundancy.
# TODO: Add this to the canonicalization to reduce redundancy.
@register_specialize
@register_specialize
@
local_optimiz
er
([
true_div
])
@
node_rewrit
er
([
true_div
])
def
local_div_to_reciprocal
(
fgraph
,
node
):
def
local_div_to_reciprocal
(
fgraph
,
node
):
if
node
.
op
==
true_div
and
np
.
all
(
get_constant
(
node
.
inputs
[
0
])
==
1.0
):
if
node
.
op
==
true_div
and
np
.
all
(
get_constant
(
node
.
inputs
[
0
])
==
1.0
):
out
=
node
.
outputs
[
0
]
out
=
node
.
outputs
[
0
]
...
@@ -1828,7 +1828,7 @@ def local_div_to_reciprocal(fgraph, node):
...
@@ -1828,7 +1828,7 @@ def local_div_to_reciprocal(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
reciprocal
])
@
node_rewrit
er
([
reciprocal
])
def
local_reciprocal_canon
(
fgraph
,
node
):
def
local_reciprocal_canon
(
fgraph
,
node
):
if
node
.
op
==
reciprocal
:
if
node
.
op
==
reciprocal
:
return
[
at_pow
(
node
.
inputs
[
0
],
-
1.0
)]
return
[
at_pow
(
node
.
inputs
[
0
],
-
1.0
)]
...
@@ -1837,7 +1837,7 @@ def local_reciprocal_canon(fgraph, node):
...
@@ -1837,7 +1837,7 @@ def local_reciprocal_canon(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
at_pow
])
@
node_rewrit
er
([
at_pow
])
def
local_pow_canonicalize
(
fgraph
,
node
):
def
local_pow_canonicalize
(
fgraph
,
node
):
if
node
.
op
==
at_pow
:
if
node
.
op
==
at_pow
:
cst
=
get_constant
(
node
.
inputs
[
1
])
cst
=
get_constant
(
node
.
inputs
[
1
])
...
@@ -1850,7 +1850,7 @@ def local_pow_canonicalize(fgraph, node):
...
@@ -1850,7 +1850,7 @@ def local_pow_canonicalize(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_mul_to_sqr
(
fgraph
,
node
):
def
local_mul_to_sqr
(
fgraph
,
node
):
"""
"""
x*x -> sqr(x)
x*x -> sqr(x)
...
@@ -1862,7 +1862,7 @@ def local_mul_to_sqr(fgraph, node):
...
@@ -1862,7 +1862,7 @@ def local_mul_to_sqr(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
int_div
])
@
node_rewrit
er
([
int_div
])
def
local_intdiv_by_one
(
fgraph
,
node
):
def
local_intdiv_by_one
(
fgraph
,
node
):
"""x // 1 -> x"""
"""x // 1 -> x"""
if
node
.
op
in
[
int_div
]:
if
node
.
op
in
[
int_div
]:
...
@@ -1874,7 +1874,7 @@ def local_intdiv_by_one(fgraph, node):
...
@@ -1874,7 +1874,7 @@ def local_intdiv_by_one(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
int_div
,
true_div
])
@
node_rewrit
er
([
int_div
,
true_div
])
def
local_zero_div
(
fgraph
,
node
):
def
local_zero_div
(
fgraph
,
node
):
"""0 / x -> 0"""
"""0 / x -> 0"""
if
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
if
isinstance
(
node
.
op
,
Elemwise
)
and
isinstance
(
...
@@ -1887,7 +1887,7 @@ def local_zero_div(fgraph, node):
...
@@ -1887,7 +1887,7 @@ def local_zero_div(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
at_pow
])
@
node_rewrit
er
([
at_pow
])
def
local_pow_specialize
(
fgraph
,
node
):
def
local_pow_specialize
(
fgraph
,
node
):
# here, we are past the point of canonicalization, so we don't want
# here, we are past the point of canonicalization, so we don't want
# to put in un-necessary fills.
# to put in un-necessary fills.
...
@@ -1925,7 +1925,7 @@ def local_pow_specialize(fgraph, node):
...
@@ -1925,7 +1925,7 @@ def local_pow_specialize(fgraph, node):
@register_specialize_device
@register_specialize_device
@
local_optimiz
er
([
at_pow
])
@
node_rewrit
er
([
at_pow
])
def
local_pow_specialize_device
(
fgraph
,
node
):
def
local_pow_specialize_device
(
fgraph
,
node
):
"""
"""
This optimization is not the same on all device. We do it only on cpu here.
This optimization is not the same on all device. We do it only on cpu here.
...
@@ -1992,7 +1992,7 @@ def local_pow_specialize_device(fgraph, node):
...
@@ -1992,7 +1992,7 @@ def local_pow_specialize_device(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_mul_specialize
(
fgraph
,
node
):
def
local_mul_specialize
(
fgraph
,
node
):
"""
"""
Remove special-case constants from mul arguments and useless neg in inputs.
Remove special-case constants from mul arguments and useless neg in inputs.
...
@@ -2068,7 +2068,7 @@ def local_mul_specialize(fgraph, node):
...
@@ -2068,7 +2068,7 @@ def local_mul_specialize(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
local_add_specialize
(
fgraph
,
node
):
def
local_add_specialize
(
fgraph
,
node
):
"""Remove zeros from ``add``s.
"""Remove zeros from ``add``s.
...
@@ -2147,7 +2147,7 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX")
...
@@ -2147,7 +2147,7 @@ local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX")
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
at_abs
])
@
node_rewrit
er
([
at_abs
])
def
local_abs_lift
(
fgraph
,
node
):
def
local_abs_lift
(
fgraph
,
node
):
"""
"""
Move the abs toward the input.
Move the abs toward the input.
...
@@ -2165,7 +2165,7 @@ def local_abs_lift(fgraph, node):
...
@@ -2165,7 +2165,7 @@ def local_abs_lift(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
mul
,
true_div
])
@
node_rewrit
er
([
mul
,
true_div
])
def
local_abs_merge
(
fgraph
,
node
):
def
local_abs_merge
(
fgraph
,
node
):
"""
"""
Merge abs generated by local_abs_lift when the canonizer don't
Merge abs generated by local_abs_lift when the canonizer don't
...
@@ -2201,7 +2201,7 @@ def local_abs_merge(fgraph, node):
...
@@ -2201,7 +2201,7 @@ def local_abs_merge(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
log
])
@
node_rewrit
er
([
log
])
def
local_log1p
(
fgraph
,
node
):
def
local_log1p
(
fgraph
,
node
):
# log(1+x) -> log1p(x)
# log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
# log(1-x) -> log1p(-x)
...
@@ -2234,7 +2234,7 @@ def local_log1p(fgraph, node):
...
@@ -2234,7 +2234,7 @@ def local_log1p(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
log
])
@
node_rewrit
er
([
log
])
def
local_log_add_exp
(
fgraph
,
node
):
def
local_log_add_exp
(
fgraph
,
node
):
"""
"""
``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)``
``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)``
...
@@ -2266,7 +2266,7 @@ def local_log_add_exp(fgraph, node):
...
@@ -2266,7 +2266,7 @@ def local_log_add_exp(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
log
])
@
node_rewrit
er
([
log
])
def
local_log_sum_exp
(
fgraph
,
node
):
def
local_log_sum_exp
(
fgraph
,
node
):
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
# log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max)))
...
@@ -2435,7 +2435,7 @@ def attempt_distribution(factor, num, denum, out_type):
...
@@ -2435,7 +2435,7 @@ def attempt_distribution(factor, num, denum, out_type):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
mul
,
true_div
,
reciprocal
])
@
node_rewrit
er
([
mul
,
true_div
,
reciprocal
])
def
local_greedy_distributor
(
fgraph
,
node
):
def
local_greedy_distributor
(
fgraph
,
node
):
"""
"""
Optimize by reducing the number of multiplications and/or divisions.
Optimize by reducing the number of multiplications and/or divisions.
...
@@ -2609,7 +2609,7 @@ register_specialize(local_erf_neg_minus_one)
...
@@ -2609,7 +2609,7 @@ register_specialize(local_erf_neg_minus_one)
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
log
])
@
node_rewrit
er
([
log
])
def
local_log_erfc
(
fgraph
,
node
):
def
local_log_erfc
(
fgraph
,
node
):
"""Stability optimization for `log(erfc(x))`.
"""Stability optimization for `log(erfc(x))`.
...
@@ -2652,7 +2652,7 @@ def local_log_erfc(fgraph, node):
...
@@ -2652,7 +2652,7 @@ def local_log_erfc(fgraph, node):
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
true_div
])
@
node_rewrit
er
([
true_div
])
def
local_grad_log_erfc_neg
(
fgraph
,
node
):
def
local_grad_log_erfc_neg
(
fgraph
,
node
):
"""Stability optimization for the grad of `log(erfc(x))`.
"""Stability optimization for the grad of `log(erfc(x))`.
...
@@ -3093,7 +3093,7 @@ def is_neg(var):
...
@@ -3093,7 +3093,7 @@ def is_neg(var):
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
true_div
])
@
node_rewrit
er
([
true_div
])
def
local_exp_over_1_plus_exp
(
fgraph
,
node
):
def
local_exp_over_1_plus_exp
(
fgraph
,
node
):
"""
"""
exp(x)/(1+exp(x)) -> sigm(x)
exp(x)/(1+exp(x)) -> sigm(x)
...
@@ -3447,7 +3447,7 @@ def perform_sigm_times_exp(
...
@@ -3447,7 +3447,7 @@ def perform_sigm_times_exp(
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_sigm_times_exp
(
fgraph
,
node
):
def
local_sigm_times_exp
(
fgraph
,
node
):
"""
"""
exp(x) * sigm(-x) -> sigm(x)
exp(x) * sigm(-x) -> sigm(x)
...
@@ -3476,7 +3476,7 @@ def local_sigm_times_exp(fgraph, node):
...
@@ -3476,7 +3476,7 @@ def local_sigm_times_exp(fgraph, node):
@register_stabilize
@register_stabilize
@
local_optimiz
er
([
reciprocal
])
@
node_rewrit
er
([
reciprocal
])
def
local_reciprocal_1_plus_exp
(
fgraph
,
node
):
def
local_reciprocal_1_plus_exp
(
fgraph
,
node
):
"""``reciprocal(1+exp(x)) -> sigm(-x)``
"""``reciprocal(1+exp(x)) -> sigm(-x)``
...
@@ -3558,7 +3558,7 @@ register_specialize(local_sigmoid_logit)
...
@@ -3558,7 +3558,7 @@ register_specialize(local_sigmoid_logit)
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
@
local_optimiz
er
([
_conj
])
@
node_rewrit
er
([
_conj
])
def
local_useless_conj
(
fgraph
,
node
):
def
local_useless_conj
(
fgraph
,
node
):
r"""Remove `conj` `Op`\s applied to non-imaginary variable types."""
r"""Remove `conj` `Op`\s applied to non-imaginary variable types."""
x
=
node
.
inputs
[
0
]
x
=
node
.
inputs
[
0
]
...
...
aesara/tensor/nnet/basic.py
浏览文件 @
550a6e98
...
@@ -18,7 +18,7 @@ from aesara.compile import optdb
...
@@ -18,7 +18,7 @@ from aesara.compile import optdb
from
aesara.gradient
import
DisconnectedType
,
grad_not_implemented
from
aesara.gradient
import
DisconnectedType
,
grad_not_implemented
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
,
optimizer
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
,
optimizer
from
aesara.link.c.op
import
COp
from
aesara.link.c.op
import
COp
from
aesara.raise_op
import
Assert
from
aesara.raise_op
import
Assert
from
aesara.scalar
import
UnaryScalarOp
from
aesara.scalar
import
UnaryScalarOp
...
@@ -1046,7 +1046,7 @@ class LogSoftmax(COp):
...
@@ -1046,7 +1046,7 @@ class LogSoftmax(COp):
# This is not registered in stabilize, as it cause some crossentropy
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
# optimization to not be inserted.
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_logsoftmax
(
fgraph
,
node
):
def
local_logsoftmax
(
fgraph
,
node
):
"""
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
...
@@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node):
...
@@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node):
# This is not registered in stabilize, as it cause some crossentropy
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
# optimization to not be inserted.
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@
local_optimiz
er
([
SoftmaxGrad
])
@
node_rewrit
er
([
SoftmaxGrad
])
def
local_logsoftmax_grad
(
fgraph
,
node
):
def
local_logsoftmax_grad
(
fgraph
,
node
):
"""
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
...
@@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS):
...
@@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS):
@register_specialize
(
"fast_compile"
)
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
softmax_legacy
])
@
node_rewrit
er
([
softmax_legacy
])
def
local_softmax_with_bias
(
fgraph
,
node
):
def
local_softmax_with_bias
(
fgraph
,
node
):
"""
"""
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
...
@@ -1954,7 +1954,7 @@ optdb.register(
...
@@ -1954,7 +1954,7 @@ optdb.register(
@register_specialize
(
@register_specialize
(
"fast_compile"
,
"local_crossentropy_to_crossentropy_with_softmax_grad"
"fast_compile"
,
"local_crossentropy_to_crossentropy_with_softmax_grad"
)
# old name
)
# old name
@
local_optimiz
er
([
softmax_grad_legacy
])
@
node_rewrit
er
([
softmax_grad_legacy
])
def
local_softmax_grad_to_crossentropy_with_softmax_grad
(
fgraph
,
node
):
def
local_softmax_grad_to_crossentropy_with_softmax_grad
(
fgraph
,
node
):
if
node
.
op
==
softmax_grad_legacy
and
node
.
inputs
[
1
]
.
ndim
==
2
:
if
node
.
op
==
softmax_grad_legacy
and
node
.
inputs
[
1
]
.
ndim
==
2
:
g_coding_dist
,
coding_dist
=
node
.
inputs
g_coding_dist
,
coding_dist
=
node
.
inputs
...
@@ -1971,7 +1971,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
...
@@ -1971,7 +1971,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
@register_specialize
(
"fast_compile"
)
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
MaxAndArgmax
])
@
node_rewrit
er
([
MaxAndArgmax
])
def
local_argmax_pushdown
(
fgraph
,
node
):
def
local_argmax_pushdown
(
fgraph
,
node
):
if
(
if
(
isinstance
(
node
.
op
,
MaxAndArgmax
)
isinstance
(
node
.
op
,
MaxAndArgmax
)
...
@@ -2060,7 +2060,7 @@ def _is_const(z, val, approx=False):
...
@@ -2060,7 +2060,7 @@ def _is_const(z, val, approx=False):
@register_specialize
(
"fast_compile"
)
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
AdvancedSubtensor
,
log
])
@
node_rewrit
er
([
AdvancedSubtensor
,
log
])
def
local_advanced_indexing_crossentropy_onehot
(
fgraph
,
node
):
def
local_advanced_indexing_crossentropy_onehot
(
fgraph
,
node
):
log_op
=
None
log_op
=
None
sm
=
None
sm
=
None
...
@@ -2108,7 +2108,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
...
@@ -2108,7 +2108,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
@register_specialize
(
"fast_compile"
)
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
softmax_grad_legacy
])
@
node_rewrit
er
([
softmax_grad_legacy
])
def
local_advanced_indexing_crossentropy_onehot_grad
(
fgraph
,
node
):
def
local_advanced_indexing_crossentropy_onehot_grad
(
fgraph
,
node
):
if
not
(
node
.
op
==
softmax_grad_legacy
and
node
.
inputs
[
1
]
.
ndim
==
2
):
if
not
(
node
.
op
==
softmax_grad_legacy
and
node
.
inputs
[
1
]
.
ndim
==
2
):
return
return
...
@@ -2323,7 +2323,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
...
@@ -2323,7 +2323,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
@register_specialize
(
"fast_compile"
)
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
softmax_with_bias
])
@
node_rewrit
er
([
softmax_with_bias
])
def
graph_merge_softmax_with_crossentropy_softmax
(
fgraph
,
node
):
def
graph_merge_softmax_with_crossentropy_softmax
(
fgraph
,
node
):
if
node
.
op
==
softmax_with_bias
:
if
node
.
op
==
softmax_with_bias
:
x
,
b
=
node
.
inputs
x
,
b
=
node
.
inputs
...
@@ -2340,7 +2340,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
...
@@ -2340,7 +2340,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
@register_specialize
@register_specialize
@register_stabilize
@register_stabilize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
CrossentropySoftmax1HotWithBiasDx
])
@
node_rewrit
er
([
CrossentropySoftmax1HotWithBiasDx
])
def
local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc
(
fgraph
,
node
):
def
local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc
(
fgraph
,
node
):
"""
"""
Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is
Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is
...
...
aesara/tensor/nnet/batchnorm.py
浏览文件 @
550a6e98
...
@@ -4,7 +4,7 @@ import aesara
...
@@ -4,7 +4,7 @@ import aesara
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
from
aesara.scalar
import
Composite
,
add
,
as_common_dtype
,
mul
,
sub
,
true_div
from
aesara.scalar
import
Composite
,
add
,
as_common_dtype
,
mul
,
sub
,
true_div
from
aesara.tensor
import
basic
as
at
from
aesara.tensor
import
basic
as
at
from
aesara.tensor.basic
import
as_tensor_variable
from
aesara.tensor.basic
import
as_tensor_variable
...
@@ -778,7 +778,7 @@ class AbstractBatchNormTrainGrad(Op):
...
@@ -778,7 +778,7 @@ class AbstractBatchNormTrainGrad(Op):
output_storage
[
2
][
0
]
=
g_wrt_bias
output_storage
[
2
][
0
]
=
g_wrt_bias
@
local_optimiz
er
([
AbstractBatchNormTrain
])
@
node_rewrit
er
([
AbstractBatchNormTrain
])
def
local_abstract_batch_norm_train
(
fgraph
,
node
):
def
local_abstract_batch_norm_train
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormTrain
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormTrain
):
return
None
return
None
...
@@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node):
...
@@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node):
return
results
return
results
@
local_optimiz
er
([
AbstractBatchNormTrainGrad
])
@
node_rewrit
er
([
AbstractBatchNormTrainGrad
])
def
local_abstract_batch_norm_train_grad
(
fgraph
,
node
):
def
local_abstract_batch_norm_train_grad
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormTrainGrad
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormTrainGrad
):
return
None
return
None
...
@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
...
@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
return
results
return
results
@
local_optimiz
er
([
AbstractBatchNormInference
])
@
node_rewrit
er
([
AbstractBatchNormInference
])
def
local_abstract_batch_norm_inference
(
fgraph
,
node
):
def
local_abstract_batch_norm_inference
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormInference
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormInference
):
return
None
return
None
...
...
aesara/tensor/nnet/conv3d2d.py
浏览文件 @
550a6e98
...
@@ -3,7 +3,7 @@ from aesara import tensor as at
...
@@ -3,7 +3,7 @@ from aesara import tensor as at
from
aesara.gradient
import
DisconnectedType
from
aesara.gradient
import
DisconnectedType
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
node_rewrit
er
def
get_diagonal_subtensor_view
(
x
,
i0
,
i1
):
def
get_diagonal_subtensor_view
(
x
,
i0
,
i1
):
...
@@ -296,7 +296,7 @@ def conv3d(
...
@@ -296,7 +296,7 @@ def conv3d(
return
out_5d
return
out_5d
@
local_optimiz
er
([
DiagonalSubtensor
,
IncDiagonalSubtensor
])
@
node_rewrit
er
([
DiagonalSubtensor
,
IncDiagonalSubtensor
])
def
local_inplace_DiagonalSubtensor
(
fgraph
,
node
):
def
local_inplace_DiagonalSubtensor
(
fgraph
,
node
):
"""Also work for IncDiagonalSubtensor."""
"""Also work for IncDiagonalSubtensor."""
if
(
if
(
...
...
aesara/tensor/nnet/ctc.py
浏览文件 @
550a6e98
...
@@ -5,7 +5,7 @@ import aesara.tensor as at
...
@@ -5,7 +5,7 @@ import aesara.tensor as at
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.gradient
import
grad_undefined
from
aesara.gradient
import
grad_undefined
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.opt
import
local_optimiz
er
from
aesara.graph.opt
import
node_rewrit
er
from
aesara.link.c.cmodule
import
GCC_compiler
from
aesara.link.c.cmodule
import
GCC_compiler
from
aesara.link.c.op
import
ExternalCOp
,
OpenMPOp
from
aesara.link.c.op
import
ExternalCOp
,
OpenMPOp
from
aesara.tensor.basic_opt
import
register_canonicalize
from
aesara.tensor.basic_opt
import
register_canonicalize
...
@@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths):
...
@@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed
# Disable gradient computation if not needed
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@
local_optimiz
er
([
ConnectionistTemporalClassification
])
@
node_rewrit
er
([
ConnectionistTemporalClassification
])
def
local_ctc_no_grad
(
fgraph
,
node
):
def
local_ctc_no_grad
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
ConnectionistTemporalClassification
):
if
isinstance
(
node
.
op
,
ConnectionistTemporalClassification
):
if
len
(
node
.
outputs
)
>
1
:
if
len
(
node
.
outputs
)
>
1
:
...
...
aesara/tensor/nnet/opt.py
浏览文件 @
550a6e98
...
@@ -11,7 +11,7 @@ from aesara.graph.opt import (
...
@@ -11,7 +11,7 @@ from aesara.graph.opt import (
TopoOptimizer
,
TopoOptimizer
,
copy_stack_trace
,
copy_stack_trace
,
in2out
,
in2out
,
local_optimiz
er
,
node_rewrit
er
,
)
)
from
aesara.tensor.basic_opt
import
register_specialize_device
from
aesara.tensor.basic_opt
import
register_specialize_device
from
aesara.tensor.nnet.abstract_conv
import
(
from
aesara.tensor.nnet.abstract_conv
import
(
...
@@ -37,7 +37,7 @@ from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGrad
...
@@ -37,7 +37,7 @@ from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGrad
from
aesara.tensor.type
import
TensorType
from
aesara.tensor.type
import
TensorType
@
local_optimiz
er
([
SparseBlockGemv
],
inplace
=
True
)
@
node_rewrit
er
([
SparseBlockGemv
],
inplace
=
True
)
def
local_inplace_sparse_block_gemv
(
fgraph
,
node
):
def
local_inplace_sparse_block_gemv
(
fgraph
,
node
):
"""
"""
SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True)
SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True)
...
@@ -60,7 +60,7 @@ compile.optdb.register(
...
@@ -60,7 +60,7 @@ compile.optdb.register(
)
# DEBUG
)
# DEBUG
@
local_optimiz
er
([
SparseBlockOuter
],
inplace
=
True
)
@
node_rewrit
er
([
SparseBlockOuter
],
inplace
=
True
)
def
local_inplace_sparse_block_outer
(
fgraph
,
node
):
def
local_inplace_sparse_block_outer
(
fgraph
,
node
):
"""
"""
SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True)
SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True)
...
@@ -85,7 +85,7 @@ compile.optdb.register(
...
@@ -85,7 +85,7 @@ compile.optdb.register(
# Conv opts
# Conv opts
@
local_optimiz
er
([
AbstractConv2d
])
@
node_rewrit
er
([
AbstractConv2d
])
def
local_abstractconv_gemm
(
fgraph
,
node
):
def
local_abstractconv_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
# a NumPy C implementation of [sd]gemm_.
...
@@ -113,7 +113,7 @@ def local_abstractconv_gemm(fgraph, node):
...
@@ -113,7 +113,7 @@ def local_abstractconv_gemm(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv3d
])
@
node_rewrit
er
([
AbstractConv3d
])
def
local_abstractconv3d_gemm
(
fgraph
,
node
):
def
local_abstractconv3d_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
# a NumPy C implementation of [sd]gemm_.
...
@@ -139,7 +139,7 @@ def local_abstractconv3d_gemm(fgraph, node):
...
@@ -139,7 +139,7 @@ def local_abstractconv3d_gemm(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d_gradWeights
])
@
node_rewrit
er
([
AbstractConv2d_gradWeights
])
def
local_abstractconv_gradweight_gemm
(
fgraph
,
node
):
def
local_abstractconv_gradweight_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
# a NumPy C implementation of [sd]gemm_.
...
@@ -169,7 +169,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node):
...
@@ -169,7 +169,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv3d_gradWeights
])
@
node_rewrit
er
([
AbstractConv3d_gradWeights
])
def
local_abstractconv3d_gradweight_gemm
(
fgraph
,
node
):
def
local_abstractconv3d_gradweight_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
# a NumPy C implementation of [sd]gemm_.
...
@@ -197,7 +197,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node):
...
@@ -197,7 +197,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d_gradInputs
])
@
node_rewrit
er
([
AbstractConv2d_gradInputs
])
def
local_abstractconv_gradinputs_gemm
(
fgraph
,
node
):
def
local_abstractconv_gradinputs_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
# a NumPy C implementation of [sd]gemm_.
...
@@ -227,7 +227,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node):
...
@@ -227,7 +227,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv3d_gradInputs
])
@
node_rewrit
er
([
AbstractConv3d_gradInputs
])
def
local_abstractconv3d_gradinputs_gemm
(
fgraph
,
node
):
def
local_abstractconv3d_gradinputs_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
# a NumPy C implementation of [sd]gemm_.
...
@@ -255,7 +255,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node):
...
@@ -255,7 +255,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d
])
@
node_rewrit
er
([
AbstractConv2d
])
def
local_conv2d_cpu
(
fgraph
,
node
):
def
local_conv2d_cpu
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv2d
)
or
node
.
inputs
[
0
]
.
dtype
==
"float16"
:
if
not
isinstance
(
node
.
op
,
AbstractConv2d
)
or
node
.
inputs
[
0
]
.
dtype
==
"float16"
:
...
@@ -287,7 +287,7 @@ def local_conv2d_cpu(fgraph, node):
...
@@ -287,7 +287,7 @@ def local_conv2d_cpu(fgraph, node):
return
[
rval
]
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d_gradWeights
])
@
node_rewrit
er
([
AbstractConv2d_gradWeights
])
def
local_conv2d_gradweight_cpu
(
fgraph
,
node
):
def
local_conv2d_gradweight_cpu
(
fgraph
,
node
):
if
(
if
(
not
isinstance
(
node
.
op
,
AbstractConv2d_gradWeights
)
not
isinstance
(
node
.
op
,
AbstractConv2d_gradWeights
)
...
@@ -396,7 +396,7 @@ def local_conv2d_gradweight_cpu(fgraph, node):
...
@@ -396,7 +396,7 @@ def local_conv2d_gradweight_cpu(fgraph, node):
return
[
res
]
return
[
res
]
@
local_optimiz
er
([
AbstractConv2d_gradInputs
])
@
node_rewrit
er
([
AbstractConv2d_gradInputs
])
def
local_conv2d_gradinputs_cpu
(
fgraph
,
node
):
def
local_conv2d_gradinputs_cpu
(
fgraph
,
node
):
if
(
if
(
not
isinstance
(
node
.
op
,
AbstractConv2d_gradInputs
)
not
isinstance
(
node
.
op
,
AbstractConv2d_gradInputs
)
...
@@ -561,7 +561,7 @@ conv_groupopt.register(
...
@@ -561,7 +561,7 @@ conv_groupopt.register(
# Verify that no AbstractConv are present in the graph
# Verify that no AbstractConv are present in the graph
@
local_optimiz
er
(
@
node_rewrit
er
(
[
[
AbstractConv2d
,
AbstractConv2d
,
AbstractConv2d_gradWeights
,
AbstractConv2d_gradWeights
,
...
...
aesara/tensor/nnet/sigm.py
浏览文件 @
550a6e98
...
@@ -9,7 +9,7 @@ stability.
...
@@ -9,7 +9,7 @@ stability.
import
aesara
import
aesara
from
aesara
import
printing
from
aesara
import
printing
from
aesara
import
scalar
as
aes
from
aesara
import
scalar
as
aes
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
from
aesara.printing
import
pprint
from
aesara.printing
import
pprint
from
aesara.scalar
import
sigmoid
as
scalar_sigmoid
from
aesara.scalar
import
sigmoid
as
scalar_sigmoid
from
aesara.scalar.math
import
Sigmoid
from
aesara.scalar.math
import
Sigmoid
...
@@ -99,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"
...
@@ -99,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"
# @opt.register_uncanonicalize
# @opt.register_uncanonicalize
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_ultra_fast_sigmoid
(
fgraph
,
node
):
def
local_ultra_fast_sigmoid
(
fgraph
,
node
):
"""
"""
When enabled, change all sigmoid to ultra_fast_sigmoid.
When enabled, change all sigmoid to ultra_fast_sigmoid.
...
@@ -159,7 +159,7 @@ def hard_sigmoid(x):
...
@@ -159,7 +159,7 @@ def hard_sigmoid(x):
# @opt.register_uncanonicalize
# @opt.register_uncanonicalize
@
local_optimiz
er
([
sigmoid
])
@
node_rewrit
er
([
sigmoid
])
def
local_hard_sigmoid
(
fgraph
,
node
):
def
local_hard_sigmoid
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
Elemwise
)
and
node
.
op
.
scalar_op
==
scalar_sigmoid
:
if
isinstance
(
node
.
op
,
Elemwise
)
and
node
.
op
.
scalar_op
==
scalar_sigmoid
:
out
=
hard_sigmoid
(
node
.
inputs
[
0
])
out
=
hard_sigmoid
(
node
.
inputs
[
0
])
...
...
aesara/tensor/opt_uncanonicalize.py
浏览文件 @
550a6e98
...
@@ -34,7 +34,7 @@ supposed to be canonical.
...
@@ -34,7 +34,7 @@ supposed to be canonical.
import
logging
import
logging
from
aesara
import
scalar
as
aes
from
aesara
import
scalar
as
aes
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
from
aesara.tensor.basic
import
Alloc
,
alloc
,
constant
from
aesara.tensor.basic
import
Alloc
,
alloc
,
constant
from
aesara.tensor.basic_opt
import
register_uncanonicalize
from
aesara.tensor.basic_opt
import
register_uncanonicalize
from
aesara.tensor.elemwise
import
CAReduce
,
DimShuffle
from
aesara.tensor.elemwise
import
CAReduce
,
DimShuffle
...
@@ -47,7 +47,7 @@ _logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
...
@@ -47,7 +47,7 @@ _logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
@register_uncanonicalize
@register_uncanonicalize
@
local_optimiz
er
([
MaxAndArgmax
])
@
node_rewrit
er
([
MaxAndArgmax
])
def
local_max_and_argmax
(
fgraph
,
node
):
def
local_max_and_argmax
(
fgraph
,
node
):
"""
"""
If we don't use the argmax, change it to a max only.
If we don't use the argmax, change it to a max only.
...
@@ -66,7 +66,7 @@ def local_max_and_argmax(fgraph, node):
...
@@ -66,7 +66,7 @@ def local_max_and_argmax(fgraph, node):
@register_uncanonicalize
@register_uncanonicalize
@
local_optimiz
er
([
neg
])
@
node_rewrit
er
([
neg
])
def
local_max_to_min
(
fgraph
,
node
):
def
local_max_to_min
(
fgraph
,
node
):
"""
"""
Change -(max(-x)) to min.
Change -(max(-x)) to min.
...
@@ -95,7 +95,7 @@ def local_max_to_min(fgraph, node):
...
@@ -95,7 +95,7 @@ def local_max_to_min(fgraph, node):
@register_uncanonicalize
@register_uncanonicalize
@
local_optimiz
er
([
Alloc
])
@
node_rewrit
er
([
Alloc
])
def
local_alloc_dimshuffle
(
fgraph
,
node
):
def
local_alloc_dimshuffle
(
fgraph
,
node
):
"""
"""
If a dimshuffle is inside an alloc and only adds dimension to the
If a dimshuffle is inside an alloc and only adds dimension to the
...
@@ -118,7 +118,7 @@ def local_alloc_dimshuffle(fgraph, node):
...
@@ -118,7 +118,7 @@ def local_alloc_dimshuffle(fgraph, node):
@register_uncanonicalize
@register_uncanonicalize
@
local_optimiz
er
([
Reshape
])
@
node_rewrit
er
([
Reshape
])
def
local_reshape_dimshuffle
(
fgraph
,
node
):
def
local_reshape_dimshuffle
(
fgraph
,
node
):
"""
"""
If a dimshuffle is inside a reshape and does not change the order
If a dimshuffle is inside a reshape and does not change the order
...
@@ -147,7 +147,7 @@ def local_reshape_dimshuffle(fgraph, node):
...
@@ -147,7 +147,7 @@ def local_reshape_dimshuffle(fgraph, node):
@register_uncanonicalize
@register_uncanonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_alloc
(
fgraph
,
node
):
def
local_dimshuffle_alloc
(
fgraph
,
node
):
"""
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
If an alloc is inside a dimshuffle which only adds dimension to the left,
...
@@ -175,7 +175,7 @@ def local_dimshuffle_alloc(fgraph, node):
...
@@ -175,7 +175,7 @@ def local_dimshuffle_alloc(fgraph, node):
@register_uncanonicalize
@register_uncanonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_subtensor
(
fgraph
,
node
):
def
local_dimshuffle_subtensor
(
fgraph
,
node
):
"""If a subtensor is inside a dimshuffle which only drop
"""If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the
broadcastable dimensions, scrap the dimshuffle and index the
...
...
aesara/tensor/random/opt.py
浏览文件 @
550a6e98
from
aesara.compile
import
optdb
from
aesara.compile
import
optdb
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
in2out
,
node_rewrit
er
from
aesara.tensor.basic
import
constant
,
get_vector_length
from
aesara.tensor.basic
import
constant
,
get_vector_length
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.extra_ops
import
broadcast_to
from
aesara.tensor.extra_ops
import
broadcast_to
...
@@ -39,7 +39,7 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
...
@@ -39,7 +39,7 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
return
not
all
(
_node_check
(
n
,
i
)
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
()))
return
not
all
(
_node_check
(
n
,
i
)
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
()))
@
local_optimiz
er
([
RandomVariable
],
inplace
=
True
)
@
node_rewrit
er
([
RandomVariable
],
inplace
=
True
)
def
random_make_inplace
(
fgraph
,
node
):
def
random_make_inplace
(
fgraph
,
node
):
op
=
node
.
op
op
=
node
.
op
...
@@ -61,7 +61,7 @@ optdb.register(
...
@@ -61,7 +61,7 @@ optdb.register(
)
)
@
local_optimiz
er
(
tracks
=
None
)
@
node_rewrit
er
(
tracks
=
None
)
def
local_rv_size_lift
(
fgraph
,
node
):
def
local_rv_size_lift
(
fgraph
,
node
):
"""Lift the ``size`` parameter in a ``RandomVariable``.
"""Lift the ``size`` parameter in a ``RandomVariable``.
...
@@ -109,7 +109,7 @@ def local_rv_size_lift(fgraph, node):
...
@@ -109,7 +109,7 @@ def local_rv_size_lift(fgraph, node):
return
new_node
.
outputs
return
new_node
.
outputs
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_rv_lift
(
fgraph
,
node
):
def
local_dimshuffle_rv_lift
(
fgraph
,
node
):
"""Lift a ``DimShuffle`` through ``RandomVariable`` inputs.
"""Lift a ``DimShuffle`` through ``RandomVariable`` inputs.
...
@@ -266,7 +266,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -266,7 +266,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return
False
return
False
@
local_optimiz
er
([
Subtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor
])
@
node_rewrit
er
([
Subtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor
])
def
local_subtensor_rv_lift
(
fgraph
,
node
):
def
local_subtensor_rv_lift
(
fgraph
,
node
):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
...
...
aesara/tensor/subtensor_opt.py
浏览文件 @
550a6e98
...
@@ -7,7 +7,7 @@ import aesara
...
@@ -7,7 +7,7 @@ import aesara
import
aesara.scalar.basic
as
aes
import
aesara.scalar.basic
as
aes
from
aesara
import
compile
from
aesara
import
compile
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
in2out
,
node_rewrit
er
from
aesara.raise_op
import
Assert
from
aesara.raise_op
import
Assert
from
aesara.tensor.basic
import
(
from
aesara.tensor.basic
import
(
Alloc
,
Alloc
,
...
@@ -202,7 +202,7 @@ def get_advsubtensor_axis(indices):
...
@@ -202,7 +202,7 @@ def get_advsubtensor_axis(indices):
@register_specialize
@register_specialize
@
local_optimiz
er
([
AdvancedSubtensor
])
@
node_rewrit
er
([
AdvancedSubtensor
])
def
local_replace_AdvancedSubtensor
(
fgraph
,
node
):
def
local_replace_AdvancedSubtensor
(
fgraph
,
node
):
r"""
r"""
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
...
@@ -231,7 +231,7 @@ def local_replace_AdvancedSubtensor(fgraph, node):
...
@@ -231,7 +231,7 @@ def local_replace_AdvancedSubtensor(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
AdvancedIncSubtensor
])
@
node_rewrit
er
([
AdvancedIncSubtensor
])
def
local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1
(
fgraph
,
node
):
def
local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1
(
fgraph
,
node
):
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
...
@@ -268,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
...
@@ -268,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_stabilize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_of_dot
(
fgraph
,
node
):
def
local_subtensor_of_dot
(
fgraph
,
node
):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
...
@@ -326,7 +326,7 @@ def local_subtensor_of_dot(fgraph, node):
...
@@ -326,7 +326,7 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_useless_slice
(
fgraph
,
node
):
def
local_useless_slice
(
fgraph
,
node
):
"""
"""
Remove Subtensor of the form X[0, :] -> X[0]
Remove Subtensor of the form X[0, :] -> X[0]
...
@@ -362,7 +362,7 @@ def local_useless_slice(fgraph, node):
...
@@ -362,7 +362,7 @@ def local_useless_slice(fgraph, node):
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
def
local_subtensor_lift
(
fgraph
,
node
):
"""
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
...
@@ -466,7 +466,7 @@ def local_subtensor_lift(fgraph, node):
...
@@ -466,7 +466,7 @@ def local_subtensor_lift(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_merge
(
fgraph
,
node
):
def
local_subtensor_merge
(
fgraph
,
node
):
"""
"""
Refactored optimization to deal with all cases of tensor merging.
Refactored optimization to deal with all cases of tensor merging.
...
@@ -537,7 +537,7 @@ def local_subtensor_merge(fgraph, node):
...
@@ -537,7 +537,7 @@ def local_subtensor_merge(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_remove_broadcastable_index
(
fgraph
,
node
):
def
local_subtensor_remove_broadcastable_index
(
fgraph
,
node
):
"""
"""
Remove broadcastable dimension with index 0 or -1
Remove broadcastable dimension with index 0 or -1
...
@@ -586,7 +586,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
...
@@ -586,7 +586,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_of_alloc
(
fgraph
,
node
):
def
local_subtensor_of_alloc
(
fgraph
,
node
):
"""
"""
...
@@ -654,7 +654,7 @@ def local_subtensor_of_alloc(fgraph, node):
...
@@ -654,7 +654,7 @@ def local_subtensor_of_alloc(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_inc_subtensor
(
fgraph
,
node
):
def
local_subtensor_inc_subtensor
(
fgraph
,
node
):
"""
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
Subtensor(SetSubtensor(x, y, idx), idx) -> y
...
@@ -694,7 +694,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
...
@@ -694,7 +694,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
(
"fast_compile"
)
@register_canonicalize
(
"fast_compile"
)
@register_useless
@register_useless
@
local_optimiz
er
([
Subtensor
,
AdvancedSubtensor1
])
@
node_rewrit
er
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
fgraph
,
node
):
def
local_subtensor_make_vector
(
fgraph
,
node
):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
...
@@ -770,7 +770,7 @@ def local_subtensor_make_vector(fgraph, node):
...
@@ -770,7 +770,7 @@ def local_subtensor_make_vector(fgraph, node):
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
IncSubtensor
])
@
node_rewrit
er
([
IncSubtensor
])
def
local_useless_inc_subtensor
(
fgraph
,
node
):
def
local_useless_inc_subtensor
(
fgraph
,
node
):
r"""Remove redundant `IncSubtensor`\s.
r"""Remove redundant `IncSubtensor`\s.
...
@@ -834,7 +834,7 @@ def local_useless_inc_subtensor(fgraph, node):
...
@@ -834,7 +834,7 @@ def local_useless_inc_subtensor(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
AdvancedIncSubtensor1
])
@
node_rewrit
er
([
AdvancedIncSubtensor1
])
def
local_set_to_inc_subtensor
(
fgraph
,
node
):
def
local_set_to_inc_subtensor
(
fgraph
,
node
):
r"""
r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
...
@@ -878,7 +878,7 @@ def local_set_to_inc_subtensor(fgraph, node):
...
@@ -878,7 +878,7 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_useless_subtensor
(
fgraph
,
node
):
def
local_useless_subtensor
(
fgraph
,
node
):
"""Remove `Subtensor` if it takes the full input."""
"""Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature
# This optimization needs ShapeOpt and fgraph.shape_feature
...
@@ -960,7 +960,7 @@ def local_useless_subtensor(fgraph, node):
...
@@ -960,7 +960,7 @@ def local_useless_subtensor(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
AdvancedSubtensor1
])
@
node_rewrit
er
([
AdvancedSubtensor1
])
def
local_useless_AdvancedSubtensor1
(
fgraph
,
node
):
def
local_useless_AdvancedSubtensor1
(
fgraph
,
node
):
"""Remove `AdvancedSubtensor1` if it takes the full input.
"""Remove `AdvancedSubtensor1` if it takes the full input.
...
@@ -1116,7 +1116,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
...
@@ -1116,7 +1116,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
local_IncSubtensor_serialize
(
fgraph
,
node
):
def
local_IncSubtensor_serialize
(
fgraph
,
node
):
"""
"""
When using Subtensor, gradient graphs can be ugly.
When using Subtensor, gradient graphs can be ugly.
...
@@ -1216,7 +1216,7 @@ compile.optdb.register(
...
@@ -1216,7 +1216,7 @@ compile.optdb.register(
# gemm is the first one now, at priority 70
# gemm is the first one now, at priority 70
@
local_optimiz
er
([
IncSubtensor
],
inplace
=
True
)
@
node_rewrit
er
([
IncSubtensor
],
inplace
=
True
)
def
local_inplace_setsubtensor
(
fgraph
,
node
):
def
local_inplace_setsubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
dta
=
node
.
op
.
destroyhandler_tolerate_aliased
dta
=
node
.
op
.
destroyhandler_tolerate_aliased
...
@@ -1249,7 +1249,7 @@ compile.optdb.register(
...
@@ -1249,7 +1249,7 @@ compile.optdb.register(
)
)
@
local_optimiz
er
([
AdvancedIncSubtensor1
],
inplace
=
True
)
@
node_rewrit
er
([
AdvancedIncSubtensor1
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor1
(
fgraph
,
node
):
def
local_inplace_AdvancedIncSubtensor1
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
clone_inplace
()
new_op
=
node
.
op
.
clone_inplace
()
...
@@ -1270,7 +1270,7 @@ compile.optdb.register(
...
@@ -1270,7 +1270,7 @@ compile.optdb.register(
)
)
@
local_optimiz
er
([
AdvancedIncSubtensor
],
inplace
=
True
)
@
node_rewrit
er
([
AdvancedIncSubtensor
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor
(
fgraph
,
node
):
def
local_inplace_AdvancedIncSubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor
)
and
not
node
.
op
.
inplace
:
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
type
(
node
.
op
)(
new_op
=
type
(
node
.
op
)(
...
@@ -1298,7 +1298,7 @@ compile.optdb.register(
...
@@ -1298,7 +1298,7 @@ compile.optdb.register(
# Register old name
# Register old name
@register_canonicalize
(
"local_incsubtensor_of_allocs"
)
@register_canonicalize
(
"local_incsubtensor_of_allocs"
)
@register_stabilize
(
"local_incsubtensor_of_allocs"
)
@register_stabilize
(
"local_incsubtensor_of_allocs"
)
@
local_optimiz
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
@
node_rewrit
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_incsubtensor_of_zeros
(
fgraph
,
node
):
def
local_incsubtensor_of_zeros
(
fgraph
,
node
):
"""
"""
IncSubtensor(x, zeros, idx) -> x
IncSubtensor(x, zeros, idx) -> x
...
@@ -1323,7 +1323,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
...
@@ -1323,7 +1323,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
IncSubtensor
])
@
node_rewrit
er
([
IncSubtensor
])
def
local_incsubtensor_of_zeros_to_setsubtensor
(
fgraph
,
node
):
def
local_incsubtensor_of_zeros_to_setsubtensor
(
fgraph
,
node
):
"""
"""
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
...
@@ -1344,7 +1344,7 @@ def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
...
@@ -1344,7 +1344,7 @@ def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
@register_canonicalize
(
"local_setsubtensor_of_allocs"
)
@register_canonicalize
(
"local_setsubtensor_of_allocs"
)
@register_stabilize
(
"local_setsubtensor_of_allocs"
)
@register_stabilize
(
"local_setsubtensor_of_allocs"
)
@
local_optimiz
er
([
IncSubtensor
])
@
node_rewrit
er
([
IncSubtensor
])
def
local_setsubtensor_of_constants
(
fgraph
,
node
):
def
local_setsubtensor_of_constants
(
fgraph
,
node
):
"""
"""
SetSubtensor(x, x[idx], idx) -> x
SetSubtensor(x, x[idx], idx) -> x
...
@@ -1379,7 +1379,7 @@ def local_setsubtensor_of_constants(fgraph, node):
...
@@ -1379,7 +1379,7 @@ def local_setsubtensor_of_constants(fgraph, node):
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@
local_optimiz
er
([
AdvancedSubtensor1
])
@
node_rewrit
er
([
AdvancedSubtensor1
])
def
local_adv_sub1_adv_inc_sub1
(
fgraph
,
node
):
def
local_adv_sub1_adv_inc_sub1
(
fgraph
,
node
):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
...
@@ -1446,7 +1446,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
...
@@ -1446,7 +1446,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
@register_stabilize
@register_stabilize
@register_canonicalize
@register_canonicalize
@register_useless
@register_useless
@
local_optimiz
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
@
node_rewrit
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_useless_inc_subtensor_alloc
(
fgraph
,
node
):
def
local_useless_inc_subtensor_alloc
(
fgraph
,
node
):
"""
"""
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
...
@@ -1552,7 +1552,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
...
@@ -1552,7 +1552,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
@register_specialize
@register_specialize
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_shape_constant
(
fgraph
,
node
):
def
local_subtensor_shape_constant
(
fgraph
,
node
):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
...
@@ -1606,7 +1606,7 @@ def local_subtensor_shape_constant(fgraph, node):
...
@@ -1606,7 +1606,7 @@ def local_subtensor_shape_constant(fgraph, node):
@register_canonicalize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_SpecifyShape_lift
(
fgraph
,
node
):
def
local_subtensor_SpecifyShape_lift
(
fgraph
,
node
):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
...
@@ -1640,7 +1640,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
...
@@ -1640,7 +1640,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
@register_specialize
@register_specialize
@
local_optimiz
er
([
Join
])
@
node_rewrit
er
([
Join
])
def
local_join_subtensors
(
fgraph
,
node
):
def
local_join_subtensors
(
fgraph
,
node
):
r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`.
r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`.
...
...
aesara/typed_list/opt.py
浏览文件 @
550a6e98
from
aesara.compile
import
optdb
from
aesara.compile
import
optdb
from
aesara.graph.opt
import
TopoOptimizer
,
local_optimiz
er
from
aesara.graph.opt
import
TopoOptimizer
,
node_rewrit
er
from
aesara.typed_list.basic
import
Append
,
Extend
,
Insert
,
Remove
,
Reverse
from
aesara.typed_list.basic
import
Append
,
Extend
,
Insert
,
Remove
,
Reverse
@
local_optimiz
er
([
Append
,
Extend
,
Insert
,
Reverse
,
Remove
],
inplace
=
True
)
@
node_rewrit
er
([
Append
,
Extend
,
Insert
,
Reverse
,
Remove
],
inplace
=
True
)
def
typed_list_inplace_opt
(
fgraph
,
node
):
def
typed_list_inplace_opt
(
fgraph
,
node
):
if
(
if
(
isinstance
(
node
.
op
,
(
Append
,
Extend
,
Insert
,
Reverse
,
Remove
))
isinstance
(
node
.
op
,
(
Append
,
Extend
,
Insert
,
Reverse
,
Remove
))
...
...
doc/extending/graph_rewriting.rst
浏览文件 @
550a6e98
...
@@ -67,15 +67,15 @@ Local optimization
...
@@ -67,15 +67,15 @@ Local optimization
A local optimization is an object which defines the following methods:
A local optimization is an object which defines the following methods:
.. class::
LocalOptimiz
er
.. class::
NodeRewrit
er
.. method:: transform(fgraph, node)
.. method:: transform(fgraph, node)
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a
returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`
LocalOptimiz
er` is applied by a :class:`NavigatorOptimizer`, the outputs
list. When the :class:`
NodeRewrit
er` is applied by a :class:`NavigatorOptimizer`, the outputs
of the node passed as argument to the :class:`
LocalOptimiz
er` will be replaced by
of the node passed as argument to the :class:`
NodeRewrit
er` will be replaced by
the list returned.
the list returned.
...
@@ -218,10 +218,10 @@ The local version of the above code would be the following:
...
@@ -218,10 +218,10 @@ The local version of the above code would be the following:
.. testcode::
.. testcode::
from aesara.graph.opt import
LocalOptimiz
er
from aesara.graph.opt import
NodeRewrit
er
class LocalSimplify(
LocalOptimiz
er):
class LocalSimplify(
NodeRewrit
er):
def transform(self, fgraph, node):
def transform(self, fgraph, node):
if node.op == true_div:
if node.op == true_div:
x, y = node.inputs
x, y = node.inputs
...
@@ -234,7 +234,7 @@ The local version of the above code would be the following:
...
@@ -234,7 +234,7 @@ The local version of the above code would be the following:
return False
return False
def tracks(self):
def tracks(self):
# This tells certain navigators to only apply this `
LocalOptimiz
er`
# This tells certain navigators to only apply this `
NodeRewrit
er`
# on these kinds of `Op`s
# on these kinds of `Op`s
return [true_div]
return [true_div]
...
@@ -242,7 +242,7 @@ The local version of the above code would be the following:
...
@@ -242,7 +242,7 @@ The local version of the above code would be the following:
In this case, the transformation is defined in the
In this case, the transformation is defined in the
:meth:`
LocalOptimiz
er.transform` method, which is given an explicit
:meth:`
NodeRewrit
er.transform` method, which is given an explicit
:class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is
:class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is
also provided, in case global information is needed.
also provided, in case global information is needed.
...
@@ -273,7 +273,7 @@ FunctionGraph(add(z, mul(x, true_div(z, x))))
...
@@ -273,7 +273,7 @@ FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub`
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub`
++++++++++++++++++++++++++++++++++++++++++++++++++++++
++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`
LocalOptimiz
er`\s:
Aesara defines some shortcuts to make :class:`
NodeRewrit
er`\s:
.. function:: OpSub(op1, op2)
.. function:: OpSub(op1, op2)
...
@@ -433,7 +433,7 @@ This means that a relation that--say--represents :math:`x + x = 2 x` can be
...
@@ -433,7 +433,7 @@ This means that a relation that--say--represents :math:`x + x = 2 x` can be
utilized in both directions.
utilized in both directions.
Currently, the local optimizer :class:`KanrenRelationSub` provides a means of
Currently, the local optimizer :class:`KanrenRelationSub` provides a means of
turning :mod:`kanren` relations into :class:`
LocalOptimiz
er`\s; however,
turning :mod:`kanren` relations into :class:`
NodeRewrit
er`\s; however,
:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so
:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so
:class:`KanrenRelationSub` is not necessary.
:class:`KanrenRelationSub` is not necessary.
...
@@ -561,7 +561,7 @@ serve as a basis for filtering.
...
@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations
The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might
to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then
want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`
LocalOptimiz
er`\s A, B
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`
NodeRewrit
er`\s A, B
and C which are applied on every node of the graph until they all fail to change
and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to
some optimizations are very CPU-intensive and we don't want to take the time to
...
@@ -596,14 +596,14 @@ is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase`
...
@@ -596,14 +596,14 @@ is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the
instances, the :class:`OptimizationQuery` will be passed to them as well and the
optimizers they return will be put in their places.
optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`
LocalOptimiz
er` or :class:`OptimizationDatabase` objects. Each of them
An :class:`EquilibriumDB` contains :class:`
NodeRewrit
er` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`
LocalOptimiz
er`\s that match the query are
an :class:`EquilibriumDB`, all :class:`
NodeRewrit
er`\s that match the query are
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`
LocalOptimiz
er`\s they return will be put in their places
:class:`
NodeRewrit
er`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`
LocalOptimiz
er` objects, so this
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`
NodeRewrit
er` objects, so this
is a moot point).
is a moot point).
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which
...
@@ -697,10 +697,10 @@ already-compiled functions will see no change. The 'order' parameter
...
@@ -697,10 +697,10 @@ already-compiled functions will see no change. The 'order' parameter
Registering a :class:`
LocalOptimiz
er`
Registering a :class:`
NodeRewrit
er`
-----------------------------------
--
-----------------------------------
:class:`
LocalOptimiz
er`\s may be registered in two ways:
:class:`
NodeRewrit
er`\s may be registered in two ways:
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer
(see previous section).
(see previous section).
...
...
tests/compile/test_debugmode.py
浏览文件 @
550a6e98
...
@@ -18,7 +18,7 @@ from aesara.configdefaults import config
...
@@ -18,7 +18,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.features
import
BadOptimization
from
aesara.graph.features
import
BadOptimization
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
local_optimiz
er
from
aesara.graph.opt
import
node_rewrit
er
from
aesara.graph.optdb
import
EquilibriumDB
from
aesara.graph.optdb
import
EquilibriumDB
from
aesara.link.c.op
import
COp
from
aesara.link.c.op
import
COp
from
aesara.tensor.math
import
add
,
dot
,
log
from
aesara.tensor.math
import
add
,
dot
,
log
...
@@ -237,7 +237,7 @@ def test_badthunkoutput():
...
@@ -237,7 +237,7 @@ def test_badthunkoutput():
def
test_badoptimization
():
def
test_badoptimization
():
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_broken_add
(
fgraph
,
node
):
def
insert_broken_add
(
fgraph
,
node
):
if
node
.
op
==
add
:
if
node
.
op
==
add
:
return
[
off_by_half
(
*
node
.
inputs
)]
return
[
off_by_half
(
*
node
.
inputs
)]
...
@@ -263,7 +263,7 @@ def test_badoptimization():
...
@@ -263,7 +263,7 @@ def test_badoptimization():
def
test_badoptimization_opt_err
():
def
test_badoptimization_opt_err
():
# This variant of test_badoptimization() replace the working code
# This variant of test_badoptimization() replace the working code
# with a new apply node that will raise an error.
# with a new apply node that will raise an error.
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_bigger_b_add
(
fgraph
,
node
):
def
insert_bigger_b_add
(
fgraph
,
node
):
if
node
.
op
==
add
:
if
node
.
op
==
add
:
inputs
=
list
(
node
.
inputs
)
inputs
=
list
(
node
.
inputs
)
...
@@ -272,7 +272,7 @@ def test_badoptimization_opt_err():
...
@@ -272,7 +272,7 @@ def test_badoptimization_opt_err():
return
[
node
.
op
(
*
inputs
)]
return
[
node
.
op
(
*
inputs
)]
return
False
return
False
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_bad_dtype
(
fgraph
,
node
):
def
insert_bad_dtype
(
fgraph
,
node
):
if
node
.
op
==
add
:
if
node
.
op
==
add
:
inputs
=
list
(
node
.
inputs
)
inputs
=
list
(
node
.
inputs
)
...
@@ -326,7 +326,7 @@ def test_stochasticoptimization():
...
@@ -326,7 +326,7 @@ def test_stochasticoptimization():
last_time_replaced
=
[
False
]
last_time_replaced
=
[
False
]
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_broken_add_sometimes
(
fgraph
,
node
):
def
insert_broken_add_sometimes
(
fgraph
,
node
):
if
node
.
op
==
add
:
if
node
.
op
==
add
:
last_time_replaced
[
0
]
=
not
last_time_replaced
[
0
]
last_time_replaced
[
0
]
=
not
last_time_replaced
[
0
]
...
...
tests/graph/test_opt.py
浏览文件 @
550a6e98
...
@@ -15,10 +15,10 @@ from aesara.graph.opt import (
...
@@ -15,10 +15,10 @@ from aesara.graph.opt import (
PatternSub
,
PatternSub
,
TopoOptimizer
,
TopoOptimizer
,
in2out
,
in2out
,
local_optimizer
,
logging
,
logging
,
node_rewriter
,
pre_constant_merge
,
pre_constant_merge
,
pre_greedy_
local_optimiz
er
,
pre_greedy_
node_rewrit
er
,
)
)
from
aesara.raise_op
import
assert_op
from
aesara.raise_op
import
assert_op
from
aesara.tensor.basic_opt
import
constant_folding
from
aesara.tensor.basic_opt
import
constant_folding
...
@@ -547,7 +547,7 @@ def test_pre_constant_merge():
...
@@ -547,7 +547,7 @@ def test_pre_constant_merge():
assert
res
==
[
adv
]
assert
res
==
[
adv
]
def
test_pre_greedy_
local_optimiz
er
():
def
test_pre_greedy_
node_rewrit
er
():
empty_fgraph
=
FunctionGraph
([],
[])
empty_fgraph
=
FunctionGraph
([],
[])
...
@@ -564,7 +564,7 @@ def test_pre_greedy_local_optimizer():
...
@@ -564,7 +564,7 @@ def test_pre_greedy_local_optimizer():
# This should fold `o1`, because it has only `Constant` arguments, and
# This should fold `o1`, because it has only `Constant` arguments, and
# replace it with the `Constant` result
# replace it with the `Constant` result
cst
=
pre_greedy_
local_optimiz
er
(
empty_fgraph
,
[
constant_folding
],
o2
)
cst
=
pre_greedy_
node_rewrit
er
(
empty_fgraph
,
[
constant_folding
],
o2
)
assert
cst
.
owner
.
inputs
[
0
]
.
owner
is
None
assert
cst
.
owner
.
inputs
[
0
]
.
owner
is
None
assert
cst
.
owner
.
inputs
[
1
]
is
c2
assert
cst
.
owner
.
inputs
[
1
]
is
c2
...
@@ -577,14 +577,14 @@ def test_pre_greedy_local_optimizer():
...
@@ -577,14 +577,14 @@ def test_pre_greedy_local_optimizer():
fg
=
FunctionGraph
([],
[
o1
],
clone
=
False
)
fg
=
FunctionGraph
([],
[
o1
],
clone
=
False
)
o2
=
op1
(
o1
,
c2
,
x
,
o3
,
o1
)
o2
=
op1
(
o1
,
c2
,
x
,
o3
,
o1
)
cst
=
pre_greedy_
local_optimiz
er
(
fg
,
[
constant_folding
],
o2
)
cst
=
pre_greedy_
node_rewrit
er
(
fg
,
[
constant_folding
],
o2
)
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
# What exactly is this supposed to test?
# What exactly is this supposed to test?
ms
=
MakeSlice
()(
1
)
ms
=
MakeSlice
()(
1
)
cst
=
pre_greedy_
local_optimiz
er
(
empty_fgraph
,
[
constant_folding
],
ms
)
cst
=
pre_greedy_
node_rewrit
er
(
empty_fgraph
,
[
constant_folding
],
ms
)
assert
isinstance
(
cst
,
SliceConstant
)
assert
isinstance
(
cst
,
SliceConstant
)
...
@@ -673,13 +673,13 @@ class TestLocalOptGroup:
...
@@ -673,13 +673,13 @@ class TestLocalOptGroup:
fgraph
=
FunctionGraph
([
x
,
y
],
[
o1
],
clone
=
False
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
o1
],
clone
=
False
)
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_opt_1
(
fgraph
,
node
):
def
local_opt_1
(
fgraph
,
node
):
if
node
.
inputs
[
0
]
==
x
:
if
node
.
inputs
[
0
]
==
x
:
res
=
op2
(
y
,
*
node
.
inputs
[
1
:])
res
=
op2
(
y
,
*
node
.
inputs
[
1
:])
return
[
res
]
return
[
res
]
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_opt_2
(
fgraph
,
node
):
def
local_opt_2
(
fgraph
,
node
):
if
node
.
inputs
[
0
]
==
y
:
if
node
.
inputs
[
0
]
==
y
:
res
=
op2
(
x
,
*
node
.
inputs
[
1
:])
res
=
op2
(
x
,
*
node
.
inputs
[
1
:])
...
@@ -703,8 +703,8 @@ class TestLocalOptGroup:
...
@@ -703,8 +703,8 @@ class TestLocalOptGroup:
)
)
def
test_
local_optimiz
er_str
():
def
test_
node_rewrit
er_str
():
@
local_optimiz
er
([
op1
,
MyOp
])
@
node_rewrit
er
([
op1
,
MyOp
])
def
local_opt_1
(
fgraph
,
node
):
def
local_opt_1
(
fgraph
,
node
):
pass
pass
...
@@ -715,17 +715,17 @@ def test_local_optimizer_str():
...
@@ -715,17 +715,17 @@ def test_local_optimizer_str():
assert
"local_opt_1"
in
res
assert
"local_opt_1"
in
res
def
test_
local_optimiz
er
():
def
test_
node_rewrit
er
():
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
@
local_optimiz
er
([])
@
node_rewrit
er
([])
def
local_bad_1
(
fgraph
,
node
):
def
local_bad_1
(
fgraph
,
node
):
return
node
.
outputs
return
node
.
outputs
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
@
local_optimiz
er
([
None
])
@
node_rewrit
er
([
None
])
def
local_bad_2
(
fgraph
,
node
):
def
local_bad_2
(
fgraph
,
node
):
return
node
.
outputs
return
node
.
outputs
...
@@ -748,7 +748,7 @@ def test_local_optimizer():
...
@@ -748,7 +748,7 @@ def test_local_optimizer():
hits
=
[
0
]
hits
=
[
0
]
@
local_optimiz
er
([
op1
,
MyNewOp
])
@
node_rewrit
er
([
op1
,
MyNewOp
])
def
local_opt_1
(
fgraph
,
node
,
hits
=
hits
):
def
local_opt_1
(
fgraph
,
node
,
hits
=
hits
):
hits
[
0
]
+=
1
hits
[
0
]
+=
1
return
node
.
outputs
return
node
.
outputs
...
@@ -766,24 +766,24 @@ def test_local_optimizer():
...
@@ -766,24 +766,24 @@ def test_local_optimizer():
assert
hits
[
0
]
==
2
assert
hits
[
0
]
==
2
def
test_Tracking
LocalOptimiz
er
():
def
test_Tracking
NodeRewrit
er
():
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_opt_1
(
fgraph
,
node
):
def
local_opt_1
(
fgraph
,
node
):
pass
pass
@
local_optimiz
er
([
op1
])
@
node_rewrit
er
([
op1
])
def
local_opt_2
(
fgraph
,
node
):
def
local_opt_2
(
fgraph
,
node
):
pass
pass
@
local_optimiz
er
([
Op
])
@
node_rewrit
er
([
Op
])
def
local_opt_3
(
fgraph
,
node
):
def
local_opt_3
(
fgraph
,
node
):
pass
pass
@
local_optimiz
er
([
MyOp
])
@
node_rewrit
er
([
MyOp
])
def
local_opt_4
(
fgraph
,
node
):
def
local_opt_4
(
fgraph
,
node
):
pass
pass
@
local_optimiz
er
([
MyOp
])
@
node_rewrit
er
([
MyOp
])
def
local_opt_5
(
fgraph
,
node
):
def
local_opt_5
(
fgraph
,
node
):
pass
pass
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
550a6e98
...
@@ -16,7 +16,7 @@ from aesara.configdefaults import config
...
@@ -16,7 +16,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
check_stack_trace
,
local_optimiz
er
,
out2in
from
aesara.graph.opt
import
check_stack_trace
,
node_rewrit
er
,
out2in
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
...
@@ -1752,7 +1752,7 @@ class TestShapeOptimizer:
...
@@ -1752,7 +1752,7 @@ class TestShapeOptimizer:
identity_shape
=
IdentityShape
()
identity_shape
=
IdentityShape
()
@
local_optimiz
er
([
IdentityNoShape
])
@
node_rewrit
er
([
IdentityNoShape
])
def
local_identity_noshape_to_identity_shape
(
fgraph
,
node
):
def
local_identity_noshape_to_identity_shape
(
fgraph
,
node
):
"""Optimization transforming the first Op into the second"""
"""Optimization transforming the first Op into the second"""
if
isinstance
(
node
.
op
,
IdentityNoShape
):
if
isinstance
(
node
.
op
,
IdentityNoShape
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论