Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e94955a7
提交
e94955a7
authored
10月 28, 2011
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pep8 fix.
上级
43cb38b8
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
57 行增加
和
40 行删除
+57
-40
optdb.py
theano/gof/optdb.py
+57
-40
没有找到文件。
theano/gof/optdb.py
浏览文件 @
e94955a7
import
sys
,
StringIO
import
StringIO
import
sys
if
sys
.
version_info
[:
2
]
>=
(
2
,
5
):
if
sys
.
version_info
[:
2
]
>=
(
2
,
5
):
from
collections
import
defaultdict
from
collections
import
defaultdict
else
:
else
:
from
python25
import
defaultdict
from
python25
import
defaultdict
import
numpy
import
numpy
import
opt
import
opt
from
theano.configparser
import
TheanoConfigParser
,
AddConfigVar
,
FloatParam
from
theano.configparser
import
AddConfigVar
,
FloatParam
from
theano
import
config
from
theano
import
config
AddConfigVar
(
'optdb.position_cutoff'
,
AddConfigVar
(
'optdb.position_cutoff'
,
'Where to stop eariler during optimization. It represent the position of the optimizer where to stop.'
,
'Where to stop eariler during optimization. It represent the'
' position of the optimizer where to stop.'
,
FloatParam
(
numpy
.
inf
),
FloatParam
(
numpy
.
inf
),
in_c_key
=
False
)
in_c_key
=
False
)
#upgraded to 20 to avoid EquibriumOptimizer error
#upgraded to 20 to avoid EquibriumOptimizer error
...
@@ -22,6 +24,7 @@ AddConfigVar('optdb.max_use_ratio',
...
@@ -22,6 +24,7 @@ AddConfigVar('optdb.max_use_ratio',
FloatParam
(
20
),
FloatParam
(
20
),
in_c_key
=
False
)
in_c_key
=
False
)
class
DB
(
object
):
class
DB
(
object
):
def
__hash__
(
self
):
def
__hash__
(
self
):
if
not
hasattr
(
self
,
'_optimizer_idx'
):
if
not
hasattr
(
self
,
'_optimizer_idx'
):
...
@@ -32,7 +35,7 @@ class DB(object):
...
@@ -32,7 +35,7 @@ class DB(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
__db__
=
defaultdict
(
set
)
self
.
__db__
=
defaultdict
(
set
)
self
.
_names
=
set
()
self
.
_names
=
set
()
self
.
name
=
None
#
will be reset by register
self
.
name
=
None
#
will be reset by register
#(via obj.name by the thing doing the registering)
#(via obj.name by the thing doing the registering)
def
register
(
self
,
name
,
obj
,
*
tags
):
def
register
(
self
,
name
,
obj
,
*
tags
):
...
@@ -42,13 +45,15 @@ class DB(object):
...
@@ -42,13 +45,15 @@ class DB(object):
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
,
opt
.
LocalOptimizer
)):
if
not
isinstance
(
obj
,
(
DB
,
opt
.
Optimizer
,
opt
.
LocalOptimizer
)):
raise
TypeError
(
'Object cannot be registered in OptDB'
,
obj
)
raise
TypeError
(
'Object cannot be registered in OptDB'
,
obj
)
if
name
in
self
.
__db__
:
if
name
in
self
.
__db__
:
raise
ValueError
(
'The name of the object cannot be an existing tag or the name of another existing object.'
,
obj
,
name
)
raise
ValueError
(
'The name of the object cannot be an existing'
' tag or the name of another existing object.'
,
obj
,
name
)
# This restriction is there because in many place we suppose that
# This restriction is there because in many place we suppose that
# something in the DB is there only once.
# something in the DB is there only once.
if
getattr
(
obj
,
'name'
,
""
)
in
self
.
__db__
:
if
getattr
(
obj
,
'name'
,
""
)
in
self
.
__db__
:
raise
ValueError
(
'''You can
\'
t register the same optimization
raise
ValueError
(
'''You can
\'
t register the same optimization
multiple time in a DB. Tryed to register "
%
s" again under the new name "
%
s".
multiple time in a DB. Tryed to register "
%
s" again under the new name "
%
s".
Use theano.gof.ProxyDB to work around that'''
%
(
obj
.
name
,
name
))
Use theano.gof.ProxyDB to work around that'''
%
(
obj
.
name
,
name
))
if
self
.
name
is
not
None
:
if
self
.
name
is
not
None
:
tags
=
tags
+
(
self
.
name
,)
tags
=
tags
+
(
self
.
name
,)
...
@@ -60,11 +65,12 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
...
@@ -60,11 +65,12 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def
add_tags
(
self
,
name
,
*
tags
):
def
add_tags
(
self
,
name
,
*
tags
):
obj
=
self
.
__db__
[
name
]
obj
=
self
.
__db__
[
name
]
assert
len
(
obj
)
==
1
assert
len
(
obj
)
==
1
obj
=
obj
.
copy
()
.
pop
()
obj
=
obj
.
copy
()
.
pop
()
for
tag
in
tags
:
for
tag
in
tags
:
if
tag
in
self
.
_names
:
if
tag
in
self
.
_names
:
raise
ValueError
(
'The tag of the object collides with a name.'
,
obj
,
tag
)
raise
ValueError
(
'The tag of the object collides with a name.'
,
obj
,
tag
)
self
.
__db__
[
tag
]
.
add
(
obj
)
self
.
__db__
[
tag
]
.
add
(
obj
)
def
__query__
(
self
,
q
):
def
__query__
(
self
,
q
):
...
@@ -94,36 +100,41 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
...
@@ -94,36 +100,41 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
if
len
(
tags
)
>
1
or
kwtags
:
if
len
(
tags
)
>
1
or
kwtags
:
raise
TypeError
(
'If the first argument to query is a Query, there should be no other arguments.'
,
tags
,
kwtags
)
raise
TypeError
(
'If the first argument to query is a Query,'
' there should be no other arguments.'
,
tags
,
kwtags
)
return
self
.
__query__
(
tags
[
0
])
return
self
.
__query__
(
tags
[
0
])
include
=
[
tag
[
1
:]
for
tag
in
tags
if
tag
.
startswith
(
'+'
)]
include
=
[
tag
[
1
:]
for
tag
in
tags
if
tag
.
startswith
(
'+'
)]
require
=
[
tag
[
1
:]
for
tag
in
tags
if
tag
.
startswith
(
'&'
)]
require
=
[
tag
[
1
:]
for
tag
in
tags
if
tag
.
startswith
(
'&'
)]
exclude
=
[
tag
[
1
:]
for
tag
in
tags
if
tag
.
startswith
(
'-'
)]
exclude
=
[
tag
[
1
:]
for
tag
in
tags
if
tag
.
startswith
(
'-'
)]
if
len
(
include
)
+
len
(
require
)
+
len
(
exclude
)
<
len
(
tags
):
if
len
(
include
)
+
len
(
require
)
+
len
(
exclude
)
<
len
(
tags
):
raise
ValueError
(
"All tags must start with one of the following characters: '+', '&' or '-'"
,
tags
)
raise
ValueError
(
"All tags must start with one of the following"
return
self
.
__query__
(
Query
(
include
=
include
,
" characters: '+', '&' or '-'"
,
tags
)
require
=
require
,
return
self
.
__query__
(
Query
(
include
=
include
,
exclude
=
exclude
,
require
=
require
,
subquery
=
kwtags
))
exclude
=
exclude
,
subquery
=
kwtags
))
def
__getitem__
(
self
,
name
):
def
__getitem__
(
self
,
name
):
variables
=
self
.
__db__
[
name
]
variables
=
self
.
__db__
[
name
]
if
not
variables
:
if
not
variables
:
raise
KeyError
(
"Nothing registered for '
%
s'"
%
name
)
raise
KeyError
(
"Nothing registered for '
%
s'"
%
name
)
elif
len
(
variables
)
>
1
:
elif
len
(
variables
)
>
1
:
raise
ValueError
(
'More than one match for
%
s (please use query)'
%
name
)
raise
ValueError
(
'More than one match for
%
s (please use query)'
%
name
)
for
variable
in
variables
:
for
variable
in
variables
:
return
variable
return
variable
def
print_summary
(
self
,
stream
=
sys
.
stdout
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
):
print
>>
stream
,
"
%
s (id
%
i)"
%
(
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
"
%
s (id
%
i)"
%
(
self
.
__class__
.
__name__
,
id
(
self
))
print
>>
stream
,
" names"
,
self
.
_names
print
>>
stream
,
" names"
,
self
.
_names
print
>>
stream
,
" db"
,
self
.
__db__
print
>>
stream
,
" db"
,
self
.
__db__
class
Query
(
object
):
class
Query
(
object
):
def
__init__
(
self
,
include
,
require
=
None
,
exclude
=
None
,
subquery
=
None
,
position_cutoff
=
None
):
def
__init__
(
self
,
include
,
require
=
None
,
exclude
=
None
,
subquery
=
None
,
position_cutoff
=
None
):
"""
"""
:type position_cutoff: float
:type position_cutoff: float
:param position_cutoff: Used by SequenceDB to keep only optimizer that
:param position_cutoff: Used by SequenceDB to keep only optimizer that
...
@@ -142,6 +153,7 @@ class Query(object):
...
@@ -142,6 +153,7 @@ class Query(object):
self
.
exclude
,
self
.
exclude
,
self
.
subquery
,
self
.
subquery
,
self
.
position_cutoff
)
self
.
position_cutoff
)
#remove all opt with this tag
#remove all opt with this tag
def
excluding
(
self
,
*
tags
):
def
excluding
(
self
,
*
tags
):
return
Query
(
self
.
include
,
return
Query
(
self
.
include
,
...
@@ -149,6 +161,7 @@ class Query(object):
...
@@ -149,6 +161,7 @@ class Query(object):
self
.
exclude
.
union
(
tags
),
self
.
exclude
.
union
(
tags
),
self
.
subquery
,
self
.
subquery
,
self
.
position_cutoff
)
self
.
position_cutoff
)
#keep only opt with this tag.
#keep only opt with this tag.
def
requiring
(
self
,
*
tags
):
def
requiring
(
self
,
*
tags
):
return
Query
(
self
.
include
,
return
Query
(
self
.
include
,
...
@@ -158,17 +171,16 @@ class Query(object):
...
@@ -158,17 +171,16 @@ class Query(object):
self
.
position_cutoff
)
self
.
position_cutoff
)
class
EquilibriumDB
(
DB
):
class
EquilibriumDB
(
DB
):
"""
A set of potential optimizations which should be applied in an arbitrary order until
"""
A set of potential optimizations which should be applied in an
equilibrium is reached.
arbitrary order until
equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
.. note::
.. note::
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer suppor both.
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both.
"""
"""
...
@@ -186,15 +198,15 @@ class SequenceDB(DB):
...
@@ -186,15 +198,15 @@ class SequenceDB(DB):
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
Each potential optimization is registered with a floating-point position.
Each potential optimization is registered with a floating-point position.
No matter which optimizations are selected by a query, they are carried
out in order of
No matter which optimizations are selected by a query, they are carried
increasing position.
out in order of
increasing position.
The optdb itself (`theano.compile.mode.optdb`), from which (among many
other tags) fast_run
The optdb itself (`theano.compile.mode.optdb`), from which (among many
and fast_compile optimizers are drawn is a SequenceDB.
other tags) fast_run
and fast_compile optimizers are drawn is a SequenceDB.
"""
"""
def
__init__
(
self
,
failure_callback
=
opt
.
SeqOptimizer
.
warn
):
def
__init__
(
self
,
failure_callback
=
opt
.
SeqOptimizer
.
warn
):
super
(
SequenceDB
,
self
)
.
__init__
()
super
(
SequenceDB
,
self
)
.
__init__
()
self
.
__position__
=
{}
self
.
__position__
=
{}
self
.
failure_callback
=
failure_callback
self
.
failure_callback
=
failure_callback
...
@@ -206,26 +218,29 @@ class SequenceDB(DB):
...
@@ -206,26 +218,29 @@ class SequenceDB(DB):
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
"""
"""
:type position_cutoff: float or int
:type position_cutoff: float or int
:param position_cutoff: only optimizations with position less than the cutoff are returned.
:param position_cutoff: only optimizations with position less than
the cutoff are returned.
"""
"""
opts
=
super
(
SequenceDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
opts
=
super
(
SequenceDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
position_cutoff
=
kwtags
.
pop
(
'position_cutoff'
,
config
.
optdb
.
position_cutoff
)
position_cutoff
=
kwtags
.
pop
(
'position_cutoff'
,
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
config
.
optdb
.
position_cutoff
)
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
#the call to super should have raise an error with a good message
#the call to super should have raise an error with a good message
assert
len
(
tags
)
==
1
assert
len
(
tags
)
==
1
if
getattr
(
tags
[
0
],
'position_cutoff'
,
None
):
if
getattr
(
tags
[
0
],
'position_cutoff'
,
None
):
position_cutoff
=
tags
[
0
]
.
position_cutoff
position_cutoff
=
tags
[
0
]
.
position_cutoff
opts
=
[
o
for
o
in
opts
if
self
.
__position__
[
o
.
name
]
<
position_cutoff
]
opts
=
[
o
for
o
in
opts
if
self
.
__position__
[
o
.
name
]
<
position_cutoff
]
opts
.
sort
(
key
=
lambda
obj
:
self
.
__position__
[
obj
.
name
])
opts
.
sort
(
key
=
lambda
obj
:
self
.
__position__
[
obj
.
name
])
return
opt
.
SeqOptimizer
(
opts
,
failure_callback
=
self
.
failure_callback
)
return
opt
.
SeqOptimizer
(
opts
,
failure_callback
=
self
.
failure_callback
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
):
print
>>
stream
,
"SequenceDB (id
%
i)"
%
id
(
self
)
print
>>
stream
,
"SequenceDB (id
%
i)"
%
id
(
self
)
positions
=
self
.
__position__
.
items
()
positions
=
self
.
__position__
.
items
()
def
c
(
a
,
b
):
return
cmp
(
a
[
1
],
b
[
1
])
def
c
(
a
,
b
):
return
cmp
(
a
[
1
],
b
[
1
])
positions
.
sort
(
c
)
positions
.
sort
(
c
)
print
>>
stream
,
" position"
,
positions
print
>>
stream
,
" position"
,
positions
...
@@ -240,8 +255,10 @@ class SequenceDB(DB):
...
@@ -240,8 +255,10 @@ class SequenceDB(DB):
class
ProxyDB
(
DB
):
class
ProxyDB
(
DB
):
"""
"""
This is needed as we can't register the same DB mutiple time in different position
Wrap an existing proxy.
in a SequentialDB
This is needed as we can't register the same DB mutiple time in
different position in a SequentialDB
"""
"""
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
assert
isinstance
(
db
,
DB
),
""
assert
isinstance
(
db
,
DB
),
""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论