Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
46d46d76
提交
46d46d76
authored
8月 13, 2015
作者:
Iban Harlouchet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
numpydoc for theano/gof/opt.py
上级
ae99f41d
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
301 行增加
和
159 行删除
+301
-159
opt.py
theano/gof/opt.py
+301
-159
没有找到文件。
theano/gof/opt.py
浏览文件 @
46d46d76
"""
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
from
__future__
import
print_function
...
...
@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph):
class
Optimizer
(
object
):
"""WRITEME
"""
WRITEME
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}.
"""
def
__hash__
(
self
):
...
...
@@ -58,19 +62,25 @@ class Optimizer(object):
return
id
(
self
)
!=
id
(
other
)
def
apply
(
self
,
fgraph
):
"""WRITEME
"""
WRITEME
Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
pass
def
optimize
(
self
,
fgraph
,
*
args
,
**
kwargs
):
"""WRITEME
This is meant as a shortcut to::
"""
WRITEME
This is meant as a shortcut to:
opt.add_requirements(fgraph)
opt.apply(fgraph)
"""
self
.
add_requirements
(
fgraph
)
try
:
...
...
@@ -82,18 +92,24 @@ class Optimizer(object):
return
ret
def
__call__
(
self
,
fgraph
):
"""WRITEME
Same as self.optimize(fgraph)
"""
WRITEME
Same as self.optimize(fgraph).
"""
return
self
.
optimize
(
fgraph
)
def
add_requirements
(
self
,
fgraph
):
"""WRITEME
"""
WRITEME
Add features to the fgraph that are required to apply the optimization.
For example:
fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature())
etc.
"""
pass
...
...
@@ -111,7 +127,10 @@ class Optimizer(object):
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
fn
,
requirements
=
()):
self
.
apply
=
fn
self
.
requirements
=
requirements
...
...
@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer):
def
optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
rval
=
FromFunctionOptimizer
(
f
)
rval
.
__name__
=
f
.
__name__
return
rval
def
inplace_optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
dh_handler
=
dh
.
DestroyHandler
requirements
=
(
lambda
fgraph
:
fgraph
.
attach_feature
(
dh_handler
()),)
...
...
@@ -152,13 +177,18 @@ def inplace_optimizer(f):
class
SeqOptimizer
(
Optimizer
,
list
):
# inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
"""
WRITEME
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
@staticmethod
def
warn
(
exc
,
self
,
optimizer
):
"""Default failure_callback for SeqOptimizer
"""
Default failure_callback for SeqOptimizer.
"""
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"Traceback:"
)
...
...
@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list):
pdb
.
post_mortem
(
sys
.
exc_info
()[
2
])
def
__init__
(
self
,
*
opts
,
**
kw
):
"""WRITEME"""
"""
WRITEME
"""
if
len
(
opts
)
==
1
and
isinstance
(
opts
[
0
],
(
list
,
tuple
)):
opts
=
opts
[
0
]
self
[:]
=
opts
self
.
failure_callback
=
kw
.
pop
(
'failure_callback'
,
None
)
def
apply
(
self
,
fgraph
):
"""WRITEME
"""
WRITEME
Applies each L{Optimizer} in self in turn.
"""
l
=
[]
if
fgraph
.
profile
:
...
...
@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list):
def
merge_profile
(
prof1
,
prof2
):
"""
Merge 2 profiles returned by this cass apply() fct.
"""
new_t
=
[]
new_l
=
[]
...
...
@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list):
class
_metadict
:
"""WRITEME"""
"""
WRITEME
"""
# dict that accepts unhashable keys
# uses an associative list
# for internal use only
...
...
@@ -430,6 +471,7 @@ class MergeFeature(object):
That way, the MergeOptimizer can remember the result of the last merge
pass on the fgraph.
"""
def
on_attach
(
self
,
fgraph
):
assert
not
hasattr
(
fgraph
,
'merge_feature'
)
...
...
@@ -493,7 +535,10 @@ class MergeFeature(object):
self
.
seen_constants
.
discard
(
id
(
c
))
def
process_constant
(
self
,
fgraph
,
c
):
"""Check if a constant can be merged, and queue that replacement"""
"""
Check if a constant can be merged, and queue that replacement.
"""
if
id
(
c
)
in
self
.
seen_constants
:
return
sig
=
c
.
merge_signature
()
...
...
@@ -511,7 +556,10 @@ class MergeFeature(object):
self
.
seen_constants
.
add
(
id
(
c
))
def
process_node
(
self
,
fgraph
,
node
):
"""Check if a node can be merged, and queue that replacement."""
"""
Check if a node can be merged, and queue that replacement.
"""
if
node
in
self
.
nodes_seen
:
return
...
...
@@ -570,6 +618,7 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def
add_requirements
(
self
,
fgraph
):
...
...
@@ -678,6 +727,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
if
givens
is
None
:
givens
=
{}
...
...
@@ -718,13 +768,15 @@ def pre_constant_merge(vars):
`vars` is a list of nodes, and we want to merge together nodes
that are constant inputs used to compute nodes in that list.
:note: This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
Notes
-----
This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
seen_var
=
set
()
# signature -> variable (for constants)
const_sig_inv
=
{}
...
...
@@ -767,10 +819,12 @@ def pre_constant_merge(vars):
########################
class
LocalOptimizer
(
object
):
"""A class for node-based optimizations.
"""
A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance.
"""
def
__hash__
(
self
):
...
...
@@ -784,11 +838,13 @@ class LocalOptimizer(object):
Return the list of op classes that this opt applies to.
Return None to apply to all nodes.
"""
return
None
def
transform
(
self
,
node
):
"""Transform a subgraph whose output is `node`.
"""
Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
...
...
@@ -800,7 +856,9 @@ class LocalOptimizer(object):
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
Parameters
----------
node : an Apply instance
"""
...
...
@@ -810,8 +868,8 @@ class LocalOptimizer(object):
def
add_requirements
(
self
,
fgraph
):
"""
If this local optimization wants to add some requirements to the
fgraph,
This is the place to do it.
fgraph,
this is the place to do it.
"""
# Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate())
...
...
@@ -830,8 +888,11 @@ theano.configparser.AddConfigVar(
class
LocalMetaOptimizer
(
LocalOptimizer
):
"""Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest"""
"""
Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest.
"""
def
__init__
(
self
,
tracks
=
None
,
optimizers
=
()):
self
.
_tracks
=
tracks
...
...
@@ -907,9 +968,12 @@ class LocalMetaOptimizer(LocalOptimizer):
return
def
provide_inputs
(
self
,
node
,
inputs
):
"""If implemented, returns a dictionary mapping all symbolic variables
in ``inputs`` to SharedVariable instances of suitable dummy values. The
``node`` can be inspected to infer required input shapes."""
"""
If implemented, returns a dictionary mapping all symbolic variables
in ``inputs`` to SharedVariable instances of suitable dummy values.
The ``node`` can be inspected to infer required input shapes.
"""
raise
NotImplementedError
()
def
time_call
(
self
,
fn
):
...
...
@@ -919,7 +983,10 @@ class LocalMetaOptimizer(LocalOptimizer):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
fn
,
tracks
=
None
,
requirements
=
()):
self
.
transform
=
fn
self
.
_tracks
=
tracks
...
...
@@ -945,7 +1012,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def
local_optimizer
(
tracks
,
inplace
=
False
):
def
decorator
(
f
):
"""WRITEME"""
"""
WRITEME
"""
if
tracks
is
not
None
:
if
len
(
tracks
)
is
0
:
raise
ValueError
(
"Use None instead of an empty list to apply to all nodes."
,
f
.
__module__
,
f
.
__name__
)
...
...
@@ -964,7 +1034,10 @@ def local_optimizer(tracks, inplace=False):
class
LocalOptGroup
(
LocalOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
*
optimizers
):
if
len
(
optimizers
)
==
1
and
isinstance
(
optimizers
[
0
],
list
):
...
...
@@ -1009,12 +1082,23 @@ class LocalOptGroup(LocalOptimizer):
class
OpSub
(
LocalOptimizer
):
"""WRITEME
"""
WRITEME
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that take
s
the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==>
Parameters
----------
op1, op2
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
Examples
--------
OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
# an OpSub does not apply to the nodes it produces
...
...
@@ -1023,10 +1107,6 @@ class OpSub(LocalOptimizer):
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
"""
self
.
op1
=
op1
self
.
op2
=
op2
self
.
transfer_tags
=
transfer_tags
...
...
@@ -1052,9 +1132,12 @@ class OpSub(LocalOptimizer):
class
OpRemove
(
LocalOptimizer
):
"""WRITEME
"""
WRITEME
Removes all applications of an op by transferring each of its
outputs to the corresponding input.
"""
reentrant
=
False
# no nodes are added at all
...
...
@@ -1085,25 +1168,27 @@ class OpRemove(LocalOptimizer):
class
PatternSub
(
LocalOptimizer
):
"""WRITEME
"""
WRITEME
@todo update
Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is
...
...
@@ -1123,45 +1208,51 @@ class PatternSub(LocalOptimizer):
trying to match and returns True or False according to an
arbitrary criterion.
Examples:
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
The constructor creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
Parameters
----------
in_pattern
The input pattern that we want to replace.
out_pattern
The replacement pattern.
allow_multiple_clients : bool
If False, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn : TODO
name
Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional
The values that self.tracks() will return. Useful to speed up
optimization sometimes.
get_nodes : optional
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this optimizer.
Notes
-----
`tracks` and `get_nodes` can be used to make this optimizer track a less
frequent Op, so this will make this optimizer tried less frequently.
Examples
--------
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
,
tracks
=
(),
get_nodes
=
None
):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
:param in_pattern: the input pattern that we want to replace
:param out_pattern: the replacement pattern
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
if
isinstance
(
in_pattern
,
(
list
,
tuple
)):
...
...
@@ -1196,6 +1287,7 @@ class PatternSub(LocalOptimizer):
"""
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
"""
if
get_nodes
and
self
.
get_nodes
is
not
None
:
for
real_node
in
self
.
get_nodes
(
node
):
...
...
@@ -1357,12 +1449,40 @@ class Updater:
class
NavigatorOptimizer
(
Optimizer
):
"""Abstract class
"""
Abstract class.
Parameters
----------
local_opt
A LocalOptimizer to apply over a FunctionGraph (or None is Ok too).
ignore_newtrees
- True: new subgraphs returned by an optimization is not a
candidate for optimization.
- False: new subgraphs returned by an optimization is a candidate
for optimization.
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
failure_callback
A function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
@staticmethod
def
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: print traceback
"""
Failure_callback for NavigatorOptimizer: print traceback.
"""
if
config
.
on_opt_error
!=
'ignore'
:
_logger
.
error
(
"Optimization failure due to:
%
s"
%
str
(
local_opt
))
...
...
@@ -1377,9 +1497,11 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer
"""
Failure_callback for NavigatorOptimizer.
Ignore InconsistencyErrors, print traceback.
ignore InconsistencyErrors, print traceback
"""
if
isinstance
(
exc
,
InconsistencyError
):
return
...
...
@@ -1387,36 +1509,14 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
Failure_callback for NavigatorOptimizer: ignore all errors.
"""
pass
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
:param local_opt: a LocalOptimizer to apply over a FunctionGraph
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
self
.
ignore_newtrees
=
not
getattr
(
local_opt
,
'reentrant'
,
True
)
...
...
@@ -1429,14 +1529,23 @@ class NavigatorOptimizer(Optimizer):
Install some FunctionGraph listeners to help the navigator deal with
the ignore_trees-related functionality.
:param importer: function that will be called whenever when
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The FunctionGraph plugin that handles the three tasks.
Parameters
----------
importer
Function that will be called whenever optimizations add stuff
to the graph.
pruner
Function to be called when optimizations remove stuff
from the graph.
chin
"on change input" called whenever a node's inputs change.
Returns
-------
object
The FunctionGraph plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
if
self
.
ignore_newtrees
:
importer
=
None
...
...
@@ -1449,18 +1558,25 @@ class NavigatorOptimizer(Optimizer):
return
u
def
detach_updater
(
self
,
fgraph
,
u
):
"""Undo the work of attach_updater.
"""
Undo the work of attach_updater.
Parameters
----------
u
A return-value of attach_updater.
:param u: a return-value of attach_updater
Returns
-------
None
:returns: None.
"""
if
u
is
not
None
:
fgraph
.
remove_feature
(
u
)
def
process_node
(
self
,
fgraph
,
node
,
lopt
=
None
):
"""
This function will use `lopt` to `transform` the `node`.
The
This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
...
...
@@ -1470,12 +1586,20 @@ class NavigatorOptimizer(Optimizer):
If there are no replacement candidates or the fgraph rejects the
replacements, this function returns False.
:param fgraph: a FunctionGraph
:param node: an Apply instance in `fgraph`
:param lopt: a LocalOptimizer instance that may have a better idea for
Parameters
----------
fgraph
A FunctionGraph.
node
An Apply instance in `fgraph`
lopt
A LocalOptimizer instance that may have a better idea for
how to compute node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `fgraph`.
Returns
-------
bool
True iff the `node`'s outputs were replaced in the `fgraph`.
"""
lopt
=
lopt
or
self
.
local_opt
...
...
@@ -1544,7 +1668,10 @@ class NavigatorOptimizer(Optimizer):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
...
...
@@ -1617,7 +1744,10 @@ class TopoOptimizer(NavigatorOptimizer):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""
WRITEME
"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
...
...
@@ -1661,6 +1791,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
Requires the following features:
- NodeFinder
- ReplaceValidate(Added by default)
"""
super
(
OpKeyOptimizer
,
self
)
.
add_requirements
(
fgraph
)
fgraph
.
attach_feature
(
toolbox
.
NodeFinder
())
...
...
@@ -1686,24 +1817,27 @@ class ChangeTracker:
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
"""
Apply optimizations until equilibrium point.
Parameters
----------
optimizers
List or set of local or global optimizations to apply until equilibrium.
max_use_ratio
Each optimizer can be applied at most (size of graph * this number)
times.
ignore_newtrees
See EquilibriumDB ignore_newtrees parameter definition.
"""
def
__init__
(
self
,
optimizers
,
failure_callback
=
None
,
ignore_newtrees
=
True
,
max_use_ratio
=
None
,
final_optimizers
=
None
):
""" Apply optimizations until equilibrium point.
:param optimizers: list or set of local or global optimizations to
apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param ignore_newtrees: See EquilibriumDB ignore_newtrees
parameter definition
"""
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
ignore_newtrees
=
ignore_newtrees
,
...
...
@@ -2083,8 +2217,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
_check_chain
(
r
,
chain
):
"""WRITEME"""
"""
WRITEME
"""
chain
=
list
(
reversed
(
chain
))
while
chain
:
elem
=
chain
.
pop
()
...
...
@@ -2115,17 +2251,20 @@ def _check_chain(r, chain):
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
"""
WRITEME
"""
if
isinstance
(
r
,
graph
.
Apply
):
r
=
r
.
outputs
[
0
]
return
_check_chain
(
r
,
reduce
(
list
.
__iadd__
,
([
x
,
0
]
for
x
in
chain
)))
def
pre_greedy_local_optimizer
(
list_optimizations
,
out
):
'''
"""
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the
fgraph.
i
t applies each of the local_optimizations on the traversed graph.
fgraph.
I
t applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
...
...
@@ -2133,11 +2272,14 @@ def pre_greedy_local_optimizer(list_optimizations, out):
We should not apply optimizations on node that are in fgraph.
So we don't optimize node that have an attribute fgraph.
:note: This don't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
Notes
-----
This doesn't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
adds additional node to the inputs of the node, it can
be needed to call this function multiple times.
"""
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
if
not
getattr
(
out
,
'owner'
,
None
):
return
[
out
],
optimized_vars
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论