Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
46d46d76
提交
46d46d76
authored
8月 13, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/gof/opt.py
上级
ae99f41d
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
275 行增加
和
133 行删除
+275
-133
opt.py
theano/gof/opt.py
+275
-133
没有找到文件。
theano/gof/opt.py
浏览文件 @
46d46d76
"""
"""
Defines the base class for optimizations as well as a certain
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
amount of useful generic optimization tools.
"""
"""
from
__future__
import
print_function
from
__future__
import
print_function
...
@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph):
...
@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph):
class
Optimizer
(
object
):
class
Optimizer
(
object
):
"""WRITEME
"""
WRITEME
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind
It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}.
of transformation you could apply to an L{FunctionGraph}.
"""
"""
def
__hash__
(
self
):
def
__hash__
(
self
):
...
@@ -58,19 +62,25 @@ class Optimizer(object):
...
@@ -58,19 +62,25 @@ class Optimizer(object):
return
id
(
self
)
!=
id
(
other
)
return
id
(
self
)
!=
id
(
other
)
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
"""WRITEME
"""
WRITEME
Applies the optimization to the provided L{FunctionGraph}. It may
Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the
use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an
L{Optimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method.
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
"""
pass
pass
def
optimize
(
self
,
fgraph
,
*
args
,
**
kwargs
):
def
optimize
(
self
,
fgraph
,
*
args
,
**
kwargs
):
"""WRITEME
"""
This is meant as a shortcut to::
WRITEME
This is meant as a shortcut to:
opt.add_requirements(fgraph)
opt.add_requirements(fgraph)
opt.apply(fgraph)
opt.apply(fgraph)
"""
"""
self
.
add_requirements
(
fgraph
)
self
.
add_requirements
(
fgraph
)
try
:
try
:
...
@@ -82,18 +92,24 @@ class Optimizer(object):
...
@@ -82,18 +92,24 @@ class Optimizer(object):
return
ret
return
ret
def
__call__
(
self
,
fgraph
):
def
__call__
(
self
,
fgraph
):
"""WRITEME
"""
Same as self.optimize(fgraph)
WRITEME
Same as self.optimize(fgraph).
"""
"""
return
self
.
optimize
(
fgraph
)
return
self
.
optimize
(
fgraph
)
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
"""WRITEME
"""
WRITEME
Add features to the fgraph that are required to apply the optimization.
Add features to the fgraph that are required to apply the optimization.
For example:
For example:
fgraph.attach_feature(History())
fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature())
fgraph.attach_feature(MyFeature())
etc.
etc.
"""
"""
pass
pass
...
@@ -111,7 +127,10 @@ class Optimizer(object):
...
@@ -111,7 +127,10 @@ class Optimizer(object):
class
FromFunctionOptimizer
(
Optimizer
):
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
fn
,
requirements
=
()):
def
__init__
(
self
,
fn
,
requirements
=
()):
self
.
apply
=
fn
self
.
apply
=
fn
self
.
requirements
=
requirements
self
.
requirements
=
requirements
...
@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer):
...
@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer):
def
optimizer
(
f
):
def
optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
rval
=
FromFunctionOptimizer
(
f
)
rval
=
FromFunctionOptimizer
(
f
)
rval
.
__name__
=
f
.
__name__
rval
.
__name__
=
f
.
__name__
return
rval
return
rval
def
inplace_optimizer
(
f
):
def
inplace_optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
dh_handler
=
dh
.
DestroyHandler
dh_handler
=
dh
.
DestroyHandler
requirements
=
(
lambda
fgraph
:
requirements
=
(
lambda
fgraph
:
fgraph
.
attach_feature
(
dh_handler
()),)
fgraph
.
attach_feature
(
dh_handler
()),)
...
@@ -152,13 +177,18 @@ def inplace_optimizer(f):
...
@@ -152,13 +177,18 @@ def inplace_optimizer(f):
class
SeqOptimizer
(
Optimizer
,
list
):
class
SeqOptimizer
(
Optimizer
,
list
):
# inherit from Optimizer first to get Optimizer.__hash__
# inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
"""
WRITEME
Takes a list of L{Optimizer} instances and applies them
Takes a list of L{Optimizer} instances and applies them
sequentially.
sequentially.
"""
"""
@staticmethod
@staticmethod
def
warn
(
exc
,
self
,
optimizer
):
def
warn
(
exc
,
self
,
optimizer
):
"""Default failure_callback for SeqOptimizer
"""
Default failure_callback for SeqOptimizer.
"""
"""
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
"Traceback:"
)
...
@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list):
...
@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list):
pdb
.
post_mortem
(
sys
.
exc_info
()[
2
])
pdb
.
post_mortem
(
sys
.
exc_info
()[
2
])
def
__init__
(
self
,
*
opts
,
**
kw
):
def
__init__
(
self
,
*
opts
,
**
kw
):
"""WRITEME"""
"""
WRITEME
"""
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
opts
=
opts
[
0
]
opts
=
opts
[
0
]
self
[:]
=
opts
self
[:]
=
opts
self
.
failure_callback
=
kw
.
pop
(
'failure_callback'
,
None
)
self
.
failure_callback
=
kw
.
pop
(
'failure_callback'
,
None
)
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
"""WRITEME
"""
WRITEME
Applies each L{Optimizer} in self in turn.
Applies each L{Optimizer} in self in turn.
"""
"""
l
=
[]
l
=
[]
if
fgraph
.
profile
:
if
fgraph
.
profile
:
...
@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list):
def
merge_profile
(
prof1
,
prof2
):
def
merge_profile
(
prof1
,
prof2
):
"""
"""
Merge 2 profiles returned by this cass apply() fct.
Merge 2 profiles returned by this cass apply() fct.
"""
"""
new_t
=
[]
new_t
=
[]
new_l
=
[]
new_l
=
[]
...
@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list):
...
@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list):
class
_metadict
:
class
_metadict
:
"""WRITEME"""
"""
WRITEME
"""
# dict that accepts unhashable keys
# dict that accepts unhashable keys
# uses an associative list
# uses an associative list
# for internal use only
# for internal use only
...
@@ -430,6 +471,7 @@ class MergeFeature(object):
...
@@ -430,6 +471,7 @@ class MergeFeature(object):
That way, the MergeOptimizer can remember the result of the last merge
That way, the MergeOptimizer can remember the result of the last merge
pass on the fgraph.
pass on the fgraph.
"""
"""
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
assert
not
hasattr
(
fgraph
,
'merge_feature'
)
assert
not
hasattr
(
fgraph
,
'merge_feature'
)
...
@@ -493,7 +535,10 @@ class MergeFeature(object):
...
@@ -493,7 +535,10 @@ class MergeFeature(object):
self
.
seen_constants
.
discard
(
id
(
c
))
self
.
seen_constants
.
discard
(
id
(
c
))
def
process_constant
(
self
,
fgraph
,
c
):
def
process_constant
(
self
,
fgraph
,
c
):
"""Check if a constant can be merged, and queue that replacement"""
"""
Check if a constant can be merged, and queue that replacement.
"""
if
id
(
c
)
in
self
.
seen_constants
:
if
id
(
c
)
in
self
.
seen_constants
:
return
return
sig
=
c
.
merge_signature
()
sig
=
c
.
merge_signature
()
...
@@ -511,7 +556,10 @@ class MergeFeature(object):
...
@@ -511,7 +556,10 @@ class MergeFeature(object):
self
.
seen_constants
.
add
(
id
(
c
))
self
.
seen_constants
.
add
(
id
(
c
))
def
process_node
(
self
,
fgraph
,
node
):
def
process_node
(
self
,
fgraph
,
node
):
"""Check if a node can be merged, and queue that replacement."""
"""
Check if a node can be merged, and queue that replacement.
"""
if
node
in
self
.
nodes_seen
:
if
node
in
self
.
nodes_seen
:
return
return
...
@@ -570,6 +618,7 @@ class MergeOptimizer(Optimizer):
...
@@ -570,6 +618,7 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
int(1) for example, are transferred to a particular instance of int(1).
"""
"""
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
...
@@ -678,6 +727,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
...
@@ -678,6 +727,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
Merge-based implementation of `theano.gof.graph.is_same_graph`.
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
"""
if
givens
is
None
:
if
givens
is
None
:
givens
=
{}
givens
=
{}
...
@@ -718,13 +768,15 @@ def pre_constant_merge(vars):
...
@@ -718,13 +768,15 @@ def pre_constant_merge(vars):
`vars` is a list of nodes, and we want to merge together nodes
`vars` is a list of nodes, and we want to merge together nodes
that are constant inputs used to compute nodes in that list.
that are constant inputs used to compute nodes in that list.
:note: This function will ignore nodes that are in an fgraph.
Notes
-----
This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
so that DebugMode will not check each of them.
"""
"""
seen_var
=
set
()
seen_var
=
set
()
# signature -> variable (for constants)
# signature -> variable (for constants)
const_sig_inv
=
{}
const_sig_inv
=
{}
...
@@ -767,10 +819,12 @@ def pre_constant_merge(vars):
...
@@ -767,10 +819,12 @@ def pre_constant_merge(vars):
########################
########################
class
LocalOptimizer
(
object
):
class
LocalOptimizer
(
object
):
"""A class for node-based optimizations.
"""
A class for node-based optimizations.
Instances should implement the transform function,
Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance.
and be passed to configure a fgraph-based Optimizer instance.
"""
"""
def
__hash__
(
self
):
def
__hash__
(
self
):
...
@@ -784,11 +838,13 @@ class LocalOptimizer(object):
...
@@ -784,11 +838,13 @@ class LocalOptimizer(object):
Return the list of op classes that this opt applies to.
Return the list of op classes that this opt applies to.
Return None to apply to all nodes.
Return None to apply to all nodes.
"""
"""
return
None
return
None
def
transform
(
self
,
node
):
def
transform
(
self
,
node
):
"""Transform a subgraph whose output is `node`.
"""
Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
Subclasses should implement this function so that it returns one of two
kinds of things:
kinds of things:
...
@@ -800,7 +856,9 @@ class LocalOptimizer(object):
...
@@ -800,7 +856,9 @@ class LocalOptimizer(object):
- dict(old variables -> new variables). A dictionary that map
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
from old variables to new variables to replace.
:type node: an Apply instance
Parameters
----------
node : an Apply instance
"""
"""
...
@@ -810,8 +868,8 @@ class LocalOptimizer(object):
...
@@ -810,8 +868,8 @@ class LocalOptimizer(object):
def
add_requirements
(
self
,
fgraph
):
def
add_requirements
(
self
,
fgraph
):
"""
"""
If this local optimization wants to add some requirements to the
If this local optimization wants to add some requirements to the
fgraph,
fgraph,
this is the place to do it.
This is the place to do it.
"""
"""
# Added by default
# Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate())
# fgraph.attach_feature(toolbox.ReplaceValidate())
...
@@ -830,8 +888,11 @@ theano.configparser.AddConfigVar(
...
@@ -830,8 +888,11 @@ theano.configparser.AddConfigVar(
class
LocalMetaOptimizer
(
LocalOptimizer
):
class
LocalMetaOptimizer
(
LocalOptimizer
):
"""Base class for meta-optimizers that try a set of LocalOptimizers
"""
to replace a node and choose the one that executes the fastest"""
Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest.
"""
def
__init__
(
self
,
tracks
=
None
,
optimizers
=
()):
def
__init__
(
self
,
tracks
=
None
,
optimizers
=
()):
self
.
_tracks
=
tracks
self
.
_tracks
=
tracks
...
@@ -907,9 +968,12 @@ class LocalMetaOptimizer(LocalOptimizer):
...
@@ -907,9 +968,12 @@ class LocalMetaOptimizer(LocalOptimizer):
return
return
def
provide_inputs
(
self
,
node
,
inputs
):
def
provide_inputs
(
self
,
node
,
inputs
):
"""If implemented, returns a dictionary mapping all symbolic variables
"""
in ``inputs`` to SharedVariable instances of suitable dummy values. The
If implemented, returns a dictionary mapping all symbolic variables
``node`` can be inspected to infer required input shapes."""
in ``inputs`` to SharedVariable instances of suitable dummy values.
The ``node`` can be inspected to infer required input shapes.
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
time_call
(
self
,
fn
):
def
time_call
(
self
,
fn
):
...
@@ -919,7 +983,10 @@ class LocalMetaOptimizer(LocalOptimizer):
...
@@ -919,7 +983,10 @@ class LocalMetaOptimizer(LocalOptimizer):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
fn
,
tracks
=
None
,
requirements
=
()):
def
__init__
(
self
,
fn
,
tracks
=
None
,
requirements
=
()):
self
.
transform
=
fn
self
.
transform
=
fn
self
.
_tracks
=
tracks
self
.
_tracks
=
tracks
...
@@ -945,7 +1012,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
...
@@ -945,7 +1012,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def
local_optimizer
(
tracks
,
inplace
=
False
):
def
local_optimizer
(
tracks
,
inplace
=
False
):
def
decorator
(
f
):
def
decorator
(
f
):
"""WRITEME"""
"""
WRITEME
"""
if
tracks
is
not
None
:
if
tracks
is
not
None
:
if
len
(
tracks
)
is
0
:
if
len
(
tracks
)
is
0
:
raise
ValueError
(
"Use None instead of an empty list to apply to all nodes."
,
f
.
__module__
,
f
.
__name__
)
raise
ValueError
(
"Use None instead of an empty list to apply to all nodes."
,
f
.
__module__
,
f
.
__name__
)
...
@@ -964,7 +1034,10 @@ def local_optimizer(tracks, inplace=False):
...
@@ -964,7 +1034,10 @@ def local_optimizer(tracks, inplace=False):
class
LocalOptGroup
(
LocalOptimizer
):
class
LocalOptGroup
(
LocalOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
*
optimizers
):
def
__init__
(
self
,
*
optimizers
):
if
len
(
optimizers
)
==
1
and
isinstance
(
optimizers
[
0
],
list
):
if
len
(
optimizers
)
==
1
and
isinstance
(
optimizers
[
0
],
list
):
...
@@ -1009,12 +1082,23 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -1009,12 +1082,23 @@ class LocalOptGroup(LocalOptimizer):
class
OpSub
(
LocalOptimizer
):
class
OpSub
(
LocalOptimizer
):
"""WRITEME
"""
WRITEME
Replaces the application of a certain op by the application of
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that takes the same inputs as what they are replacing.
Parameters
----------
op1, op2
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
e.g. OpSub(add, sub) ==>
Examples
--------
OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
"""
# an OpSub does not apply to the nodes it produces
# an OpSub does not apply to the nodes it produces
...
@@ -1023,10 +1107,6 @@ class OpSub(LocalOptimizer):
...
@@ -1023,10 +1107,6 @@ class OpSub(LocalOptimizer):
retains_inputs
=
True
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
"""
self
.
op1
=
op1
self
.
op1
=
op1
self
.
op2
=
op2
self
.
op2
=
op2
self
.
transfer_tags
=
transfer_tags
self
.
transfer_tags
=
transfer_tags
...
@@ -1052,9 +1132,12 @@ class OpSub(LocalOptimizer):
...
@@ -1052,9 +1132,12 @@ class OpSub(LocalOptimizer):
class
OpRemove
(
LocalOptimizer
):
class
OpRemove
(
LocalOptimizer
):
"""WRITEME
"""
WRITEME
Removes all applications of an op by transferring each of its
Removes all applications of an op by transferring each of its
outputs to the corresponding input.
outputs to the corresponding input.
"""
"""
reentrant
=
False
# no nodes are added at all
reentrant
=
False
# no nodes are added at all
...
@@ -1085,7 +1168,9 @@ class OpRemove(LocalOptimizer):
...
@@ -1085,7 +1168,9 @@ class OpRemove(LocalOptimizer):
class
PatternSub
(
LocalOptimizer
):
class
PatternSub
(
LocalOptimizer
):
"""WRITEME
"""
WRITEME
@todo update
@todo update
Replaces all occurrences of the input pattern by the output pattern:
Replaces all occurrences of the input pattern by the output pattern:
...
@@ -1123,7 +1208,38 @@ class PatternSub(LocalOptimizer):
...
@@ -1123,7 +1208,38 @@ class PatternSub(LocalOptimizer):
trying to match and returns True or False according to an
trying to match and returns True or False according to an
arbitrary criterion.
arbitrary criterion.
Examples:
The constructor creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
Parameters
----------
in_pattern
The input pattern that we want to replace.
out_pattern
The replacement pattern.
allow_multiple_clients : bool
If False, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn : TODO
name
Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional
The values that self.tracks() will return. Useful to speed up
optimization sometimes.
get_nodes : optional
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this optimizer.
Notes
-----
`tracks` and `get_nodes` can be used to make this optimizer track a less
frequent Op, so this will make this optimizer tried less frequently.
Examples
--------
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
...
@@ -1137,31 +1253,6 @@ class PatternSub(LocalOptimizer):
...
@@ -1137,31 +1253,6 @@ class PatternSub(LocalOptimizer):
allow_multiple_clients
=
False
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
,
tracks
=
(),
get_nodes
=
None
):
tracks
=
(),
get_nodes
=
None
):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
:param in_pattern: the input pattern that we want to replace
:param out_pattern: the replacement pattern
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self
.
in_pattern
=
in_pattern
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
self
.
out_pattern
=
out_pattern
if
isinstance
(
in_pattern
,
(
list
,
tuple
)):
if
isinstance
(
in_pattern
,
(
list
,
tuple
)):
...
@@ -1196,6 +1287,7 @@ class PatternSub(LocalOptimizer):
...
@@ -1196,6 +1287,7 @@ class PatternSub(LocalOptimizer):
"""
"""
Checks if the graph from node corresponds to in_pattern. If it does,
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
constructs out_pattern and performs the replacement.
"""
"""
if
get_nodes
and
self
.
get_nodes
is
not
None
:
if
get_nodes
and
self
.
get_nodes
is
not
None
:
for
real_node
in
self
.
get_nodes
(
node
):
for
real_node
in
self
.
get_nodes
(
node
):
...
@@ -1357,12 +1449,40 @@ class Updater:
...
@@ -1357,12 +1449,40 @@ class Updater:
class
NavigatorOptimizer
(
Optimizer
):
class
NavigatorOptimizer
(
Optimizer
):
"""Abstract class
"""
Abstract class.
Parameters
----------
local_opt
A LocalOptimizer to apply over a FunctionGraph (or None is Ok too).
ignore_newtrees
- True: new subgraphs returned by an optimization is not a
candidate for optimization.
- False: new subgraphs returned by an optimization is a candidate
for optimization.
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
failure_callback
A function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
"""
@staticmethod
@staticmethod
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: print traceback
"""
Failure_callback for NavigatorOptimizer: print traceback.
"""
"""
if
config
.
on_opt_error
!=
'ignore'
:
if
config
.
on_opt_error
!=
'ignore'
:
_logger
.
error
(
"Optimization failure due to:
%
s"
%
str
(
local_opt
))
_logger
.
error
(
"Optimization failure due to:
%
s"
%
str
(
local_opt
))
...
@@ -1377,9 +1497,11 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1377,9 +1497,11 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer
"""
Failure_callback for NavigatorOptimizer.
Ignore InconsistencyErrors, print traceback.
ignore InconsistencyErrors, print traceback
"""
"""
if
isinstance
(
exc
,
InconsistencyError
):
if
isinstance
(
exc
,
InconsistencyError
):
return
return
...
@@ -1387,36 +1509,14 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1387,36 +1509,14 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
Failure_callback for NavigatorOptimizer: ignore all errors.
"""
"""
pass
pass
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
failure_callback
=
None
):
"""
:param local_opt: a LocalOptimizer to apply over a FunctionGraph
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self
.
local_opt
=
local_opt
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
if
ignore_newtrees
==
'auto'
:
self
.
ignore_newtrees
=
not
getattr
(
local_opt
,
'reentrant'
,
True
)
self
.
ignore_newtrees
=
not
getattr
(
local_opt
,
'reentrant'
,
True
)
...
@@ -1429,14 +1529,23 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1429,14 +1529,23 @@ class NavigatorOptimizer(Optimizer):
Install some FunctionGraph listeners to help the navigator deal with
Install some FunctionGraph listeners to help the navigator deal with
the ignore_trees-related functionality.
the ignore_trees-related functionality.
:param importer: function that will be called whenever when
Parameters
optimizations add stuff to the graph.
----------
:param pruner: function to be called when optimizations remove stuff
importer
from graph.
Function that will be called whenever optimizations add stuff
:param chin: "on change input" called whenever an node's inputs change.
to the graph.
pruner
:returns: The FunctionGraph plugin that handles the three tasks.
Function to be called when optimizations remove stuff
from the graph.
chin
"on change input" called whenever a node's inputs change.
Returns
-------
object
The FunctionGraph plugin that handles the three tasks.
Keep this around so that you can detach later!
Keep this around so that you can detach later!
"""
"""
if
self
.
ignore_newtrees
:
if
self
.
ignore_newtrees
:
importer
=
None
importer
=
None
...
@@ -1449,11 +1558,18 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1449,11 +1558,18 @@ class NavigatorOptimizer(Optimizer):
return
u
return
u
def
detach_updater
(
self
,
fgraph
,
u
):
def
detach_updater
(
self
,
fgraph
,
u
):
"""Undo the work of attach_updater.
"""
Undo the work of attach_updater.
Parameters
----------
u
A return-value of attach_updater.
:param u: a return-value of attach_updater
Returns
-------
None
:returns: None.
"""
"""
if
u
is
not
None
:
if
u
is
not
None
:
fgraph
.
remove_feature
(
u
)
fgraph
.
remove_feature
(
u
)
...
@@ -1470,12 +1586,20 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1470,12 +1586,20 @@ class NavigatorOptimizer(Optimizer):
If there are no replacement candidates or the fgraph rejects the
If there are no replacement candidates or the fgraph rejects the
replacements, this function returns False.
replacements, this function returns False.
:param fgraph: a FunctionGraph
Parameters
:param node: an Apply instance in `fgraph`
----------
:param lopt: a LocalOptimizer instance that may have a better idea for
fgraph
A FunctionGraph.
node
An Apply instance in `fgraph`
lopt
A LocalOptimizer instance that may have a better idea for
how to compute node's outputs.
how to compute node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `fgraph`.
Returns
-------
bool
True iff the `node`'s outputs were replaced in the `fgraph`.
"""
"""
lopt
=
lopt
or
self
.
local_opt
lopt
=
lopt
or
self
.
local_opt
...
@@ -1544,7 +1668,10 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1544,7 +1668,10 @@ class NavigatorOptimizer(Optimizer):
class
TopoOptimizer
(
NavigatorOptimizer
):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
failure_callback
=
None
):
...
@@ -1617,7 +1744,10 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -1617,7 +1744,10 @@ class TopoOptimizer(NavigatorOptimizer):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
failure_callback
=
None
):
...
@@ -1661,6 +1791,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -1661,6 +1791,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
Requires the following features:
Requires the following features:
- NodeFinder
- NodeFinder
- ReplaceValidate(Added by default)
- ReplaceValidate(Added by default)
"""
"""
super
(
OpKeyOptimizer
,
self
)
.
add_requirements
(
fgraph
)
super
(
OpKeyOptimizer
,
self
)
.
add_requirements
(
fgraph
)
fgraph
.
attach_feature
(
toolbox
.
NodeFinder
())
fgraph
.
attach_feature
(
toolbox
.
NodeFinder
())
...
@@ -1686,24 +1817,27 @@ class ChangeTracker:
...
@@ -1686,24 +1817,27 @@ class ChangeTracker:
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
"""
Apply optimizations until equilibrium point.
Parameters
----------
optimizers
List or set of local or global optimizations to apply until equilibrium.
max_use_ratio
Each optimizer can be applied at most (size of graph * this number)
times.
ignore_newtrees
See EquilibriumDB ignore_newtrees parameter definition.
"""
def
__init__
(
self
,
def
__init__
(
self
,
optimizers
,
optimizers
,
failure_callback
=
None
,
failure_callback
=
None
,
ignore_newtrees
=
True
,
ignore_newtrees
=
True
,
max_use_ratio
=
None
,
max_use_ratio
=
None
,
final_optimizers
=
None
):
final_optimizers
=
None
):
""" Apply optimizations until equilibrium point.
:param optimizers: list or set of local or global optimizations to
apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param ignore_newtrees: See EquilibriumDB ignore_newtrees
parameter definition
"""
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
None
,
ignore_newtrees
=
ignore_newtrees
,
ignore_newtrees
=
ignore_newtrees
,
...
@@ -2083,8 +2217,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2083,8 +2217,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
_check_chain
(
r
,
chain
):
def
_check_chain
(
r
,
chain
):
"""WRITEME"""
"""
WRITEME
"""
chain
=
list
(
reversed
(
chain
))
chain
=
list
(
reversed
(
chain
))
while
chain
:
while
chain
:
elem
=
chain
.
pop
()
elem
=
chain
.
pop
()
...
@@ -2115,17 +2251,20 @@ def _check_chain(r, chain):
...
@@ -2115,17 +2251,20 @@ def _check_chain(r, chain):
def
check_chain
(
r
,
*
chain
):
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
"""
WRITEME
"""
if
isinstance
(
r
,
graph
.
Apply
):
if
isinstance
(
r
,
graph
.
Apply
):
r
=
r
.
outputs
[
0
]
r
=
r
.
outputs
[
0
]
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
def
pre_greedy_local_optimizer
(
list_optimizations
,
out
):
def
pre_greedy_local_optimizer
(
list_optimizations
,
out
):
'''
"""
This function traverses the computation graph described by all
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the
``node`` in the graph before the variable out but that are not in the
fgraph.
i
t applies each of the local_optimizations on the traversed graph.
fgraph.
I
t applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
the graph of the indices of a subtensor.
...
@@ -2133,11 +2272,14 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -2133,11 +2272,14 @@ def pre_greedy_local_optimizer(list_optimizations, out):
We should not apply optimizations on node that are in fgraph.
We should not apply optimizations on node that are in fgraph.
So we don't optimize node that have an attribute fgraph.
So we don't optimize node that have an attribute fgraph.
:note: This don't do an equilibrium... So if there is optimization
Notes
-----
This doesn't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
like local_upcast_elemwise_constant_inputs in the list, that
add additional node to the inputs of the node, it can
adds additional node to the inputs of the node, it can
be needed to call this function multiple time.
be needed to call this function multiple times.
'''
"""
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
if
not
getattr
(
out
,
'owner'
,
None
):
if
not
getattr
(
out
,
'owner'
,
None
):
return
[
out
],
optimized_vars
return
[
out
],
optimized_vars
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论