Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
550a6e98
提交
550a6e98
authored
7月 14, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename LocalOptimizer to NodeRewriter
上级
214ef4cf
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
29 个修改的文件
包含
189 行增加
和
195 行删除
+189
-195
builders.py
aesara/compile/builders.py
+2
-2
__init__.py
aesara/graph/__init__.py
+1
-1
kanren.py
aesara/graph/kanren.py
+2
-2
opt.py
aesara/graph/opt.py
+0
-0
optdb.py
aesara/graph/optdb.py
+8
-14
ifelse.py
aesara/ifelse.py
+7
-7
ops.py
aesara/sandbox/linalg/ops.py
+8
-8
rng_mrg.py
aesara/sandbox/rng_mrg.py
+2
-2
opt.py
aesara/scan/opt.py
+8
-8
opt.py
aesara/sparse/opt.py
+14
-14
basic_opt.py
aesara/tensor/basic_opt.py
+0
-0
blas.py
aesara/tensor/blas.py
+10
-10
blas_c.py
aesara/tensor/blas_c.py
+5
-5
blas_scipy.py
aesara/tensor/blas_scipy.py
+3
-3
math_opt.py
aesara/tensor/math_opt.py
+0
-0
basic.py
aesara/tensor/nnet/basic.py
+10
-10
batchnorm.py
aesara/tensor/nnet/batchnorm.py
+4
-4
conv3d2d.py
aesara/tensor/nnet/conv3d2d.py
+2
-2
ctc.py
aesara/tensor/nnet/ctc.py
+2
-2
opt.py
aesara/tensor/nnet/opt.py
+13
-13
sigm.py
aesara/tensor/nnet/sigm.py
+3
-3
opt_uncanonicalize.py
aesara/tensor/opt_uncanonicalize.py
+7
-7
opt.py
aesara/tensor/random/opt.py
+5
-5
subtensor_opt.py
aesara/tensor/subtensor_opt.py
+27
-27
opt.py
aesara/typed_list/opt.py
+2
-2
graph_rewriting.rst
doc/extending/graph_rewriting.rst
+17
-17
test_debugmode.py
tests/compile/test_debugmode.py
+5
-5
test_opt.py
tests/graph/test_opt.py
+20
-20
test_basic_opt.py
tests/tensor/test_basic_opt.py
+2
-2
没有找到文件。
aesara/compile/builders.py
浏览文件 @
550a6e98
...
...
@@ -24,7 +24,7 @@ from aesara.graph.basic import (
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.null_type
import
NullType
from
aesara.graph.op
import
HasInnerGraph
,
Op
from
aesara.graph.opt
import
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
in2out
,
node_rewrit
er
from
aesara.graph.utils
import
MissingInputError
from
aesara.tensor.basic_opt
import
ShapeFeature
...
...
@@ -928,7 +928,7 @@ class OpFromGraph(Op, HasInnerGraph):
output
[
0
]
=
variable
@
local_optimiz
er
([
OpFromGraph
])
@
node_rewrit
er
([
OpFromGraph
])
def
inline_ofg_expansion
(
fgraph
,
node
):
"""
This optimization expands internal graph of OpFromGraph.
...
...
aesara/graph/__init__.py
浏览文件 @
550a6e98
...
...
@@ -13,7 +13,7 @@ from aesara.graph.basic import (
from
aesara.graph.op
import
Op
from
aesara.graph.type
import
Type
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.opt
import
local_optimiz
er
,
optimizer
from
aesara.graph.opt
import
node_rewrit
er
,
optimizer
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
OptimizationQuery
...
...
aesara/graph/kanren.py
浏览文件 @
550a6e98
...
...
@@ -6,11 +6,11 @@ from unification import var
from
unification.variable
import
Var
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.opt
import
LocalOptimiz
er
from
aesara.graph.opt
import
NodeRewrit
er
from
aesara.graph.unify
import
eval_if_etuple
class
KanrenRelationSub
(
LocalOptimiz
er
):
class
KanrenRelationSub
(
NodeRewrit
er
):
r"""A local optimizer that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information
...
...
aesara/graph/opt.py
浏览文件 @
550a6e98
差异被折叠。
点击展开。
aesara/graph/optdb.py
浏览文件 @
550a6e98
...
...
@@ -11,14 +11,14 @@ from aesara.misc.ordered_set import OrderedSet
from
aesara.utils
import
DefaultOrderedDict
OptimizersType
=
Union
[
aesara_opt
.
GraphRewriter
,
aesara_opt
.
LocalOptimiz
er
]
OptimizersType
=
Union
[
aesara_opt
.
GraphRewriter
,
aesara_opt
.
NodeRewrit
er
]
class
OptimizationDatabase
:
r"""A class that represents a collection/database of optimizations.
These databases are used to logically organize collections of optimizers
(i.e. `GraphRewriter`\s and `
LocalOptimiz
er`).
(i.e. `GraphRewriter`\s and `
NodeRewrit
er`).
"""
def
__init__
(
self
):
...
...
@@ -62,7 +62,7 @@ class OptimizationDatabase:
(
OptimizationDatabase
,
aesara_opt
.
GraphRewriter
,
aesara_opt
.
LocalOptimiz
er
,
aesara_opt
.
NodeRewrit
er
,
),
):
raise
TypeError
(
f
"{optimizer} is not a valid optimizer type."
)
...
...
@@ -311,7 +311,7 @@ class EquilibriumDB(OptimizationDatabase):
Notes
-----
We can use `
LocalOptimiz
er` and `GraphRewriter` since `EquilibriumOptimizer`
We can use `
NodeRewrit
er` and `GraphRewriter` since `EquilibriumOptimizer`
supports both.
It is probably not a good idea to have ignore_newtrees=False and
...
...
@@ -474,24 +474,18 @@ class SequenceDB(OptimizationDatabase):
class
LocalGroupDB
(
SequenceDB
):
"""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
It supports the tracks, to only get applied to some Op.
"""
r"""A database that generates `NodeRewriter`\s of type `LocalOptGroup`."""
def
__init__
(
self
,
apply_all_opts
:
bool
=
False
,
profile
:
bool
=
False
,
local_opt
=
aesara_opt
.
LocalOptGroup
,
node_rewriter
=
aesara_opt
.
LocalOptGroup
,
):
super
()
.
__init__
(
failure_callback
=
None
)
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
self
.
local_opt
=
local_opt
self
.
node_rewriter
=
node_rewriter
self
.
__name__
:
str
=
""
def
register
(
self
,
name
,
obj
,
*
tags
,
position
=
"last"
,
**
kwargs
):
...
...
@@ -499,7 +493,7 @@ class LocalGroupDB(SequenceDB):
def
query
(
self
,
*
tags
,
**
kwtags
):
opts
=
list
(
super
()
.
query
(
*
tags
,
**
kwtags
))
ret
=
self
.
local_opt
(
ret
=
self
.
node_rewriter
(
*
opts
,
apply_all_opts
=
self
.
apply_all_opts
,
profile
=
self
.
profile
)
return
ret
...
...
aesara/ifelse.py
浏览文件 @
550a6e98
...
...
@@ -22,7 +22,7 @@ from aesara.compile import optdb
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
node_rewrit
er
from
aesara.graph.type
import
HasDataType
,
HasShape
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
...
...
@@ -404,7 +404,7 @@ def ifelse(
return
tuple
(
rval
)
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_make_inplace
(
fgraph
,
node
):
op
=
node
.
op
if
(
...
...
@@ -482,7 +482,7 @@ acceptable_ops = (
)
@
local_optimiz
er
(
acceptable_ops
)
@
node_rewrit
er
(
acceptable_ops
)
def
ifelse_lift_single_if_through_acceptable_ops
(
fgraph
,
main_node
):
"""This optimization lifts up certain ifelse instances.
...
...
@@ -529,7 +529,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node):
return
nw_outs
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_merge_ifs_true
(
fgraph
,
node
):
op
=
node
.
op
if
not
isinstance
(
op
,
IfElse
):
...
...
@@ -556,7 +556,7 @@ def cond_merge_ifs_true(fgraph, node):
return
op
(
*
old_ins
,
return_list
=
True
)
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_merge_ifs_false
(
fgraph
,
node
):
op
=
node
.
op
if
not
isinstance
(
op
,
IfElse
):
...
...
@@ -635,7 +635,7 @@ class CondMerge(GraphRewriter):
fgraph
.
replace_all_validate
(
pairs
,
reason
=
"cond_merge"
)
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_remove_identical
(
fgraph
,
node
):
op
=
node
.
op
...
...
@@ -681,7 +681,7 @@ def cond_remove_identical(fgraph, node):
return
rval
@
local_optimiz
er
([
IfElse
])
@
node_rewrit
er
([
IfElse
])
def
cond_merge_random_op
(
fgraph
,
main_node
):
if
isinstance
(
main_node
.
op
,
IfElse
):
return
False
...
...
aesara/sandbox/linalg/ops.py
浏览文件 @
550a6e98
import
logging
from
aesara.graph.opt
import
local_optimiz
er
from
aesara.graph.opt
import
node_rewrit
er
from
aesara.tensor
import
basic
as
at
from
aesara.tensor.basic_opt
import
(
register_canonicalize
,
...
...
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
@register_canonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
transinv_to_invtrans
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
DimShuffle
):
if
node
.
op
.
new_order
==
(
1
,
0
):
...
...
@@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize
@
local_optimiz
er
([
Dot
,
Dot22
])
@
node_rewrit
er
([
Dot
,
Dot22
])
def
inv_as_solve
(
fgraph
,
node
):
"""
This utilizes a boolean `symmetric` tag on the matrices.
...
...
@@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node):
@register_stabilize
@register_canonicalize
@
local_optimiz
er
([
Solve
])
@
node_rewrit
er
([
Solve
])
def
tag_solve_triangular
(
fgraph
,
node
):
"""
If a general solve() is applied to the output of a cholesky op, then
...
...
@@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
no_transpose_symmetric
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
DimShuffle
):
x
=
node
.
inputs
[
0
]
...
...
@@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node):
@register_stabilize
@
local_optimiz
er
([
Solve
])
@
node_rewrit
er
([
Solve
])
def
psd_solve_with_chol
(
fgraph
,
node
):
"""
This utilizes a boolean `psd` tag on matrices.
...
...
@@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node):
@register_stabilize
@register_specialize
@
local_optimiz
er
([
Det
])
@
node_rewrit
er
([
Det
])
def
local_det_chol
(
fgraph
,
node
):
"""
If we have det(X) and there is already an L=cholesky(X)
...
...
@@ -129,7 +129,7 @@ def local_det_chol(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@
local_optimiz
er
([
log
])
@
node_rewrit
er
([
log
])
def
local_log_prod_sqr
(
fgraph
,
node
):
"""
This utilizes a boolean `positive` tag on matrices.
...
...
aesara/sandbox/rng_mrg.py
浏览文件 @
550a6e98
...
...
@@ -25,7 +25,7 @@ from aesara.compile import optdb
from
aesara.configdefaults
import
config
from
aesara.gradient
import
undefined_grad
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.opt
import
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
in2out
,
node_rewrit
er
from
aesara.link.c.op
import
COp
,
Op
from
aesara.link.c.params_type
import
ParamsType
from
aesara.sandbox
import
multinomial
...
...
@@ -1343,7 +1343,7 @@ def _check_size(size):
return
at
.
as_tensor_variable
(
size
,
ndim
=
1
)
@
local_optimiz
er
((
mrg_uniform_base
,))
@
node_rewrit
er
((
mrg_uniform_base
,))
def
mrg_random_make_inplace
(
fgraph
,
node
):
op
=
node
.
op
...
...
aesara/scan/opt.py
浏览文件 @
550a6e98
...
...
@@ -28,7 +28,7 @@ from aesara.graph.destroyhandler import DestroyHandler
from
aesara.graph.features
import
ReplaceValidate
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
GraphRewriter
,
in2out
,
node_rewrit
er
from
aesara.graph.optdb
import
EquilibriumDB
,
SequenceDB
from
aesara.graph.type
import
HasShape
from
aesara.graph.utils
import
InconsistencyError
...
...
@@ -67,7 +67,7 @@ list_opt_slice = [
]
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
remove_constants_and_unused_inputs_scan
(
fgraph
,
node
):
"""Move constants into the inner graph, and remove unused inputs.
...
...
@@ -192,7 +192,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return
False
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_non_seq_scan
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
...
...
@@ -400,7 +400,7 @@ def push_out_non_seq_scan(fgraph, node):
return
False
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_seq_scan
(
fgraph
,
node
):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
...
...
@@ -812,7 +812,7 @@ def add_nitsot_outputs(
return
new_scan_node
,
{}
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_add_scan
(
fgraph
,
node
):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
...
...
@@ -1113,7 +1113,7 @@ def sanitize(x):
return
at
.
as_tensor_variable
(
x
)
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
save_mem_new_scan
(
fgraph
,
node
):
r"""Graph optimizer that reduces scan memory consumption.
...
...
@@ -1950,7 +1950,7 @@ def make_equiv(lo, li):
return
left
,
right
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
scan_merge_inouts
(
fgraph
,
node
):
"""
This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well
...
...
@@ -2154,7 +2154,7 @@ def scan_merge_inouts(fgraph, node):
return
na
.
outer_outputs
@
local_optimiz
er
([
Scan
])
@
node_rewrit
er
([
Scan
])
def
push_out_dot1_scan
(
fgraph
,
node
):
r"""
This is another optimization that attempts to detect certain patterns of
...
...
aesara/sparse/opt.py
浏览文件 @
550a6e98
...
...
@@ -4,7 +4,7 @@ import aesara
import
aesara.scalar
as
aes
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
from
aesara.graph.opt
import
PatternSub
,
TopoOptimizer
,
local_optimiz
er
from
aesara.graph.opt
import
PatternSub
,
TopoOptimizer
,
node_rewrit
er
from
aesara.link.c.op
import
COp
,
_NoPythonCOp
from
aesara.misc.safe_asarray
import
_asarray
from
aesara.sparse
import
basic
as
sparse
...
...
@@ -32,7 +32,7 @@ _is_dense = sparse._is_dense
# This is tested in tests/test_opt.py:test_local_csm_properties_csm
@
local_optimiz
er
([
csm_properties
])
@
node_rewrit
er
([
csm_properties
])
def
local_csm_properties_csm
(
fgraph
,
node
):
"""
If we find csm_properties(CSM(*args)), then we can replace that with the
...
...
@@ -51,7 +51,7 @@ register_specialize(local_csm_properties_csm)
# This is tested in tests/test_basic.py:test_remove0
@
local_optimiz
er
([
sparse
.
Remove0
])
@
node_rewrit
er
([
sparse
.
Remove0
])
def
local_inplace_remove0
(
fgraph
,
node
):
"""
Optimization to insert inplace versions of Remove0.
...
...
@@ -188,7 +188,7 @@ class AddSD_ccode(_NoPythonCOp):
return
(
2
,)
@
local_optimiz
er
([
sparse
.
AddSD
])
@
node_rewrit
er
([
sparse
.
AddSD
])
def
local_inplace_addsd_ccode
(
fgraph
,
node
):
"""
Optimization to insert inplace versions of AddSD.
...
...
@@ -218,7 +218,7 @@ aesara.compile.optdb.register(
@register_canonicalize
(
"fast_compile"
)
@register_specialize
@
local_optimiz
er
([
sparse
.
DenseFromSparse
])
@
node_rewrit
er
([
sparse
.
DenseFromSparse
])
def
local_dense_from_sparse_sparse_from_dense
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
sparse
.
DenseFromSparse
):
inp
=
node
.
inputs
[
0
]
...
...
@@ -226,7 +226,7 @@ def local_dense_from_sparse_sparse_from_dense(fgraph, node):
return
inp
.
owner
.
inputs
@
local_optimiz
er
([
sparse
.
AddSD
])
@
node_rewrit
er
([
sparse
.
AddSD
])
def
local_addsd_ccode
(
fgraph
,
node
):
"""
Convert AddSD to faster AddSD_ccode.
...
...
@@ -638,7 +638,7 @@ sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx
# This is tested in tests/test_basic.py:792
@
local_optimiz
er
([
sparse
.
_structured_dot
])
@
node_rewrit
er
([
sparse
.
_structured_dot
])
def
local_structured_dot
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
_structured_dot
:
a
,
b
=
node
.
inputs
...
...
@@ -950,7 +950,7 @@ register_specialize(local_usmm, name="local_usmm")
# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace
# This is tested in tests/test_basic.py:UsmmTests
@
local_optimiz
er
([
usmm_csc_dense
])
@
node_rewrit
er
([
usmm_csc_dense
])
def
local_usmm_csc_dense_inplace
(
fgraph
,
node
):
if
node
.
op
==
usmm_csc_dense
:
return
[
usmm_csc_dense_inplace
(
*
node
.
inputs
)]
...
...
@@ -960,7 +960,7 @@ register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace")
# This is tested in tests/test_basic.py:UsmmTests
@
local_optimiz
er
([
usmm
])
@
node_rewrit
er
([
usmm
])
def
local_usmm_csx
(
fgraph
,
node
):
"""
usmm -> usmm_csc_dense
...
...
@@ -1120,7 +1120,7 @@ csm_grad_c = CSMGradC()
# register a specialization to replace csm_grad -> csm_grad_c
# This is tested in tests/test_opt.py:test_local_csm_grad_c
@
local_optimiz
er
([
csm_grad
(
None
)])
@
node_rewrit
er
([
csm_grad
(
None
)])
def
local_csm_grad_c
(
fgraph
,
node
):
"""
csm_grad(None) -> csm_grad_c
...
...
@@ -1404,7 +1404,7 @@ mul_s_d_csr = MulSDCSR()
# register a specialization to replace MulSD -> MulSDCSX
@
local_optimiz
er
([
sparse
.
mul_s_d
])
@
node_rewrit
er
([
sparse
.
mul_s_d
])
def
local_mul_s_d
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
mul_s_d
:
x
,
y
=
node
.
inputs
...
...
@@ -1584,7 +1584,7 @@ mul_s_v_csr = MulSVCSR()
# register a specialization to replace MulSV -> MulSVCSR
@
local_optimiz
er
([
sparse
.
mul_s_v
])
@
node_rewrit
er
([
sparse
.
mul_s_v
])
def
local_mul_s_v
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
mul_s_v
:
x
,
y
=
node
.
inputs
...
...
@@ -1762,7 +1762,7 @@ structured_add_s_v_csr = StructuredAddSVCSR()
# register a specialization to replace
# structured_add_s_v -> structured_add_s_v_csr
@
local_optimiz
er
([
sparse
.
structured_add_s_v
])
@
node_rewrit
er
([
sparse
.
structured_add_s_v
])
def
local_structured_add_s_v
(
fgraph
,
node
):
if
node
.
op
==
sparse
.
structured_add_s_v
:
x
,
y
=
node
.
inputs
...
...
@@ -2051,7 +2051,7 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr
@
local_optimiz
er
([
sparse
.
sampling_dot
])
@
node_rewrit
er
([
sparse
.
sampling_dot
])
def
local_sampling_dot_csr
(
fgraph
,
node
):
if
not
config
.
blas__ldflags
:
# The C implementation of SamplingDotCsr relies on BLAS routines
...
...
aesara/tensor/basic_opt.py
浏览文件 @
550a6e98
差异被折叠。
点击展开。
aesara/tensor/blas.py
浏览文件 @
550a6e98
...
...
@@ -150,7 +150,7 @@ from aesara.graph.opt import (
GraphRewriter
,
copy_stack_trace
,
in2out
,
local_optimiz
er
,
node_rewrit
er
,
)
from
aesara.graph.optdb
import
SequenceDB
from
aesara.graph.utils
import
InconsistencyError
,
MethodNotDefined
,
TestValueError
...
...
@@ -1733,7 +1733,7 @@ class Dot22(GemmRelated):
_dot22
=
Dot22
()
@
local_optimiz
er
([
Dot
])
@
node_rewrit
er
([
Dot
])
def
local_dot_to_dot22
(
fgraph
,
node
):
# This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below
...
...
@@ -1766,7 +1766,7 @@ def local_dot_to_dot22(fgraph, node):
_logger
.
info
(
f
"Not optimizing dot with inputs {x} {y} {x.type} {y.type}"
)
@
local_optimiz
er
([
gemm_no_inplace
],
inplace
=
True
)
@
node_rewrit
er
([
gemm_no_inplace
],
inplace
=
True
)
def
local_inplace_gemm
(
fgraph
,
node
):
if
node
.
op
==
gemm_no_inplace
:
new_out
=
[
gemm_inplace
(
*
node
.
inputs
)]
...
...
@@ -1774,7 +1774,7 @@ def local_inplace_gemm(fgraph, node):
return
new_out
@
local_optimiz
er
([
gemv_no_inplace
],
inplace
=
True
)
@
node_rewrit
er
([
gemv_no_inplace
],
inplace
=
True
)
def
local_inplace_gemv
(
fgraph
,
node
):
if
node
.
op
==
gemv_no_inplace
:
new_out
=
[
gemv_inplace
(
*
node
.
inputs
)]
...
...
@@ -1782,7 +1782,7 @@ def local_inplace_gemv(fgraph, node):
return
new_out
@
local_optimiz
er
([
ger
],
inplace
=
True
)
@
node_rewrit
er
([
ger
],
inplace
=
True
)
def
local_inplace_ger
(
fgraph
,
node
):
if
node
.
op
==
ger
:
new_out
=
[
ger_destructive
(
*
node
.
inputs
)]
...
...
@@ -1790,7 +1790,7 @@ def local_inplace_ger(fgraph, node):
return
new_out
@
local_optimiz
er
([
gemm_no_inplace
])
@
node_rewrit
er
([
gemm_no_inplace
])
def
local_gemm_to_gemv
(
fgraph
,
node
):
"""GEMM acting on row or column matrices -> GEMV."""
if
node
.
op
==
gemm_no_inplace
:
...
...
@@ -1807,7 +1807,7 @@ def local_gemm_to_gemv(fgraph, node):
return
new_out
@
local_optimiz
er
([
gemm_no_inplace
])
@
node_rewrit
er
([
gemm_no_inplace
])
def
local_gemm_to_ger
(
fgraph
,
node
):
"""GEMM computing an outer-product -> GER."""
if
node
.
op
==
gemm_no_inplace
:
...
...
@@ -1839,7 +1839,7 @@ def local_gemm_to_ger(fgraph, node):
# TODO: delete this optimization when we have the proper dot->gemm->ger pipeline
# working
@
local_optimiz
er
([
_dot22
])
@
node_rewrit
er
([
_dot22
])
def
local_dot22_to_ger_or_gemv
(
fgraph
,
node
):
"""dot22 computing an outer-product -> GER."""
if
node
.
op
==
_dot22
:
...
...
@@ -2033,7 +2033,7 @@ class Dot22Scalar(GemmRelated):
_dot22scalar
=
Dot22Scalar
()
@
local_optimiz
er
([
mul
])
@
node_rewrit
er
([
mul
])
def
local_dot22_to_dot22scalar
(
fgraph
,
node
):
"""
Notes
...
...
@@ -2651,7 +2651,7 @@ _batched_dot = BatchedDot()
# from opt import register_specialize, register_canonicalize
# @register_specialize
@
local_optimiz
er
([
sub
,
add
])
@
node_rewrit
er
([
sub
,
add
])
def
local_print_as_we_go_along
(
fgraph
,
node
):
if
node
.
op
in
(
sub
,
add
):
debugprint
(
node
)
...
...
aesara/tensor/blas_c.py
浏览文件 @
550a6e98
...
...
@@ -15,7 +15,7 @@ from aesara.tensor.blas import (
ger
,
ger_destructive
,
ldflags
,
local_optimiz
er
,
node_rewrit
er
,
optdb
,
)
...
...
@@ -344,7 +344,7 @@ cger_inplace = CGer(True)
cger_no_inplace
=
CGer
(
False
)
@
local_optimiz
er
([
ger
,
ger_destructive
])
@
node_rewrit
er
([
ger
,
ger_destructive
])
def
use_c_ger
(
fgraph
,
node
):
if
not
config
.
blas__ldflags
:
return
...
...
@@ -355,7 +355,7 @@ def use_c_ger(fgraph, node):
return
[
CGer
(
True
)(
*
node
.
inputs
)]
@
local_optimiz
er
([
CGer
(
False
)])
@
node_rewrit
er
([
CGer
(
False
)])
def
make_c_ger_destructive
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
CGer
)
and
not
node
.
op
.
destructive
:
return
[
cger_inplace
(
*
node
.
inputs
)]
...
...
@@ -699,7 +699,7 @@ int main() {
check_force_gemv_init
.
_force_init_beta
=
None
@
local_optimiz
er
([
gemv_inplace
,
gemv_no_inplace
])
@
node_rewrit
er
([
gemv_inplace
,
gemv_no_inplace
])
def
use_c_gemv
(
fgraph
,
node
):
if
not
config
.
blas__ldflags
:
return
...
...
@@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node):
return
[
cgemv_inplace
(
*
node
.
inputs
)]
@
local_optimiz
er
([
CGemv
(
inplace
=
False
)])
@
node_rewrit
er
([
CGemv
(
inplace
=
False
)])
def
make_c_gemv_destructive
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
CGemv
)
and
not
node
.
op
.
inplace
:
inputs
=
list
(
node
.
inputs
)
...
...
aesara/tensor/blas_scipy.py
浏览文件 @
550a6e98
...
...
@@ -11,7 +11,7 @@ from aesara.tensor.blas import (
ger
,
ger_destructive
,
have_fblas
,
local_optimiz
er
,
node_rewrit
er
,
optdb
,
)
...
...
@@ -58,13 +58,13 @@ scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace
=
ScipyGer
(
True
)
@
local_optimiz
er
([
ger
,
ger_destructive
])
@
node_rewrit
er
([
ger
,
ger_destructive
])
def
use_scipy_ger
(
fgraph
,
node
):
if
node
.
op
==
ger
:
return
[
scipy_ger_no_inplace
(
*
node
.
inputs
)]
@
local_optimiz
er
([
scipy_ger_no_inplace
])
@
node_rewrit
er
([
scipy_ger_no_inplace
])
def
make_ger_destructive
(
fgraph
,
node
):
if
node
.
op
==
scipy_ger_no_inplace
:
return
[
scipy_ger_inplace
(
*
node
.
inputs
)]
...
...
aesara/tensor/math_opt.py
浏览文件 @
550a6e98
差异被折叠。
点击展开。
aesara/tensor/nnet/basic.py
浏览文件 @
550a6e98
...
...
@@ -18,7 +18,7 @@ from aesara.compile import optdb
from
aesara.gradient
import
DisconnectedType
,
grad_not_implemented
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
,
optimizer
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
,
optimizer
from
aesara.link.c.op
import
COp
from
aesara.raise_op
import
Assert
from
aesara.scalar
import
UnaryScalarOp
...
...
@@ -1046,7 +1046,7 @@ class LogSoftmax(COp):
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@
local_optimiz
er
([
Elemwise
])
@
node_rewrit
er
([
Elemwise
])
def
local_logsoftmax
(
fgraph
,
node
):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
...
...
@@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node):
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize
(
"stabilize"
,
"fast_compile"
)
@
local_optimiz
er
([
SoftmaxGrad
])
@
node_rewrit
er
([
SoftmaxGrad
])
def
local_logsoftmax_grad
(
fgraph
,
node
):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
...
...
@@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS):
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
softmax_legacy
])
@
node_rewrit
er
([
softmax_legacy
])
def
local_softmax_with_bias
(
fgraph
,
node
):
"""
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
...
...
@@ -1954,7 +1954,7 @@ optdb.register(
@register_specialize
(
"fast_compile"
,
"local_crossentropy_to_crossentropy_with_softmax_grad"
)
# old name
@
local_optimiz
er
([
softmax_grad_legacy
])
@
node_rewrit
er
([
softmax_grad_legacy
])
def
local_softmax_grad_to_crossentropy_with_softmax_grad
(
fgraph
,
node
):
if
node
.
op
==
softmax_grad_legacy
and
node
.
inputs
[
1
]
.
ndim
==
2
:
g_coding_dist
,
coding_dist
=
node
.
inputs
...
...
@@ -1971,7 +1971,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
MaxAndArgmax
])
@
node_rewrit
er
([
MaxAndArgmax
])
def
local_argmax_pushdown
(
fgraph
,
node
):
if
(
isinstance
(
node
.
op
,
MaxAndArgmax
)
...
...
@@ -2060,7 +2060,7 @@ def _is_const(z, val, approx=False):
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
AdvancedSubtensor
,
log
])
@
node_rewrit
er
([
AdvancedSubtensor
,
log
])
def
local_advanced_indexing_crossentropy_onehot
(
fgraph
,
node
):
log_op
=
None
sm
=
None
...
...
@@ -2108,7 +2108,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
softmax_grad_legacy
])
@
node_rewrit
er
([
softmax_grad_legacy
])
def
local_advanced_indexing_crossentropy_onehot_grad
(
fgraph
,
node
):
if
not
(
node
.
op
==
softmax_grad_legacy
and
node
.
inputs
[
1
]
.
ndim
==
2
):
return
...
...
@@ -2323,7 +2323,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
@register_specialize
(
"fast_compile"
)
@
local_optimiz
er
([
softmax_with_bias
])
@
node_rewrit
er
([
softmax_with_bias
])
def
graph_merge_softmax_with_crossentropy_softmax
(
fgraph
,
node
):
if
node
.
op
==
softmax_with_bias
:
x
,
b
=
node
.
inputs
...
...
@@ -2340,7 +2340,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node):
@register_specialize
@register_stabilize
@register_canonicalize
@
local_optimiz
er
([
CrossentropySoftmax1HotWithBiasDx
])
@
node_rewrit
er
([
CrossentropySoftmax1HotWithBiasDx
])
def
local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc
(
fgraph
,
node
):
"""
Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is
...
...
aesara/tensor/nnet/batchnorm.py
浏览文件 @
550a6e98
...
...
@@ -4,7 +4,7 @@ import aesara
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
from
aesara.scalar
import
Composite
,
add
,
as_common_dtype
,
mul
,
sub
,
true_div
from
aesara.tensor
import
basic
as
at
from
aesara.tensor.basic
import
as_tensor_variable
...
...
@@ -778,7 +778,7 @@ class AbstractBatchNormTrainGrad(Op):
output_storage
[
2
][
0
]
=
g_wrt_bias
@
local_optimiz
er
([
AbstractBatchNormTrain
])
@
node_rewrit
er
([
AbstractBatchNormTrain
])
def
local_abstract_batch_norm_train
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormTrain
):
return
None
...
...
@@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node):
return
results
@
local_optimiz
er
([
AbstractBatchNormTrainGrad
])
@
node_rewrit
er
([
AbstractBatchNormTrainGrad
])
def
local_abstract_batch_norm_train_grad
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormTrainGrad
):
return
None
...
...
@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
return
results
@
local_optimiz
er
([
AbstractBatchNormInference
])
@
node_rewrit
er
([
AbstractBatchNormInference
])
def
local_abstract_batch_norm_inference
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractBatchNormInference
):
return
None
...
...
aesara/tensor/nnet/conv3d2d.py
浏览文件 @
550a6e98
...
...
@@ -3,7 +3,7 @@ from aesara import tensor as at
from
aesara.gradient
import
DisconnectedType
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
node_rewrit
er
def
get_diagonal_subtensor_view
(
x
,
i0
,
i1
):
...
...
@@ -296,7 +296,7 @@ def conv3d(
return
out_5d
@
local_optimiz
er
([
DiagonalSubtensor
,
IncDiagonalSubtensor
])
@
node_rewrit
er
([
DiagonalSubtensor
,
IncDiagonalSubtensor
])
def
local_inplace_DiagonalSubtensor
(
fgraph
,
node
):
"""Also work for IncDiagonalSubtensor."""
if
(
...
...
aesara/tensor/nnet/ctc.py
浏览文件 @
550a6e98
...
...
@@ -5,7 +5,7 @@ import aesara.tensor as at
from
aesara.configdefaults
import
config
from
aesara.gradient
import
grad_undefined
from
aesara.graph.basic
import
Apply
from
aesara.graph.opt
import
local_optimiz
er
from
aesara.graph.opt
import
node_rewrit
er
from
aesara.link.c.cmodule
import
GCC_compiler
from
aesara.link.c.op
import
ExternalCOp
,
OpenMPOp
from
aesara.tensor.basic_opt
import
register_canonicalize
...
...
@@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths):
# Disable gradient computation if not needed
@register_canonicalize
(
"fast_compile"
)
@
local_optimiz
er
([
ConnectionistTemporalClassification
])
@
node_rewrit
er
([
ConnectionistTemporalClassification
])
def
local_ctc_no_grad
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
ConnectionistTemporalClassification
):
if
len
(
node
.
outputs
)
>
1
:
...
...
aesara/tensor/nnet/opt.py
浏览文件 @
550a6e98
...
...
@@ -11,7 +11,7 @@ from aesara.graph.opt import (
TopoOptimizer
,
copy_stack_trace
,
in2out
,
local_optimiz
er
,
node_rewrit
er
,
)
from
aesara.tensor.basic_opt
import
register_specialize_device
from
aesara.tensor.nnet.abstract_conv
import
(
...
...
@@ -37,7 +37,7 @@ from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGrad
from
aesara.tensor.type
import
TensorType
@
local_optimiz
er
([
SparseBlockGemv
],
inplace
=
True
)
@
node_rewrit
er
([
SparseBlockGemv
],
inplace
=
True
)
def
local_inplace_sparse_block_gemv
(
fgraph
,
node
):
"""
SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True)
...
...
@@ -60,7 +60,7 @@ compile.optdb.register(
)
# DEBUG
@
local_optimiz
er
([
SparseBlockOuter
],
inplace
=
True
)
@
node_rewrit
er
([
SparseBlockOuter
],
inplace
=
True
)
def
local_inplace_sparse_block_outer
(
fgraph
,
node
):
"""
SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True)
...
...
@@ -85,7 +85,7 @@ compile.optdb.register(
# Conv opts
@
local_optimiz
er
([
AbstractConv2d
])
@
node_rewrit
er
([
AbstractConv2d
])
def
local_abstractconv_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
...
...
@@ -113,7 +113,7 @@ def local_abstractconv_gemm(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv3d
])
@
node_rewrit
er
([
AbstractConv3d
])
def
local_abstractconv3d_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
...
...
@@ -139,7 +139,7 @@ def local_abstractconv3d_gemm(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d_gradWeights
])
@
node_rewrit
er
([
AbstractConv2d_gradWeights
])
def
local_abstractconv_gradweight_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
...
...
@@ -169,7 +169,7 @@ def local_abstractconv_gradweight_gemm(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv3d_gradWeights
])
@
node_rewrit
er
([
AbstractConv3d_gradWeights
])
def
local_abstractconv3d_gradweight_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
...
...
@@ -197,7 +197,7 @@ def local_abstractconv3d_gradweight_gemm(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d_gradInputs
])
@
node_rewrit
er
([
AbstractConv2d_gradInputs
])
def
local_abstractconv_gradinputs_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
...
...
@@ -227,7 +227,7 @@ def local_abstractconv_gradinputs_gemm(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv3d_gradInputs
])
@
node_rewrit
er
([
AbstractConv3d_gradInputs
])
def
local_abstractconv3d_gradinputs_gemm
(
fgraph
,
node
):
# If config.blas__ldflags is empty, Aesara will use
# a NumPy C implementation of [sd]gemm_.
...
...
@@ -255,7 +255,7 @@ def local_abstractconv3d_gradinputs_gemm(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d
])
@
node_rewrit
er
([
AbstractConv2d
])
def
local_conv2d_cpu
(
fgraph
,
node
):
if
not
isinstance
(
node
.
op
,
AbstractConv2d
)
or
node
.
inputs
[
0
]
.
dtype
==
"float16"
:
...
...
@@ -287,7 +287,7 @@ def local_conv2d_cpu(fgraph, node):
return
[
rval
]
@
local_optimiz
er
([
AbstractConv2d_gradWeights
])
@
node_rewrit
er
([
AbstractConv2d_gradWeights
])
def
local_conv2d_gradweight_cpu
(
fgraph
,
node
):
if
(
not
isinstance
(
node
.
op
,
AbstractConv2d_gradWeights
)
...
...
@@ -396,7 +396,7 @@ def local_conv2d_gradweight_cpu(fgraph, node):
return
[
res
]
@
local_optimiz
er
([
AbstractConv2d_gradInputs
])
@
node_rewrit
er
([
AbstractConv2d_gradInputs
])
def
local_conv2d_gradinputs_cpu
(
fgraph
,
node
):
if
(
not
isinstance
(
node
.
op
,
AbstractConv2d_gradInputs
)
...
...
@@ -561,7 +561,7 @@ conv_groupopt.register(
# Verify that no AbstractConv are present in the graph
@
local_optimiz
er
(
@
node_rewrit
er
(
[
AbstractConv2d
,
AbstractConv2d_gradWeights
,
...
...
aesara/tensor/nnet/sigm.py
浏览文件 @
550a6e98
...
...
@@ -9,7 +9,7 @@ stability.
import
aesara
from
aesara
import
printing
from
aesara
import
scalar
as
aes
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
from
aesara.printing
import
pprint
from
aesara.scalar
import
sigmoid
as
scalar_sigmoid
from
aesara.scalar.math
import
Sigmoid
...
...
@@ -99,7 +99,7 @@ pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"
# @opt.register_uncanonicalize
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_ultra_fast_sigmoid
(
fgraph
,
node
):
"""
When enabled, change all sigmoid to ultra_fast_sigmoid.
...
...
@@ -159,7 +159,7 @@ def hard_sigmoid(x):
# @opt.register_uncanonicalize
@
local_optimiz
er
([
sigmoid
])
@
node_rewrit
er
([
sigmoid
])
def
local_hard_sigmoid
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
Elemwise
)
and
node
.
op
.
scalar_op
==
scalar_sigmoid
:
out
=
hard_sigmoid
(
node
.
inputs
[
0
])
...
...
aesara/tensor/opt_uncanonicalize.py
浏览文件 @
550a6e98
...
...
@@ -34,7 +34,7 @@ supposed to be canonical.
import
logging
from
aesara
import
scalar
as
aes
from
aesara.graph.opt
import
copy_stack_trace
,
local_optimiz
er
from
aesara.graph.opt
import
copy_stack_trace
,
node_rewrit
er
from
aesara.tensor.basic
import
Alloc
,
alloc
,
constant
from
aesara.tensor.basic_opt
import
register_uncanonicalize
from
aesara.tensor.elemwise
import
CAReduce
,
DimShuffle
...
...
@@ -47,7 +47,7 @@ _logger = logging.getLogger("aesara.tensor.opt_uncanonicalize")
@register_uncanonicalize
@
local_optimiz
er
([
MaxAndArgmax
])
@
node_rewrit
er
([
MaxAndArgmax
])
def
local_max_and_argmax
(
fgraph
,
node
):
"""
If we don't use the argmax, change it to a max only.
...
...
@@ -66,7 +66,7 @@ def local_max_and_argmax(fgraph, node):
@register_uncanonicalize
@
local_optimiz
er
([
neg
])
@
node_rewrit
er
([
neg
])
def
local_max_to_min
(
fgraph
,
node
):
"""
Change -(max(-x)) to min.
...
...
@@ -95,7 +95,7 @@ def local_max_to_min(fgraph, node):
@register_uncanonicalize
@
local_optimiz
er
([
Alloc
])
@
node_rewrit
er
([
Alloc
])
def
local_alloc_dimshuffle
(
fgraph
,
node
):
"""
If a dimshuffle is inside an alloc and only adds dimension to the
...
...
@@ -118,7 +118,7 @@ def local_alloc_dimshuffle(fgraph, node):
@register_uncanonicalize
@
local_optimiz
er
([
Reshape
])
@
node_rewrit
er
([
Reshape
])
def
local_reshape_dimshuffle
(
fgraph
,
node
):
"""
If a dimshuffle is inside a reshape and does not change the order
...
...
@@ -147,7 +147,7 @@ def local_reshape_dimshuffle(fgraph, node):
@register_uncanonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_alloc
(
fgraph
,
node
):
"""
If an alloc is inside a dimshuffle which only adds dimension to the left,
...
...
@@ -175,7 +175,7 @@ def local_dimshuffle_alloc(fgraph, node):
@register_uncanonicalize
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_subtensor
(
fgraph
,
node
):
"""If a subtensor is inside a dimshuffle which only drop
broadcastable dimensions, scrap the dimshuffle and index the
...
...
aesara/tensor/random/opt.py
浏览文件 @
550a6e98
from
aesara.compile
import
optdb
from
aesara.configdefaults
import
config
from
aesara.graph.op
import
compute_test_value
from
aesara.graph.opt
import
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
in2out
,
node_rewrit
er
from
aesara.tensor.basic
import
constant
,
get_vector_length
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.extra_ops
import
broadcast_to
...
...
@@ -39,7 +39,7 @@ def is_rv_used_in_graph(base_rv, node, fgraph):
return
not
all
(
_node_check
(
n
,
i
)
for
n
,
i
in
fgraph
.
clients
.
get
(
base_rv
,
()))
@
local_optimiz
er
([
RandomVariable
],
inplace
=
True
)
@
node_rewrit
er
([
RandomVariable
],
inplace
=
True
)
def
random_make_inplace
(
fgraph
,
node
):
op
=
node
.
op
...
...
@@ -61,7 +61,7 @@ optdb.register(
)
@
local_optimiz
er
(
tracks
=
None
)
@
node_rewrit
er
(
tracks
=
None
)
def
local_rv_size_lift
(
fgraph
,
node
):
"""Lift the ``size`` parameter in a ``RandomVariable``.
...
...
@@ -109,7 +109,7 @@ def local_rv_size_lift(fgraph, node):
return
new_node
.
outputs
@
local_optimiz
er
([
DimShuffle
])
@
node_rewrit
er
([
DimShuffle
])
def
local_dimshuffle_rv_lift
(
fgraph
,
node
):
"""Lift a ``DimShuffle`` through ``RandomVariable`` inputs.
...
...
@@ -266,7 +266,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return
False
@
local_optimiz
er
([
Subtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor
])
@
node_rewrit
er
([
Subtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor
])
def
local_subtensor_rv_lift
(
fgraph
,
node
):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
...
...
aesara/tensor/subtensor_opt.py
浏览文件 @
550a6e98
...
...
@@ -7,7 +7,7 @@ import aesara
import
aesara.scalar.basic
as
aes
from
aesara
import
compile
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
in2out
,
local_optimiz
er
from
aesara.graph.opt
import
TopoOptimizer
,
copy_stack_trace
,
in2out
,
node_rewrit
er
from
aesara.raise_op
import
Assert
from
aesara.tensor.basic
import
(
Alloc
,
...
...
@@ -202,7 +202,7 @@ def get_advsubtensor_axis(indices):
@register_specialize
@
local_optimiz
er
([
AdvancedSubtensor
])
@
node_rewrit
er
([
AdvancedSubtensor
])
def
local_replace_AdvancedSubtensor
(
fgraph
,
node
):
r"""
This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for
...
...
@@ -231,7 +231,7 @@ def local_replace_AdvancedSubtensor(fgraph, node):
@register_specialize
@
local_optimiz
er
([
AdvancedIncSubtensor
])
@
node_rewrit
er
([
AdvancedIncSubtensor
])
def
local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1
(
fgraph
,
node
):
r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s.
...
...
@@ -268,7 +268,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
@register_canonicalize
@register_stabilize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_of_dot
(
fgraph
,
node
):
"""Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``.
``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is
...
...
@@ -326,7 +326,7 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_useless_slice
(
fgraph
,
node
):
"""
Remove Subtensor of the form X[0, :] -> X[0]
...
...
@@ -362,7 +362,7 @@ def local_useless_slice(fgraph, node):
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize
(
"fast_compile"
)
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
...
...
@@ -466,7 +466,7 @@ def local_subtensor_lift(fgraph, node):
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_merge
(
fgraph
,
node
):
"""
Refactored optimization to deal with all cases of tensor merging.
...
...
@@ -537,7 +537,7 @@ def local_subtensor_merge(fgraph, node):
@register_specialize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_remove_broadcastable_index
(
fgraph
,
node
):
"""
Remove broadcastable dimension with index 0 or -1
...
...
@@ -586,7 +586,7 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_of_alloc
(
fgraph
,
node
):
"""
...
...
@@ -654,7 +654,7 @@ def local_subtensor_of_alloc(fgraph, node):
@register_specialize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_inc_subtensor
(
fgraph
,
node
):
"""
Subtensor(SetSubtensor(x, y, idx), idx) -> y
...
...
@@ -694,7 +694,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
@register_specialize
@register_canonicalize
(
"fast_compile"
)
@register_useless
@
local_optimiz
er
([
Subtensor
,
AdvancedSubtensor1
])
@
node_rewrit
er
([
Subtensor
,
AdvancedSubtensor1
])
def
local_subtensor_make_vector
(
fgraph
,
node
):
"""Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant.
...
...
@@ -770,7 +770,7 @@ def local_subtensor_make_vector(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
IncSubtensor
])
@
node_rewrit
er
([
IncSubtensor
])
def
local_useless_inc_subtensor
(
fgraph
,
node
):
r"""Remove redundant `IncSubtensor`\s.
...
...
@@ -834,7 +834,7 @@ def local_useless_inc_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
AdvancedIncSubtensor1
])
@
node_rewrit
er
([
AdvancedIncSubtensor1
])
def
local_set_to_inc_subtensor
(
fgraph
,
node
):
r"""
AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) ->
...
...
@@ -878,7 +878,7 @@ def local_set_to_inc_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_useless_subtensor
(
fgraph
,
node
):
"""Remove `Subtensor` if it takes the full input."""
# This optimization needs ShapeOpt and fgraph.shape_feature
...
...
@@ -960,7 +960,7 @@ def local_useless_subtensor(fgraph, node):
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
AdvancedSubtensor1
])
@
node_rewrit
er
([
AdvancedSubtensor1
])
def
local_useless_AdvancedSubtensor1
(
fgraph
,
node
):
"""Remove `AdvancedSubtensor1` if it takes the full input.
...
...
@@ -1116,7 +1116,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
@register_canonicalize
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
local_IncSubtensor_serialize
(
fgraph
,
node
):
"""
When using Subtensor, gradient graphs can be ugly.
...
...
@@ -1216,7 +1216,7 @@ compile.optdb.register(
# gemm is the first one now, at priority 70
@
local_optimiz
er
([
IncSubtensor
],
inplace
=
True
)
@
node_rewrit
er
([
IncSubtensor
],
inplace
=
True
)
def
local_inplace_setsubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
IncSubtensor
)
and
not
node
.
op
.
inplace
:
dta
=
node
.
op
.
destroyhandler_tolerate_aliased
...
...
@@ -1249,7 +1249,7 @@ compile.optdb.register(
)
@
local_optimiz
er
([
AdvancedIncSubtensor1
],
inplace
=
True
)
@
node_rewrit
er
([
AdvancedIncSubtensor1
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor1
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
)
and
not
node
.
op
.
inplace
:
new_op
=
node
.
op
.
clone_inplace
()
...
...
@@ -1270,7 +1270,7 @@ compile.optdb.register(
)
@
local_optimiz
er
([
AdvancedIncSubtensor
],
inplace
=
True
)
@
node_rewrit
er
([
AdvancedIncSubtensor
],
inplace
=
True
)
def
local_inplace_AdvancedIncSubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
type
(
node
.
op
)(
...
...
@@ -1298,7 +1298,7 @@ compile.optdb.register(
# Register old name
@register_canonicalize
(
"local_incsubtensor_of_allocs"
)
@register_stabilize
(
"local_incsubtensor_of_allocs"
)
@
local_optimiz
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
@
node_rewrit
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_incsubtensor_of_zeros
(
fgraph
,
node
):
"""
IncSubtensor(x, zeros, idx) -> x
...
...
@@ -1323,7 +1323,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
IncSubtensor
])
@
node_rewrit
er
([
IncSubtensor
])
def
local_incsubtensor_of_zeros_to_setsubtensor
(
fgraph
,
node
):
"""
IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...)
...
...
@@ -1344,7 +1344,7 @@ def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node):
@register_canonicalize
(
"local_setsubtensor_of_allocs"
)
@register_stabilize
(
"local_setsubtensor_of_allocs"
)
@
local_optimiz
er
([
IncSubtensor
])
@
node_rewrit
er
([
IncSubtensor
])
def
local_setsubtensor_of_constants
(
fgraph
,
node
):
"""
SetSubtensor(x, x[idx], idx) -> x
...
...
@@ -1379,7 +1379,7 @@ def local_setsubtensor_of_constants(fgraph, node):
@register_canonicalize
@register_specialize
@
local_optimiz
er
([
AdvancedSubtensor1
])
@
node_rewrit
er
([
AdvancedSubtensor1
])
def
local_adv_sub1_adv_inc_sub1
(
fgraph
,
node
):
"""Optimize the possible AdvSub1(AdvSetSub1(...), ...).
...
...
@@ -1446,7 +1446,7 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
@register_stabilize
@register_canonicalize
@register_useless
@
local_optimiz
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
@
node_rewrit
er
([
IncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
])
def
local_useless_inc_subtensor_alloc
(
fgraph
,
node
):
"""
Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of
...
...
@@ -1552,7 +1552,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
@register_specialize
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_shape_constant
(
fgraph
,
node
):
r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known.
...
...
@@ -1606,7 +1606,7 @@ def local_subtensor_shape_constant(fgraph, node):
@register_canonicalize
@
local_optimiz
er
([
Subtensor
])
@
node_rewrit
er
([
Subtensor
])
def
local_subtensor_SpecifyShape_lift
(
fgraph
,
node
):
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
...
...
@@ -1640,7 +1640,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
@register_specialize
@
local_optimiz
er
([
Join
])
@
node_rewrit
er
([
Join
])
def
local_join_subtensors
(
fgraph
,
node
):
r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`.
...
...
aesara/typed_list/opt.py
浏览文件 @
550a6e98
from
aesara.compile
import
optdb
from
aesara.graph.opt
import
TopoOptimizer
,
local_optimiz
er
from
aesara.graph.opt
import
TopoOptimizer
,
node_rewrit
er
from
aesara.typed_list.basic
import
Append
,
Extend
,
Insert
,
Remove
,
Reverse
@
local_optimiz
er
([
Append
,
Extend
,
Insert
,
Reverse
,
Remove
],
inplace
=
True
)
@
node_rewrit
er
([
Append
,
Extend
,
Insert
,
Reverse
,
Remove
],
inplace
=
True
)
def
typed_list_inplace_opt
(
fgraph
,
node
):
if
(
isinstance
(
node
.
op
,
(
Append
,
Extend
,
Insert
,
Reverse
,
Remove
))
...
...
doc/extending/graph_rewriting.rst
浏览文件 @
550a6e98
...
...
@@ -67,15 +67,15 @@ Local optimization
A local optimization is an object which defines the following methods:
.. class::
LocalOptimiz
er
.. class::
NodeRewrit
er
.. method:: transform(fgraph, node)
This method takes a :class:`FunctionGraph` and an :class:`Apply` node and
returns either ``False`` to signify that no changes are to be done or a
list of :class:`Variable`\s which matches the length of the node's ``outputs``
list. When the :class:`
LocalOptimiz
er` is applied by a :class:`NavigatorOptimizer`, the outputs
of the node passed as argument to the :class:`
LocalOptimiz
er` will be replaced by
list. When the :class:`
NodeRewrit
er` is applied by a :class:`NavigatorOptimizer`, the outputs
of the node passed as argument to the :class:`
NodeRewrit
er` will be replaced by
the list returned.
...
...
@@ -218,10 +218,10 @@ The local version of the above code would be the following:
.. testcode::
from aesara.graph.opt import
LocalOptimiz
er
from aesara.graph.opt import
NodeRewrit
er
class LocalSimplify(
LocalOptimiz
er):
class LocalSimplify(
NodeRewrit
er):
def transform(self, fgraph, node):
if node.op == true_div:
x, y = node.inputs
...
...
@@ -234,7 +234,7 @@ The local version of the above code would be the following:
return False
def tracks(self):
# This tells certain navigators to only apply this `
LocalOptimiz
er`
# This tells certain navigators to only apply this `
NodeRewrit
er`
# on these kinds of `Op`s
return [true_div]
...
...
@@ -242,7 +242,7 @@ The local version of the above code would be the following:
In this case, the transformation is defined in the
:meth:`
LocalOptimiz
er.transform` method, which is given an explicit
:meth:`
NodeRewrit
er.transform` method, which is given an explicit
:class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is
also provided, in case global information is needed.
...
...
@@ -273,7 +273,7 @@ FunctionGraph(add(z, mul(x, true_div(z, x))))
:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub`
++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aesara defines some shortcuts to make :class:`
LocalOptimiz
er`\s:
Aesara defines some shortcuts to make :class:`
NodeRewrit
er`\s:
.. function:: OpSub(op1, op2)
...
...
@@ -433,7 +433,7 @@ This means that a relation that--say--represents :math:`x + x = 2 x` can be
utilized in both directions.
Currently, the local optimizer :class:`KanrenRelationSub` provides a means of
turning :mod:`kanren` relations into :class:`
LocalOptimiz
er`\s; however,
turning :mod:`kanren` relations into :class:`
NodeRewrit
er`\s; however,
:mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so
:class:`KanrenRelationSub` is not necessary.
...
...
@@ -561,7 +561,7 @@ serve as a basis for filtering.
The point of :obj:`optdb` is that you might want to apply many optimizations
to a computation graph in many unique patterns. For example, you might
want to do optimization X, then optimization Y, then optimization Z. And then
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`
LocalOptimiz
er`\s A, B
maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`
NodeRewrit
er`\s A, B
and C which are applied on every node of the graph until they all fail to change
it. If some optimizations act up, we want an easy way to turn them off. Ditto if
some optimizations are very CPU-intensive and we don't want to take the time to
...
...
@@ -596,14 +596,14 @@ is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase`
instances, the :class:`OptimizationQuery` will be passed to them as well and the
optimizers they return will be put in their places.
An :class:`EquilibriumDB` contains :class:`
LocalOptimiz
er` or :class:`OptimizationDatabase` objects. Each of them
An :class:`EquilibriumDB` contains :class:`
NodeRewrit
er` or :class:`OptimizationDatabase` objects. Each of them
has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to
an :class:`EquilibriumDB`, all :class:`
LocalOptimiz
er`\s that match the query are
an :class:`EquilibriumDB`, all :class:`
NodeRewrit
er`\s that match the query are
inserted into an :class:`EquilibriumOptimizer`, which is returned. If the
:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the
:class:`OptimizationQuery` will be passed to them as well and the
:class:`
LocalOptimiz
er`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`
LocalOptimiz
er` objects, so this
:class:`
NodeRewrit
er`\s they return will be put in their places
(note that as of yet no :class:`OptimizationDatabase` can produce :class:`
NodeRewrit
er` objects, so this
is a moot point).
Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which
...
...
@@ -697,10 +697,10 @@ already-compiled functions will see no change. The 'order' parameter
Registering a :class:`
LocalOptimiz
er`
-----------------------------------
--
Registering a :class:`
NodeRewrit
er`
-----------------------------------
:class:`
LocalOptimiz
er`\s may be registered in two ways:
:class:`
NodeRewrit
er`\s may be registered in two ways:
* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer
(see previous section).
...
...
tests/compile/test_debugmode.py
浏览文件 @
550a6e98
...
...
@@ -18,7 +18,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.features
import
BadOptimization
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
local_optimiz
er
from
aesara.graph.opt
import
node_rewrit
er
from
aesara.graph.optdb
import
EquilibriumDB
from
aesara.link.c.op
import
COp
from
aesara.tensor.math
import
add
,
dot
,
log
...
...
@@ -237,7 +237,7 @@ def test_badthunkoutput():
def
test_badoptimization
():
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_broken_add
(
fgraph
,
node
):
if
node
.
op
==
add
:
return
[
off_by_half
(
*
node
.
inputs
)]
...
...
@@ -263,7 +263,7 @@ def test_badoptimization():
def
test_badoptimization_opt_err
():
# This variant of test_badoptimization() replace the working code
# with a new apply node that will raise an error.
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_bigger_b_add
(
fgraph
,
node
):
if
node
.
op
==
add
:
inputs
=
list
(
node
.
inputs
)
...
...
@@ -272,7 +272,7 @@ def test_badoptimization_opt_err():
return
[
node
.
op
(
*
inputs
)]
return
False
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_bad_dtype
(
fgraph
,
node
):
if
node
.
op
==
add
:
inputs
=
list
(
node
.
inputs
)
...
...
@@ -326,7 +326,7 @@ def test_stochasticoptimization():
last_time_replaced
=
[
False
]
@
local_optimiz
er
([
add
])
@
node_rewrit
er
([
add
])
def
insert_broken_add_sometimes
(
fgraph
,
node
):
if
node
.
op
==
add
:
last_time_replaced
[
0
]
=
not
last_time_replaced
[
0
]
...
...
tests/graph/test_opt.py
浏览文件 @
550a6e98
...
...
@@ -15,10 +15,10 @@ from aesara.graph.opt import (
PatternSub
,
TopoOptimizer
,
in2out
,
local_optimizer
,
logging
,
node_rewriter
,
pre_constant_merge
,
pre_greedy_
local_optimiz
er
,
pre_greedy_
node_rewrit
er
,
)
from
aesara.raise_op
import
assert_op
from
aesara.tensor.basic_opt
import
constant_folding
...
...
@@ -547,7 +547,7 @@ def test_pre_constant_merge():
assert
res
==
[
adv
]
def
test_pre_greedy_
local_optimiz
er
():
def
test_pre_greedy_
node_rewrit
er
():
empty_fgraph
=
FunctionGraph
([],
[])
...
...
@@ -564,7 +564,7 @@ def test_pre_greedy_local_optimizer():
# This should fold `o1`, because it has only `Constant` arguments, and
# replace it with the `Constant` result
cst
=
pre_greedy_
local_optimiz
er
(
empty_fgraph
,
[
constant_folding
],
o2
)
cst
=
pre_greedy_
node_rewrit
er
(
empty_fgraph
,
[
constant_folding
],
o2
)
assert
cst
.
owner
.
inputs
[
0
]
.
owner
is
None
assert
cst
.
owner
.
inputs
[
1
]
is
c2
...
...
@@ -577,14 +577,14 @@ def test_pre_greedy_local_optimizer():
fg
=
FunctionGraph
([],
[
o1
],
clone
=
False
)
o2
=
op1
(
o1
,
c2
,
x
,
o3
,
o1
)
cst
=
pre_greedy_
local_optimiz
er
(
fg
,
[
constant_folding
],
o2
)
cst
=
pre_greedy_
node_rewrit
er
(
fg
,
[
constant_folding
],
o2
)
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
# What exactly is this supposed to test?
ms
=
MakeSlice
()(
1
)
cst
=
pre_greedy_
local_optimiz
er
(
empty_fgraph
,
[
constant_folding
],
ms
)
cst
=
pre_greedy_
node_rewrit
er
(
empty_fgraph
,
[
constant_folding
],
ms
)
assert
isinstance
(
cst
,
SliceConstant
)
...
...
@@ -673,13 +673,13 @@ class TestLocalOptGroup:
fgraph
=
FunctionGraph
([
x
,
y
],
[
o1
],
clone
=
False
)
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_opt_1
(
fgraph
,
node
):
if
node
.
inputs
[
0
]
==
x
:
res
=
op2
(
y
,
*
node
.
inputs
[
1
:])
return
[
res
]
@
local_optimiz
er
(
None
)
@
node_rewrit
er
(
None
)
def
local_opt_2
(
fgraph
,
node
):
if
node
.
inputs
[
0
]
==
y
:
res
=
op2
(
x
,
*
node
.
inputs
[
1
:])
...
...
@@ -703,8 +703,8 @@ class TestLocalOptGroup:
)
def
test_
local_optimiz
er_str
():
@
local_optimiz
er
([
op1
,
MyOp
])
def
test_
node_rewrit
er_str
():
@
node_rewrit
er
([
op1
,
MyOp
])
def
local_opt_1
(
fgraph
,
node
):
pass
...
...
@@ -715,17 +715,17 @@ def test_local_optimizer_str():
assert
"local_opt_1"
in
res
def
test_
local_optimiz
er
():
def
test_
node_rewrit
er
():
with
pytest
.
raises
(
ValueError
):
@
local_optimiz
er
([])
@
node_rewrit
er
([])
def
local_bad_1
(
fgraph
,
node
):
return
node
.
outputs
with
pytest
.
raises
(
TypeError
):
@
local_optimiz
er
([
None
])
@
node_rewrit
er
([
None
])
def
local_bad_2
(
fgraph
,
node
):
return
node
.
outputs
...
...
@@ -748,7 +748,7 @@ def test_local_optimizer():
hits
=
[
0
]
@
local_optimiz
er
([
op1
,
MyNewOp
])
@
node_rewrit
er
([
op1
,
MyNewOp
])
def
local_opt_1
(
fgraph
,
node
,
hits
=
hits
):
hits
[
0
]
+=
1
return
node
.
outputs
...
...
@@ -766,24 +766,24 @@ def test_local_optimizer():
assert
hits
[
0
]
==
2
def
test_Tracking
LocalOptimiz
er
():
@
local_optimiz
er
(
None
)
def
test_Tracking
NodeRewrit
er
():
@
node_rewrit
er
(
None
)
def
local_opt_1
(
fgraph
,
node
):
pass
@
local_optimiz
er
([
op1
])
@
node_rewrit
er
([
op1
])
def
local_opt_2
(
fgraph
,
node
):
pass
@
local_optimiz
er
([
Op
])
@
node_rewrit
er
([
Op
])
def
local_opt_3
(
fgraph
,
node
):
pass
@
local_optimiz
er
([
MyOp
])
@
node_rewrit
er
([
MyOp
])
def
local_opt_4
(
fgraph
,
node
):
pass
@
local_optimiz
er
([
MyOp
])
@
node_rewrit
er
([
MyOp
])
def
local_opt_5
(
fgraph
,
node
):
pass
...
...
tests/tensor/test_basic_opt.py
浏览文件 @
550a6e98
...
...
@@ -16,7 +16,7 @@ from aesara.configdefaults import config
from
aesara.graph.basic
import
Apply
,
Constant
,
Variable
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
check_stack_trace
,
local_optimiz
er
,
out2in
from
aesara.graph.opt
import
check_stack_trace
,
node_rewrit
er
,
out2in
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.graph.type
import
Type
...
...
@@ -1752,7 +1752,7 @@ class TestShapeOptimizer:
identity_shape
=
IdentityShape
()
@
local_optimiz
er
([
IdentityNoShape
])
@
node_rewrit
er
([
IdentityNoShape
])
def
local_identity_noshape_to_identity_shape
(
fgraph
,
node
):
"""Optimization transforming the first Op into the second"""
if
isinstance
(
node
.
op
,
IdentityNoShape
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论