Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7e415c47
提交
7e415c47
authored
4月 11, 2012
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP 8
上级
81e1a1e9
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
218 行增加
和
124 行删除
+218
-124
opt.py
theano/gof/opt.py
+218
-124
没有找到文件。
theano/gof/opt.py
浏览文件 @
7e415c47
...
...
@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
import
copy
,
logging
,
sys
,
time
import
copy
import
logging
import
sys
import
time
import
numpy
import
graph
from
env
import
InconsistencyError
import
op
import
utils
import
unify
import
toolbox
import
op
import
theano
from
theano
import
config
from
theano.gof.python25
import
any
,
all
,
deque
from
theano.configparser
import
AddConfigVar
,
BoolParam
,
config
from
theano.configparser
import
AddConfigVar
,
BoolParam
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
...
...
@@ -39,9 +41,11 @@ import traceback
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
env
):
return
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
))
class
Optimizer
(
object
):
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
...
...
@@ -91,26 +95,30 @@ class Optimizer(object):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
def
__init__
(
self
,
fn
):
self
.
apply
=
fn
def
add_requirements
(
self
,
env
):
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
apply
),
id
(
self
))
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
fn
(
*
args
,
**
kwargs
)
def
optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
rval
=
FromFunctionOptimizer
(
f
)
...
...
@@ -118,7 +126,6 @@ def optimizer(f):
return
rval
class
SeqOptimizer
(
Optimizer
,
list
):
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
...
...
@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def
warn
(
exc
,
self
,
optimizer
):
"""Default failure_callback for SeqOptimizer
"""
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
traceback
.
format_exc
())
if
config
.
on_opt_error
==
'raise'
:
...
...
@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME
Applies each L{Optimizer} in self in turn.
"""
l
=
[]
l
=
[]
nb_node_before
=
len
(
env
.
nodes
)
for
optimizer
in
self
:
try
:
t0
=
time
.
time
()
t0
=
time
.
time
()
optimizer
.
optimize
(
env
)
l
.
append
(
float
(
time
.
time
()
-
t0
))
except
AssertionError
:
# do not catch Assertion failures
l
.
append
(
float
(
time
.
time
()
-
t0
))
except
AssertionError
:
# do not catch Assertion failures
raise
except
Exception
,
e
:
if
self
.
failure_callback
:
...
...
@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation
return
id
(
self
)
!=
id
(
other
)
def
__str__
(
self
):
return
"SeqOpt(
%
s)"
%
list
.
__str__
(
self
)
...
...
@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
# This way, -1 will do all depth
if
depth
!=
0
:
depth
-=
1
for
opt
in
self
:
opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
depth
)
class
_metadict
:
...
...
@@ -219,13 +225,15 @@ class _metadict:
def
__init__
(
self
):
self
.
d
=
{}
self
.
l
=
[]
def
__getitem__
(
self
,
item
):
return
self
.
get
(
item
,
None
)
def
__setitem__
(
self
,
item
,
value
):
try
:
self
.
d
[
item
]
=
value
except
Exception
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
self
.
l
[
i
]
=
(
item
,
value
)
return
...
...
@@ -265,9 +273,11 @@ class _metadict:
return
value
else
:
return
default
def
clear
(
self
):
self
.
d
=
{}
self
.
l
=
[]
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
self
.
d
,
self
.
l
)
...
...
@@ -528,12 +538,13 @@ def pre_constant_merge(vars):
const_sig_inv
[
sig
]
=
var
return
var
if
var
.
owner
:
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
var
.
owner
.
inputs
[
idx
]
=
recursive_merge
(
inp
)
return
var
return
map
(
recursive_merge
,
vars
)
########################
### Local Optimizers ###
########################
...
...
@@ -557,25 +568,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- <list of variables> to use in place of `node`'s outputs in the greater graph.
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance
"""
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
add_requirements
(
self
,
env
):
"""If this local optimization wants to add some requirements to the env,
This is the place to do it."""
"""
If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default
#env.extend(toolbox.ReplaceValidate())
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
...
...
@@ -584,15 +601,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks
=
[]
self
.
transform
=
fn
self
.
_tracks
=
tracks
def
tracks
(
self
):
return
self
.
_tracks
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
transform
),
id
(
self
))
def
local_optimizer
(
*
tracks
):
def
decorator
(
f
):
"""WRITEME"""
...
...
@@ -607,11 +630,15 @@ class LocalOptGroup(LocalOptimizer):
def
__init__
(
self
,
*
optimizers
):
self
.
opts
=
optimizers
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
for
opt
in
optimizers
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
for
opt
in
optimizers
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
]))
return
getattr
(
self
,
'__name__'
,
(
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
def
transform
(
self
,
node
):
for
opt
in
self
.
opts
:
...
...
@@ -620,11 +647,12 @@ class LocalOptGroup(LocalOptimizer):
return
repl
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
depth
-=
1
for
lopt
in
self
.
opts
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
)
,
depth
=
depth
)
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
...
...
@@ -644,13 +672,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
reentrant
=
False
# an OpSub does not apply to the nodes it produces
retains_inputs
=
True
# all the inputs of the original node are transferred to the outputs
# an OpSub does not apply to the nodes it produces
reentrant
=
False
# all the inputs of the original node are transferred to the outputs
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
...
...
@@ -705,7 +736,8 @@ class OpRemove(LocalOptimizer):
return
"
%
s(x) -> x"
%
(
self
.
op
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
str
(
self
.
op
),
id
(
self
))
...
...
@@ -756,12 +788,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
...
...
@@ -771,7 +803,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match.
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
"""
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
...
...
@@ -780,8 +813,11 @@ class PatternSub(LocalOptimizer):
elif
isinstance
(
in_pattern
,
dict
):
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
else
:
raise
TypeError
(
"The pattern to search for must start with a specific Op instance."
)
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
raise
TypeError
(
"The pattern to search for must start with "
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
)
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
...
...
@@ -816,7 +852,7 @@ class PatternSub(LocalOptimizer):
if
node
.
op
!=
self
.
op
:
return
False
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
retry_with_equiv
():
expr_equiv
=
self
.
skip_identities
(
expr
)
if
expr_equiv
is
None
:
...
...
@@ -829,7 +865,9 @@ class PatternSub(LocalOptimizer):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
expr
.
owner
is
None
:
return
False
if
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
):
if
(
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
)):
return
retry_with_equiv
()
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
return
retry_with_equiv
()
...
...
@@ -841,10 +879,14 @@ class PatternSub(LocalOptimizer):
try
:
real_pattern
=
pattern
[
'pattern'
]
except
KeyError
:
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
if
constraint
(
expr
):
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
else
:
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
basestring
):
...
...
@@ -853,17 +895,22 @@ class PatternSub(LocalOptimizer):
return
retry_with_equiv
()
else
:
u
=
u
.
merge
(
expr
,
v
)
elif
isinstance
(
pattern
,
(
int
,
float
))
and
isinstance
(
expr
,
graph
.
Constant
):
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
elif
(
isinstance
(
pattern
,
(
int
,
float
))
and
isinstance
(
expr
,
graph
.
Constant
)):
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
return
u
else
:
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
):
elif
(
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
)):
return
u
else
:
return
retry_with_equiv
()
if
pdb
:
import
pdb
;
pdb
.
set_trace
()
import
pdb
pdb
.
set_trace
()
return
u
def
build
(
pattern
,
u
):
...
...
@@ -872,11 +919,12 @@ class PatternSub(LocalOptimizer):
return
pattern
[
0
](
*
args
)
elif
isinstance
(
pattern
,
basestring
):
return
u
[
unify
.
Var
(
pattern
)]
elif
isinstance
(
pattern
,
(
int
,
float
)):
elif
isinstance
(
pattern
,
(
int
,
float
)):
return
pattern
else
:
return
pattern
.
clone
()
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
if
u
:
p
=
self
.
out_pattern
new
=
build
(
p
,
u
)
...
...
@@ -886,23 +934,31 @@ class PatternSub(LocalOptimizer):
return
False
def
__str__
(
self
):
if
getattr
(
self
,
'__name__'
,
None
):
if
getattr
(
self
,
'__name__'
,
None
):
return
self
.
__name__
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
elif
isinstance
(
pattern
,
dict
):
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
else
:
return
str
(
pattern
)
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
def
__repr__
(
self
):
return
str
(
self
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
str
(
self
.
in_pattern
),
...
...
@@ -930,37 +986,48 @@ class NavigatorOptimizer(Optimizer):
_logger
.
error
(
traceback
.
format_exc
())
if
isinstance
(
exc
,
AssertionError
)
or
config
.
on_opt_error
==
'raise'
:
raise
exc
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
"""
if
isinstance
(
exc
,
InconsistencyError
):
return
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
)
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
pass
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- False: new subgraphs returned by an optimization is a candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
...
...
@@ -969,15 +1036,19 @@ class NavigatorOptimizer(Optimizer):
self
.
ignore_newtrees
=
ignore_newtrees
self
.
failure_callback
=
failure_callback
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff from graph.
:param importer: function that will be called whenever when
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
if
self
.
ignore_newtrees
:
importer
=
None
...
...
@@ -1010,21 +1081,22 @@ class NavigatorOptimizer(Optimizer):
if
u
is
not
None
:
env
.
remove_feature
(
u
)
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
return either False or a list of Variables that are intended to replace `node.outputs`.
This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is
successful, and this
function returns True.
If the env accepts the replacement, then the optimization is
successful, and this
function returns True.
If there are no replacement candidates or the env rejects the
replacements, this
function returns False.
If there are no replacement candidates or the env rejects the
replacements, this
function returns False.
:param env: an Env
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute
node's outputs.
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute
node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
...
...
@@ -1034,16 +1106,19 @@ class NavigatorOptimizer(Optimizer):
replacements
=
lopt
.
transform
(
node
)
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
return
False
else
:
raise
if
replacements
is
False
or
replacements
is
None
:
return
False
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. Expected list or tuple.'
%
lopt
)
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. '
'Expected list or tuple.'
%
lopt
)
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
# If an output would be replaced by itself, no need to perform
# the replacement
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
...
...
@@ -1056,8 +1131,8 @@ class NavigatorOptimizer(Optimizer):
except
Exception
,
e
:
# This means the replacements were rejected by the env.
#
# This is not supposed to happen. The default failure_callback
will print a
# traceback as a warning.
# This is not supposed to happen. The default failure_callback
#
will print a
traceback as a warning.
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
)
return
False
...
...
@@ -1072,26 +1147,33 @@ class NavigatorOptimizer(Optimizer):
self
.
local_opt
.
add_requirements
(
env
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
self
.
order
=
order
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
def
importer
(
node
):
if
node
is
not
current_node
:
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
:
try
:
...
...
@@ -1114,14 +1196,16 @@ class TopoOptimizer(NavigatorOptimizer):
self
.
detach_updater
(
env
,
u
)
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
'op_key'
):
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have "
"an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
):
op
=
self
.
local_opt
.
op_key
()
...
...
@@ -1129,9 +1213,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q
=
reduce
(
list
.
__iadd__
,
map
(
env
.
get_nodes
,
op
))
else
:
q
=
list
(
env
.
get_nodes
(
op
))
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
.
op
==
op
:
q
.
append
(
node
)
if
node
.
op
==
op
:
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
and
node
.
op
==
op
:
try
:
...
...
@@ -1159,7 +1246,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env
.
extend
(
toolbox
.
NodeFinder
())
class
ChangeTracker
:
def
__init__
(
self
):
self
.
changed
=
False
...
...
@@ -1176,17 +1262,19 @@ class ChangeTracker:
def
on_attach
(
self
,
env
):
env
.
change_tracker
=
self
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
optimizers
,
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
failure_callback
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
"""
:param optimizers: list or set of local or global optimizations to
apply until
equilibrium.
:param optimizers: list or set of local or global optimizations to
apply until
equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
...
...
@@ -1194,8 +1282,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
self
.
local_optimizers
=
[]
self
.
global_optimizers
=
[]
...
...
@@ -1206,7 +1294,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
global_optimizers
.
append
(
opt
)
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
assert
self
.
max_use_ratio
is
not
None
,
'max_use_ratio has to be a number'
assert
self
.
max_use_ratio
is
not
None
,
(
'max_use_ratio has to be a number'
)
def
add_requirements
(
self
,
env
):
super
(
EquilibriumOptimizer
,
self
)
.
add_requirements
(
env
)
...
...
@@ -1216,7 +1305,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for
opt
in
self
.
global_optimizers
:
opt
.
add_requirements
(
env
)
def
apply
(
self
,
env
,
start_from
=
None
):
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
changed
=
True
...
...
@@ -1251,9 +1340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes
.
append
(
len
(
q
))
max_nb_nodes
=
max
(
max_nb_nodes
,
len
(
q
))
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
def
importer
(
node
):
if
node
is
not
current_node
:
q
.
append
(
node
)
def
pruner
(
node
):
if
node
is
not
current_node
:
try
:
...
...
@@ -1277,7 +1368,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
or
getattr
(
lopt
,
"__name__"
,
""
))
if
node
not
in
env
.
nodes
:
break
# go to next node
# go to next node
break
finally
:
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
#TODO: erase this line, it's redundant at best
...
...
@@ -1314,10 +1406,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
if
depth
!=
0
:
for
lopt
in
self
.
local_optimizers
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
#################
...
...
@@ -1340,7 +1434,8 @@ def _check_chain(r, chain):
return
False
else
:
try
:
if
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
):
if
(
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
)):
return
False
except
TypeError
:
return
False
...
...
@@ -1354,6 +1449,7 @@ def _check_chain(r, chain):
return
(
r
is
not
None
)
#_check_chain.n_calls = 0
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
if
isinstance
(
r
,
graph
.
Apply
):
...
...
@@ -1378,7 +1474,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
if
not
getattr
(
out
,
'owner'
,
None
):
return
[
out
],
optimized_vars
node
=
out
.
owner
...
...
@@ -1390,11 +1486,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else
:
if
inp
.
owner
:
outs
,
optimized_vars
=
local_recursive_function
(
list_opt
,
inp
,
optimized_vars
,
depth
+
1
)
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
list_opt
,
inp
,
optimized_vars
,
depth
+
1
)
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
optimized_vars
[
k
]
=
v
nw_in
=
outs
[
inp
.
owner
.
outputs
.
index
(
inp
)]
...
...
@@ -1408,10 +1504,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret
=
opt
.
transform
(
node
)
if
ret
is
not
False
and
ret
is
not
None
:
assert
len
(
ret
)
==
len
(
node
.
outputs
)
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
optimized_vars
[
k
]
=
v
results
=
ret
if
ret
[
0
]
.
owner
:
if
ret
[
0
]
.
owner
:
node
=
out
.
owner
else
:
break
...
...
@@ -1422,8 +1518,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return
final_outs
[
0
]
############
### Misc ###
############
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论