Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7e415c47
提交
7e415c47
authored
4月 11, 2012
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP 8
上级
81e1a1e9
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
218 行增加
和
124 行删除
+218
-124
opt.py
theano/gof/opt.py
+218
-124
没有找到文件。
theano/gof/opt.py
浏览文件 @
7e415c47
...
@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
...
@@ -3,21 +3,23 @@ Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
amount of useful generic optimization tools.
"""
"""
import
copy
import
copy
,
logging
,
sys
,
time
import
logging
import
sys
import
time
import
numpy
import
numpy
import
graph
import
graph
from
env
import
InconsistencyError
from
env
import
InconsistencyError
import
op
import
utils
import
utils
import
unify
import
unify
import
toolbox
import
toolbox
import
op
import
theano
import
theano
from
theano
import
config
from
theano
import
config
from
theano.gof.python25
import
any
,
all
,
deque
from
theano.gof.python25
import
any
,
all
,
deque
from
theano.configparser
import
AddConfigVar
,
BoolParam
,
config
from
theano.configparser
import
AddConfigVar
,
BoolParam
#if sys.version_info[:2] >= (2,5):
#if sys.version_info[:2] >= (2,5):
# from collections import defaultdict
# from collections import defaultdict
...
@@ -39,9 +41,11 @@ import traceback
...
@@ -39,9 +41,11 @@ import traceback
_optimizer_idx
=
[
0
]
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
env
):
def
_list_of_nodes
(
env
):
return
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
))
return
list
(
graph
.
io_toposort
(
env
.
inputs
,
env
.
outputs
))
class
Optimizer
(
object
):
class
Optimizer
(
object
):
"""WRITEME
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
An L{Optimizer} can be applied to an L{Env} to transform it.
...
@@ -91,26 +95,30 @@ class Optimizer(object):
...
@@ -91,26 +95,30 @@ class Optimizer(object):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
name
,
id
(
self
))
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
class
FromFunctionOptimizer
(
Optimizer
):
class
FromFunctionOptimizer
(
Optimizer
):
"""WRITEME"""
"""WRITEME"""
def
__init__
(
self
,
fn
):
def
__init__
(
self
,
fn
):
self
.
apply
=
fn
self
.
apply
=
fn
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
# Added by default
# Added by default
#env.extend(toolbox.ReplaceValidate())
#env.extend(toolbox.ReplaceValidate())
pass
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
apply
),
str
(
self
.
apply
),
id
(
self
))
id
(
self
))
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
fn
(
*
args
,
**
kwargs
)
return
self
.
fn
(
*
args
,
**
kwargs
)
def
optimizer
(
f
):
def
optimizer
(
f
):
"""decorator for FromFunctionOptimizer"""
"""decorator for FromFunctionOptimizer"""
rval
=
FromFunctionOptimizer
(
f
)
rval
=
FromFunctionOptimizer
(
f
)
...
@@ -118,7 +126,6 @@ def optimizer(f):
...
@@ -118,7 +126,6 @@ def optimizer(f):
return
rval
return
rval
class
SeqOptimizer
(
Optimizer
,
list
):
class
SeqOptimizer
(
Optimizer
,
list
):
#inherit from Optimizer first to get Optimizer.__hash__
#inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
"""WRITEME
...
@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -129,7 +136,7 @@ class SeqOptimizer(Optimizer, list):
def
warn
(
exc
,
self
,
optimizer
):
def
warn
(
exc
,
self
,
optimizer
):
"""Default failure_callback for SeqOptimizer
"""Default failure_callback for SeqOptimizer
"""
"""
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"SeqOptimizer apply
%
s"
%
str
(
optimizer
))
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
"Traceback:"
)
_logger
.
error
(
traceback
.
format_exc
())
_logger
.
error
(
traceback
.
format_exc
())
if
config
.
on_opt_error
==
'raise'
:
if
config
.
on_opt_error
==
'raise'
:
...
@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
...
@@ -146,14 +153,15 @@ class SeqOptimizer(Optimizer, list):
"""WRITEME
"""WRITEME
Applies each L{Optimizer} in self in turn.
Applies each L{Optimizer} in self in turn.
"""
"""
l
=
[]
l
=
[]
nb_node_before
=
len
(
env
.
nodes
)
nb_node_before
=
len
(
env
.
nodes
)
for
optimizer
in
self
:
for
optimizer
in
self
:
try
:
try
:
t0
=
time
.
time
()
t0
=
time
.
time
()
optimizer
.
optimize
(
env
)
optimizer
.
optimize
(
env
)
l
.
append
(
float
(
time
.
time
()
-
t0
))
l
.
append
(
float
(
time
.
time
()
-
t0
))
except
AssertionError
:
# do not catch Assertion failures
except
AssertionError
:
# do not catch Assertion failures
raise
raise
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
:
if
self
.
failure_callback
:
...
@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
...
@@ -192,7 +200,6 @@ class SeqOptimizer(Optimizer, list):
#added to override the list's __neq__ implementation
#added to override the list's __neq__ implementation
return
id
(
self
)
!=
id
(
other
)
return
id
(
self
)
!=
id
(
other
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"SeqOpt(
%
s)"
%
list
.
__str__
(
self
)
return
"SeqOpt(
%
s)"
%
list
.
__str__
(
self
)
...
@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
...
@@ -201,14 +208,13 @@ class SeqOptimizer(Optimizer, list):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
# This way, -1 will do all depth
# This way, -1 will do all depth
if
depth
!=
0
:
if
depth
!=
0
:
depth
-=
1
depth
-=
1
for
opt
in
self
:
for
opt
in
self
:
opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
depth
)
class
_metadict
:
class
_metadict
:
...
@@ -219,13 +225,15 @@ class _metadict:
...
@@ -219,13 +225,15 @@ class _metadict:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
d
=
{}
self
.
d
=
{}
self
.
l
=
[]
self
.
l
=
[]
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
self
.
get
(
item
,
None
)
return
self
.
get
(
item
,
None
)
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
,
value
):
try
:
try
:
self
.
d
[
item
]
=
value
self
.
d
[
item
]
=
value
except
Exception
:
except
Exception
:
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
for
i
,
(
key
,
val
)
in
enumerate
(
self
.
l
):
if
key
==
item
:
if
key
==
item
:
self
.
l
[
i
]
=
(
item
,
value
)
self
.
l
[
i
]
=
(
item
,
value
)
return
return
...
@@ -265,9 +273,11 @@ class _metadict:
...
@@ -265,9 +273,11 @@ class _metadict:
return
value
return
value
else
:
else
:
return
default
return
default
def
clear
(
self
):
def
clear
(
self
):
self
.
d
=
{}
self
.
d
=
{}
self
.
l
=
[]
self
.
l
=
[]
def
__str__
(
self
):
def
__str__
(
self
):
return
"(
%
s,
%
s)"
%
(
self
.
d
,
self
.
l
)
return
"(
%
s,
%
s)"
%
(
self
.
d
,
self
.
l
)
...
@@ -528,12 +538,13 @@ def pre_constant_merge(vars):
...
@@ -528,12 +538,13 @@ def pre_constant_merge(vars):
const_sig_inv
[
sig
]
=
var
const_sig_inv
[
sig
]
=
var
return
var
return
var
if
var
.
owner
:
if
var
.
owner
:
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
for
idx
,
inp
in
enumerate
(
var
.
owner
.
inputs
):
var
.
owner
.
inputs
[
idx
]
=
recursive_merge
(
inp
)
var
.
owner
.
inputs
[
idx
]
=
recursive_merge
(
inp
)
return
var
return
var
return
map
(
recursive_merge
,
vars
)
return
map
(
recursive_merge
,
vars
)
########################
########################
### Local Optimizers ###
### Local Optimizers ###
########################
########################
...
@@ -557,25 +568,31 @@ class LocalOptimizer(object):
...
@@ -557,25 +568,31 @@ class LocalOptimizer(object):
Subclasses should implement this function so that it returns one of two
Subclasses should implement this function so that it returns one of two
kinds of things:
kinds of things:
- False to indicate that no optimization can be applied to this `node`; or
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the greater graph.
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
:type node: an Apply instance
:type node: an Apply instance
"""
"""
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
utils
.
MethodNotDefined
(
"transform"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
"""If this local optimization wants to add some requirements to the env,
"""
This is the place to do it."""
If this local optimization wants to add some requirements to the env,
This is the place to do it.
"""
# Added by default
# Added by default
#env.extend(toolbox.ReplaceValidate())
#env.extend(toolbox.ReplaceValidate())
pass
pass
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
class
FromFunctionLocalOptimizer
(
LocalOptimizer
):
"""WRITEME"""
"""WRITEME"""
...
@@ -584,15 +601,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
...
@@ -584,15 +601,21 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
tracks
=
[]
tracks
=
[]
self
.
transform
=
fn
self
.
transform
=
fn
self
.
_tracks
=
tracks
self
.
_tracks
=
tracks
def
tracks
(
self
):
def
tracks
(
self
):
return
self
.
_tracks
return
self
.
_tracks
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
return
getattr
(
self
,
'__name__'
,
'<FromFunctionLocalOptimizer instance>'
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
str
(
self
.
transform
),
str
(
self
.
transform
),
id
(
self
))
id
(
self
))
def
local_optimizer
(
*
tracks
):
def
local_optimizer
(
*
tracks
):
def
decorator
(
f
):
def
decorator
(
f
):
"""WRITEME"""
"""WRITEME"""
...
@@ -607,11 +630,15 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -607,11 +630,15 @@ class LocalOptGroup(LocalOptimizer):
def
__init__
(
self
,
*
optimizers
):
def
__init__
(
self
,
*
optimizers
):
self
.
opts
=
optimizers
self
.
opts
=
optimizers
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
for
opt
in
optimizers
)
self
.
reentrant
=
any
(
getattr
(
opt
,
'reentrant'
,
True
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
for
opt
in
optimizers
)
self
.
retains_inputs
=
all
(
getattr
(
opt
,
'retains_inputs'
,
False
)
for
opt
in
optimizers
)
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
]))
return
getattr
(
self
,
'__name__'
,
(
'<theano.gof.opt.LocalOptGroup instance>'
+
str
([
str
(
o
)
for
o
in
self
.
opts
])))
def
transform
(
self
,
node
):
def
transform
(
self
,
node
):
for
opt
in
self
.
opts
:
for
opt
in
self
.
opts
:
...
@@ -620,11 +647,12 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -620,11 +647,12 @@ class LocalOptGroup(LocalOptimizer):
return
repl
return
repl
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
if
depth
!=
0
:
depth
-=
1
depth
-=
1
for
lopt
in
self
.
opts
:
for
lopt
in
self
.
opts
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
)
,
depth
=
depth
)
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
...
@@ -644,13 +672,16 @@ class OpSub(LocalOptimizer):
...
@@ -644,13 +672,16 @@ class OpSub(LocalOptimizer):
Replaces the application of a certain op by the application of
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that take the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
e.g. OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
"""
reentrant
=
False
# an OpSub does not apply to the nodes it produces
# an OpSub does not apply to the nodes it produces
retains_inputs
=
True
# all the inputs of the original node are transferred to the outputs
reentrant
=
False
# all the inputs of the original node are transferred to the outputs
retains_inputs
=
True
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
def
__init__
(
self
,
op1
,
op2
,
transfer_tags
=
True
):
"""
"""
op1.make_node and op2.make_node must take the same number of
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
inputs and have the same number of outputs.
...
@@ -705,7 +736,8 @@ class OpRemove(LocalOptimizer):
...
@@ -705,7 +736,8 @@ class OpRemove(LocalOptimizer):
return
"
%
s(x) -> x"
%
(
self
.
op
)
return
"
%
s(x) -> x"
%
(
self
.
op
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
str
(
self
.
op
),
str
(
self
.
op
),
id
(
self
))
id
(
self
))
...
@@ -756,12 +788,12 @@ class PatternSub(LocalOptimizer):
...
@@ -756,12 +788,12 @@ class PatternSub(LocalOptimizer):
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
(scrabble, 'x'))
"""
"""
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
def
__init__
(
self
,
in_pattern
,
out_pattern
,
allow_multiple_clients
=
False
,
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
skip_identities_fn
=
None
,
name
=
None
,
pdb
=
False
):
"""
"""
Creates a PatternSub that replaces occurrences of
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
in_pattern by occurrences of out_pattern.
...
@@ -771,7 +803,8 @@ class PatternSub(LocalOptimizer):
...
@@ -771,7 +803,8 @@ class PatternSub(LocalOptimizer):
:param allow_multiple_clients: if False, the pattern matching will fail
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
if one of the subpatterns has more than
one client.
one client.
:param pdb: if True, we invoke pdb when the first node in the pattern match.
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
"""
"""
self
.
in_pattern
=
in_pattern
self
.
in_pattern
=
in_pattern
self
.
out_pattern
=
out_pattern
self
.
out_pattern
=
out_pattern
...
@@ -780,8 +813,11 @@ class PatternSub(LocalOptimizer):
...
@@ -780,8 +813,11 @@ class PatternSub(LocalOptimizer):
elif
isinstance
(
in_pattern
,
dict
):
elif
isinstance
(
in_pattern
,
dict
):
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
self
.
op
=
self
.
in_pattern
[
'pattern'
][
0
]
else
:
else
:
raise
TypeError
(
"The pattern to search for must start with a specific Op instance."
)
raise
TypeError
(
"The pattern to search for must start with "
self
.
__doc__
=
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
"a specific Op instance."
)
self
.
__doc__
=
(
self
.
__class__
.
__doc__
+
"
\n\n
This instance does: "
+
str
(
self
)
+
"
\n
"
)
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
allow_multiple_clients
=
allow_multiple_clients
self
.
skip_identities_fn
=
skip_identities_fn
self
.
skip_identities_fn
=
skip_identities_fn
if
name
:
if
name
:
...
@@ -816,7 +852,7 @@ class PatternSub(LocalOptimizer):
...
@@ -816,7 +852,7 @@ class PatternSub(LocalOptimizer):
if
node
.
op
!=
self
.
op
:
if
node
.
op
!=
self
.
op
:
return
False
return
False
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
retry_with_equiv
():
def
retry_with_equiv
():
expr_equiv
=
self
.
skip_identities
(
expr
)
expr_equiv
=
self
.
skip_identities
(
expr
)
if
expr_equiv
is
None
:
if
expr_equiv
is
None
:
...
@@ -829,7 +865,9 @@ class PatternSub(LocalOptimizer):
...
@@ -829,7 +865,9 @@ class PatternSub(LocalOptimizer):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
expr
.
owner
is
None
:
if
expr
.
owner
is
None
:
return
False
return
False
if
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
):
if
(
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
)):
return
retry_with_equiv
()
return
retry_with_equiv
()
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
return
retry_with_equiv
()
return
retry_with_equiv
()
...
@@ -841,10 +879,14 @@ class PatternSub(LocalOptimizer):
...
@@ -841,10 +879,14 @@ class PatternSub(LocalOptimizer):
try
:
try
:
real_pattern
=
pattern
[
'pattern'
]
real_pattern
=
pattern
[
'pattern'
]
except
KeyError
:
except
KeyError
:
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
raise
KeyError
(
"Malformed pattern:
%
s (expected key 'pattern')"
%
pattern
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
constraint
=
pattern
.
get
(
'constraint'
,
lambda
expr
:
True
)
if
constraint
(
expr
):
if
constraint
(
expr
):
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
return
match
(
real_pattern
,
expr
,
u
,
pattern
.
get
(
'allow_multiple_clients'
,
allow_multiple_clients
))
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
basestring
):
elif
isinstance
(
pattern
,
basestring
):
...
@@ -853,17 +895,22 @@ class PatternSub(LocalOptimizer):
...
@@ -853,17 +895,22 @@ class PatternSub(LocalOptimizer):
return
retry_with_equiv
()
return
retry_with_equiv
()
else
:
else
:
u
=
u
.
merge
(
expr
,
v
)
u
=
u
.
merge
(
expr
,
v
)
elif
isinstance
(
pattern
,
(
int
,
float
))
and
isinstance
(
expr
,
graph
.
Constant
):
elif
(
isinstance
(
pattern
,
(
int
,
float
))
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
and
isinstance
(
expr
,
graph
.
Constant
)):
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
return
u
return
u
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
elif
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
):
elif
(
isinstance
(
pattern
,
graph
.
Constant
)
and
isinstance
(
expr
,
graph
.
Constant
)
and
pattern
.
equals
(
expr
)):
return
u
return
u
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
if
pdb
:
if
pdb
:
import
pdb
;
pdb
.
set_trace
()
import
pdb
pdb
.
set_trace
()
return
u
return
u
def
build
(
pattern
,
u
):
def
build
(
pattern
,
u
):
...
@@ -872,11 +919,12 @@ class PatternSub(LocalOptimizer):
...
@@ -872,11 +919,12 @@ class PatternSub(LocalOptimizer):
return
pattern
[
0
](
*
args
)
return
pattern
[
0
](
*
args
)
elif
isinstance
(
pattern
,
basestring
):
elif
isinstance
(
pattern
,
basestring
):
return
u
[
unify
.
Var
(
pattern
)]
return
u
[
unify
.
Var
(
pattern
)]
elif
isinstance
(
pattern
,
(
int
,
float
)):
elif
isinstance
(
pattern
,
(
int
,
float
)):
return
pattern
return
pattern
else
:
else
:
return
pattern
.
clone
()
return
pattern
.
clone
()
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
u
=
match
(
self
.
in_pattern
,
node
.
out
,
unify
.
Unification
(),
True
,
self
.
pdb
)
if
u
:
if
u
:
p
=
self
.
out_pattern
p
=
self
.
out_pattern
new
=
build
(
p
,
u
)
new
=
build
(
p
,
u
)
...
@@ -886,23 +934,31 @@ class PatternSub(LocalOptimizer):
...
@@ -886,23 +934,31 @@ class PatternSub(LocalOptimizer):
return
False
return
False
def
__str__
(
self
):
def
__str__
(
self
):
if
getattr
(
self
,
'__name__'
,
None
):
if
getattr
(
self
,
'__name__'
,
None
):
return
self
.
__name__
return
self
.
__name__
def
pattern_to_str
(
pattern
):
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
elif
isinstance
(
pattern
,
dict
):
elif
isinstance
(
pattern
,
dict
):
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
else
:
else
:
return
str
(
pattern
)
return
str
(
pattern
)
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
print
>>
stream
,
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
name
,
name
,
str
(
self
.
in_pattern
),
str
(
self
.
in_pattern
),
...
@@ -930,37 +986,48 @@ class NavigatorOptimizer(Optimizer):
...
@@ -930,37 +986,48 @@ class NavigatorOptimizer(Optimizer):
_logger
.
error
(
traceback
.
format_exc
())
_logger
.
error
(
traceback
.
format_exc
())
if
isinstance
(
exc
,
AssertionError
)
or
config
.
on_opt_error
==
'raise'
:
if
isinstance
(
exc
,
AssertionError
)
or
config
.
on_opt_error
==
'raise'
:
raise
exc
raise
exc
@staticmethod
@staticmethod
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_inplace
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""failure_callback for NavigatorOptimizer
ignore InconsistencyErrors, print traceback
"""
"""
if
isinstance
(
exc
,
InconsistencyError
):
if
isinstance
(
exc
,
InconsistencyError
):
return
return
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
)
return
NavigatorOptimizer
.
warn
(
exc
,
nav
,
repl_pairs
,
local_opt
)
@staticmethod
@staticmethod
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
def
warn_ignore
(
exc
,
nav
,
repl_pairs
,
local_opt
):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
"""
pass
pass
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
'auto'
,
failure_callback
=
None
):
"""
"""
:param local_opt: a LocalOptimizer to apply over a Env (or None is Ok too).
:param local_opt: a LocalOptimizer to apply over a Env
(or None is Ok too).
:param ignore_newtrees:
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a candidate for optimization
- True: new subgraphs returned by an optimization is not a
- False: new subgraphs returned by an optimization is a candidate for optimization
candidate for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant' attribute.
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables will be 'None'.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
example) then the new variables will be the ones created by
transform().
transform().
If this parameter is None, then exceptions are not caught here (raised normally).
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
"""
self
.
local_opt
=
local_opt
self
.
local_opt
=
local_opt
if
ignore_newtrees
==
'auto'
:
if
ignore_newtrees
==
'auto'
:
...
@@ -969,15 +1036,19 @@ class NavigatorOptimizer(Optimizer):
...
@@ -969,15 +1036,19 @@ class NavigatorOptimizer(Optimizer):
self
.
ignore_newtrees
=
ignore_newtrees
self
.
ignore_newtrees
=
ignore_newtrees
self
.
failure_callback
=
failure_callback
self
.
failure_callback
=
failure_callback
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
def
attach_updater
(
self
,
env
,
importer
,
pruner
,
chin
=
None
):
"""Install some Env listeners to help the navigator deal with the ignore_trees-related functionality.
"""
Install some Env listeners to help the navigator deal with the
ignore_trees-related functionality.
:param importer: function that will be called whenever when optimizations add stuff to the graph.
:param importer: function that will be called whenever when
:param pruner: function to be called when optimizations remove stuff from graph.
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
:returns: The Env plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
"""
if
self
.
ignore_newtrees
:
if
self
.
ignore_newtrees
:
importer
=
None
importer
=
None
...
@@ -1010,21 +1081,22 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1010,21 +1081,22 @@ class NavigatorOptimizer(Optimizer):
if
u
is
not
None
:
if
u
is
not
None
:
env
.
remove_feature
(
u
)
env
.
remove_feature
(
u
)
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
def
process_node
(
self
,
env
,
node
,
lopt
=
None
):
"""
"""
This function will use `lopt` to `transform` the `node`. The `transform` method will
This function will use `lopt` to `transform` the `node`. The
return either False or a list of Variables that are intended to replace `node.outputs`.
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
If the env accepts the replacement, then the optimization is
successful, and this
If the env accepts the replacement, then the optimization is
function returns True.
successful, and this
function returns True.
If there are no replacement candidates or the env rejects the
replacements, this
If there are no replacement candidates or the env rejects the
function returns False.
replacements, this
function returns False.
:param env: an Env
:param env: an Env
:param node: an Apply instance in `env`
:param node: an Apply instance in `env`
:param lopt: a LocalOptimizer instance that may have a better idea for
how to compute
:param lopt: a LocalOptimizer instance that may have a better idea for
node's outputs.
how to compute
node's outputs.
:rtype: Bool
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `env`.
:returns: True iff the `node`'s outputs were replaced in the `env`.
...
@@ -1034,16 +1106,19 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1034,16 +1106,19 @@ class NavigatorOptimizer(Optimizer):
replacements
=
lopt
.
transform
(
node
)
replacements
=
lopt
.
transform
(
node
)
except
Exception
,
e
:
except
Exception
,
e
:
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
self
.
failure_callback
(
e
,
self
,
[(
x
,
None
)
for
x
in
node
.
outputs
],
lopt
)
return
False
return
False
else
:
else
:
raise
raise
if
replacements
is
False
or
replacements
is
None
:
if
replacements
is
False
or
replacements
is
None
:
return
False
return
False
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
if
not
isinstance
(
replacements
,
(
tuple
,
list
)):
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. Expected list or tuple.'
%
lopt
)
raise
TypeError
(
'Optimizer
%
s gave wrong type of replacement. '
'Expected list or tuple.'
%
lopt
)
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
if
len
(
node
.
outputs
)
!=
len
(
replacements
):
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
raise
ValueError
(
'Optimizer
%
s gave wrong number of replacements'
%
lopt
)
# If an output would be replaced by itself, no need to perform
# If an output would be replaced by itself, no need to perform
# the replacement
# the replacement
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
repl_pairs
=
[(
r
,
rnew
)
for
r
,
rnew
in
zip
(
node
.
outputs
,
replacements
)
...
@@ -1056,8 +1131,8 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1056,8 +1131,8 @@ class NavigatorOptimizer(Optimizer):
except
Exception
,
e
:
except
Exception
,
e
:
# This means the replacements were rejected by the env.
# This means the replacements were rejected by the env.
#
#
# This is not supposed to happen. The default failure_callback
will print a
# This is not supposed to happen. The default failure_callback
# traceback as a warning.
#
will print a
traceback as a warning.
if
self
.
failure_callback
is
not
None
:
if
self
.
failure_callback
is
not
None
:
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
)
self
.
failure_callback
(
e
,
self
,
repl_pairs
,
lopt
)
return
False
return
False
...
@@ -1072,26 +1147,33 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1072,26 +1147,33 @@ class NavigatorOptimizer(Optimizer):
self
.
local_opt
.
add_requirements
(
env
)
self
.
local_opt
.
add_requirements
(
env
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s (
%
i)"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
))
if
depth
!=
0
:
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
class
TopoOptimizer
(
NavigatorOptimizer
):
class
TopoOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
order
=
'in_to_out'
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
if
order
not
in
[
'out_to_in'
,
'in_to_out'
]:
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
raise
ValueError
(
"order must be 'out_to_in' or 'in_to_out'"
)
self
.
order
=
order
self
.
order
=
order
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
,
start_from
=
None
):
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
start_from
=
env
.
outputs
if
start_from
is
None
:
start_from
=
env
.
outputs
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
q
=
deque
(
graph
.
io_toposort
(
env
.
inputs
,
start_from
))
def
importer
(
node
):
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
q
.
append
(
node
)
q
.
append
(
node
)
def
pruner
(
node
):
def
pruner
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
try
:
try
:
...
@@ -1114,14 +1196,16 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -1114,14 +1196,16 @@ class TopoOptimizer(NavigatorOptimizer):
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
class
OpKeyOptimizer
(
NavigatorOptimizer
):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""WRITEME"""
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
def
__init__
(
self
,
local_opt
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
if
not
hasattr
(
local_opt
,
'op_key'
):
if
not
hasattr
(
local_opt
,
'op_key'
):
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have an 'op_key' method."
)
raise
TypeError
(
"LocalOptimizer for OpKeyOptimizer must have "
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
"an 'op_key' method."
)
NavigatorOptimizer
.
__init__
(
self
,
local_opt
,
ignore_newtrees
,
failure_callback
)
def
apply
(
self
,
env
):
def
apply
(
self
,
env
):
op
=
self
.
local_opt
.
op_key
()
op
=
self
.
local_opt
.
op_key
()
...
@@ -1129,9 +1213,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -1129,9 +1213,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
q
=
reduce
(
list
.
__iadd__
,
map
(
env
.
get_nodes
,
op
))
q
=
reduce
(
list
.
__iadd__
,
map
(
env
.
get_nodes
,
op
))
else
:
else
:
q
=
list
(
env
.
get_nodes
(
op
))
q
=
list
(
env
.
get_nodes
(
op
))
def
importer
(
node
):
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
if
node
.
op
==
op
:
q
.
append
(
node
)
if
node
.
op
==
op
:
q
.
append
(
node
)
def
pruner
(
node
):
def
pruner
(
node
):
if
node
is
not
current_node
and
node
.
op
==
op
:
if
node
is
not
current_node
and
node
.
op
==
op
:
try
:
try
:
...
@@ -1159,7 +1246,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
...
@@ -1159,7 +1246,6 @@ class OpKeyOptimizer(NavigatorOptimizer):
env
.
extend
(
toolbox
.
NodeFinder
())
env
.
extend
(
toolbox
.
NodeFinder
())
class
ChangeTracker
:
class
ChangeTracker
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
changed
=
False
self
.
changed
=
False
...
@@ -1176,17 +1262,19 @@ class ChangeTracker:
...
@@ -1176,17 +1262,19 @@ class ChangeTracker:
def
on_attach
(
self
,
env
):
def
on_attach
(
self
,
env
):
env
.
change_tracker
=
self
env
.
change_tracker
=
self
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
class
EquilibriumOptimizer
(
NavigatorOptimizer
):
def
__init__
(
self
,
def
__init__
(
self
,
optimizers
,
optimizers
,
failure_callback
=
None
,
failure_callback
=
None
,
max_depth
=
None
,
max_depth
=
None
,
max_use_ratio
=
None
):
max_use_ratio
=
None
):
"""
"""
:param optimizers: list or set of local or global optimizations to
apply until
:param optimizers: list or set of local or global optimizations to
equilibrium.
apply until
equilibrium.
:param max_use_ratio: each optimizer can be applied at most (size of graph * this number)
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
:param max_depth: TODO what does this do? (EquilibriumDB sets it to 5)
...
@@ -1194,8 +1282,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1194,8 +1282,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
super
(
EquilibriumOptimizer
,
self
)
.
__init__
(
None
,
None
,
ignore_newtrees
=
True
,
ignore_newtrees
=
True
,
failure_callback
=
failure_callback
)
failure_callback
=
failure_callback
)
self
.
local_optimizers
=
[]
self
.
local_optimizers
=
[]
self
.
global_optimizers
=
[]
self
.
global_optimizers
=
[]
...
@@ -1206,7 +1294,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1206,7 +1294,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
global_optimizers
.
append
(
opt
)
self
.
global_optimizers
.
append
(
opt
)
self
.
max_depth
=
max_depth
self
.
max_depth
=
max_depth
self
.
max_use_ratio
=
max_use_ratio
self
.
max_use_ratio
=
max_use_ratio
assert
self
.
max_use_ratio
is
not
None
,
'max_use_ratio has to be a number'
assert
self
.
max_use_ratio
is
not
None
,
(
'max_use_ratio has to be a number'
)
def
add_requirements
(
self
,
env
):
def
add_requirements
(
self
,
env
):
super
(
EquilibriumOptimizer
,
self
)
.
add_requirements
(
env
)
super
(
EquilibriumOptimizer
,
self
)
.
add_requirements
(
env
)
...
@@ -1216,7 +1305,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1216,7 +1305,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
for
opt
in
self
.
global_optimizers
:
for
opt
in
self
.
global_optimizers
:
opt
.
add_requirements
(
env
)
opt
.
add_requirements
(
env
)
def
apply
(
self
,
env
,
start_from
=
None
):
def
apply
(
self
,
env
,
start_from
=
None
):
if
start_from
is
None
:
if
start_from
is
None
:
start_from
=
env
.
outputs
start_from
=
env
.
outputs
changed
=
True
changed
=
True
...
@@ -1251,9 +1340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1251,9 +1340,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
nb_nodes
.
append
(
len
(
q
))
nb_nodes
.
append
(
len
(
q
))
max_nb_nodes
=
max
(
max_nb_nodes
,
len
(
q
))
max_nb_nodes
=
max
(
max_nb_nodes
,
len
(
q
))
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
def
importer
(
node
):
def
importer
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
q
.
append
(
node
)
q
.
append
(
node
)
def
pruner
(
node
):
def
pruner
(
node
):
if
node
is
not
current_node
:
if
node
is
not
current_node
:
try
:
try
:
...
@@ -1277,7 +1368,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1277,7 +1368,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
or
getattr
(
lopt
,
"__name__"
,
""
))
or
getattr
(
lopt
,
"__name__"
,
""
))
if
node
not
in
env
.
nodes
:
if
node
not
in
env
.
nodes
:
break
# go to next node
# go to next node
break
finally
:
finally
:
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
self
.
detach_updater
(
env
,
u
)
#TODO: erase this line, it's redundant at best
self
.
detach_updater
(
env
,
u
)
#TODO: erase this line, it's redundant at best
...
@@ -1314,10 +1406,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1314,10 +1406,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
' '
*
level
,
self
.
__class__
.
__name__
,
name
,
id
(
self
))
print
>>
stream
,
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
))
if
depth
!=
0
:
if
depth
!=
0
:
for
lopt
in
self
.
local_optimizers
:
for
lopt
in
self
.
local_optimizers
:
lopt
.
print_summary
(
stream
,
level
=
level
+
2
,
depth
=
depth
-
1
)
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
#################
#################
...
@@ -1340,7 +1434,8 @@ def _check_chain(r, chain):
...
@@ -1340,7 +1434,8 @@ def _check_chain(r, chain):
return
False
return
False
else
:
else
:
try
:
try
:
if
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
):
if
(
issubclass
(
elem
,
op
.
Op
)
and
not
isinstance
(
r
.
owner
.
op
,
elem
)):
return
False
return
False
except
TypeError
:
except
TypeError
:
return
False
return
False
...
@@ -1354,6 +1449,7 @@ def _check_chain(r, chain):
...
@@ -1354,6 +1449,7 @@ def _check_chain(r, chain):
return
(
r
is
not
None
)
return
(
r
is
not
None
)
#_check_chain.n_calls = 0
#_check_chain.n_calls = 0
def
check_chain
(
r
,
*
chain
):
def
check_chain
(
r
,
*
chain
):
"""WRITEME"""
"""WRITEME"""
if
isinstance
(
r
,
graph
.
Apply
):
if
isinstance
(
r
,
graph
.
Apply
):
...
@@ -1378,7 +1474,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1378,7 +1474,7 @@ def pre_greedy_local_optimizer(list_optimizations, out):
add additional node to the inputs of the node, it can
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
be needed to call this function multiple time.
'''
'''
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
def
local_recursive_function
(
list_opt
,
out
,
optimized_vars
,
depth
):
if
not
getattr
(
out
,
'owner'
,
None
):
if
not
getattr
(
out
,
'owner'
,
None
):
return
[
out
],
optimized_vars
return
[
out
],
optimized_vars
node
=
out
.
owner
node
=
out
.
owner
...
@@ -1390,11 +1486,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1390,11 +1486,11 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else
:
else
:
if
inp
.
owner
:
if
inp
.
owner
:
outs
,
optimized_vars
=
local_recursive_function
(
outs
,
optimized_vars
=
local_recursive_function
(
list_opt
list_opt
,
,
inp
inp
,
,
optimized_vars
optimized_vars
,
,
depth
+
1
)
depth
+
1
)
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
for
k
,
v
in
zip
(
inp
.
owner
.
outputs
,
outs
):
optimized_vars
[
k
]
=
v
optimized_vars
[
k
]
=
v
nw_in
=
outs
[
inp
.
owner
.
outputs
.
index
(
inp
)]
nw_in
=
outs
[
inp
.
owner
.
outputs
.
index
(
inp
)]
...
@@ -1408,10 +1504,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1408,10 +1504,10 @@ def pre_greedy_local_optimizer(list_optimizations, out):
ret
=
opt
.
transform
(
node
)
ret
=
opt
.
transform
(
node
)
if
ret
is
not
False
and
ret
is
not
None
:
if
ret
is
not
False
and
ret
is
not
None
:
assert
len
(
ret
)
==
len
(
node
.
outputs
)
assert
len
(
ret
)
==
len
(
node
.
outputs
)
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
for
k
,
v
in
zip
(
node
.
outputs
,
ret
):
optimized_vars
[
k
]
=
v
optimized_vars
[
k
]
=
v
results
=
ret
results
=
ret
if
ret
[
0
]
.
owner
:
if
ret
[
0
]
.
owner
:
node
=
out
.
owner
node
=
out
.
owner
else
:
else
:
break
break
...
@@ -1422,8 +1518,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
...
@@ -1422,8 +1518,6 @@ def pre_greedy_local_optimizer(list_optimizations, out):
return
final_outs
[
0
]
return
final_outs
[
0
]
############
############
### Misc ###
### Misc ###
############
############
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论