Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2e3f17cb
提交
2e3f17cb
authored
10月 21, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Apply pyupgrade to theano.gof
上级
beb93a44
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
186 行增加
和
223 行删除
+186
-223
callcache.py
theano/gof/callcache.py
+4
-4
cc.py
theano/gof/cc.py
+14
-15
cmodule.py
theano/gof/cmodule.py
+36
-37
compiledir.py
theano/gof/compiledir.py
+17
-32
compilelock.py
theano/gof/compilelock.py
+3
-9
destroyhandler.py
theano/gof/destroyhandler.py
+4
-6
fg.py
theano/gof/fg.py
+0
-2
graph.py
theano/gof/graph.py
+5
-7
lazylinker_c.py
theano/gof/lazylinker_c.py
+2
-2
link.py
theano/gof/link.py
+6
-8
op.py
theano/gof/op.py
+16
-19
opt.py
theano/gof/opt.py
+0
-0
optdb.py
theano/gof/optdb.py
+15
-15
params_type.py
theano/gof/params_type.py
+10
-7
sched.py
theano/gof/sched.py
+4
-4
toolbox.py
theano/gof/toolbox.py
+15
-15
type.py
theano/gof/type.py
+10
-12
unify.py
theano/gof/unify.py
+7
-7
utils.py
theano/gof/utils.py
+13
-16
vm.py
theano/gof/vm.py
+5
-6
没有找到文件。
theano/gof/callcache.py
浏览文件 @
2e3f17cb
...
@@ -6,15 +6,15 @@ import six.moves.cPickle as pickle
...
@@ -6,15 +6,15 @@ import six.moves.cPickle as pickle
_logger
=
logging
.
getLogger
(
"theano.gof.callcache"
)
_logger
=
logging
.
getLogger
(
"theano.gof.callcache"
)
class
CallCache
(
object
)
:
class
CallCache
:
def
__init__
(
self
,
filename
=
None
):
def
__init__
(
self
,
filename
=
None
):
self
.
filename
=
filename
self
.
filename
=
filename
try
:
try
:
if
filename
is
None
:
if
filename
is
None
:
raise
IO
Error
(
"bad filename"
)
# just goes to except
raise
OS
Error
(
"bad filename"
)
# just goes to except
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
)
as
f
:
self
.
cache
=
pickle
.
load
(
f
)
self
.
cache
=
pickle
.
load
(
f
)
except
IO
Error
:
except
OS
Error
:
self
.
cache
=
{}
self
.
cache
=
{}
def
persist
(
self
,
filename
=
None
):
def
persist
(
self
,
filename
=
None
):
...
...
theano/gof/cc.py
浏览文件 @
2e3f17cb
...
@@ -8,7 +8,6 @@ import sys
...
@@ -8,7 +8,6 @@ import sys
from
copy
import
copy
from
copy
import
copy
import
numpy
as
np
import
numpy
as
np
from
six
import
reraise
,
string_types
from
six.moves
import
StringIO
from
six.moves
import
StringIO
import
theano
import
theano
...
@@ -233,7 +232,7 @@ def struct_gen(args, struct_builders, blocks, sub):
...
@@ -233,7 +232,7 @@ def struct_gen(args, struct_builders, blocks, sub):
# declares the storage
# declares the storage
storage_decl
=
"
\n
"
.
join
([
"PyObject*
%
s;"
%
arg
for
arg
in
args
])
storage_decl
=
"
\n
"
.
join
([
"PyObject*
%
s;"
%
arg
for
arg
in
args
])
# in the constructor, sets the storage to the arguments
# in the constructor, sets the storage to the arguments
storage_set
=
"
\n
"
.
join
([
"this->
%
s =
%
s;"
%
(
arg
,
arg
)
for
arg
in
args
])
storage_set
=
"
\n
"
.
join
([
"this->
{} = {};"
.
format
(
arg
,
arg
)
for
arg
in
args
])
# increments the storage's refcount in the constructor
# increments the storage's refcount in the constructor
storage_incref
=
"
\n
"
.
join
([
"Py_XINCREF(
%
s);"
%
arg
for
arg
in
args
])
storage_incref
=
"
\n
"
.
join
([
"Py_XINCREF(
%
s);"
%
arg
for
arg
in
args
])
# decrements the storage's refcount in the destructor
# decrements the storage's refcount in the destructor
...
@@ -359,7 +358,7 @@ def get_c_declare(r, name, sub):
...
@@ -359,7 +358,7 @@ def get_c_declare(r, name, sub):
[
[
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
r
.
clients
for
(
c
,
_
)
in
r
.
clients
if
not
isinstance
(
c
,
str
ing_types
)
if
not
isinstance
(
c
,
str
)
]
]
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
...
@@ -405,7 +404,7 @@ def get_c_extract(r, name, sub):
...
@@ -405,7 +404,7 @@ def get_c_extract(r, name, sub):
[
[
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
r
.
clients
for
(
c
,
_
)
in
r
.
clients
if
not
isinstance
(
c
,
str
ing_types
)
if
not
isinstance
(
c
,
str
)
]
]
):
):
# check_broadcast is just an hack to easily remove just the
# check_broadcast is just an hack to easily remove just the
...
@@ -415,7 +414,7 @@ def get_c_extract(r, name, sub):
...
@@ -415,7 +414,7 @@ def get_c_extract(r, name, sub):
[
[
getattr
(
c
.
op
,
"check_broadcast"
,
True
)
getattr
(
c
.
op
,
"check_broadcast"
,
True
)
for
(
c
,
_
)
in
r
.
clients
for
(
c
,
_
)
in
r
.
clients
if
not
isinstance
(
c
,
str
ing_types
)
if
not
isinstance
(
c
,
str
)
]
]
):
):
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
True
)
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
True
)
...
@@ -849,7 +848,7 @@ class CLinker(link.Linker):
...
@@ -849,7 +848,7 @@ class CLinker(link.Linker):
pass
pass
else
:
else
:
# The following will be executed if the "try" block succeeds
# The following will be executed if the "try" block succeeds
assert
isinstance
(
c_support_code_apply
[
-
1
],
str
ing_types
),
(
assert
isinstance
(
c_support_code_apply
[
-
1
],
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_support_code_apply"
str
(
node
.
op
)
+
" didn't return a string for c_support_code_apply"
)
)
...
@@ -858,13 +857,13 @@ class CLinker(link.Linker):
...
@@ -858,13 +857,13 @@ class CLinker(link.Linker):
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
pass
pass
else
:
else
:
assert
isinstance
(
c_init_code_apply
[
-
1
],
str
ing_types
),
(
assert
isinstance
(
c_init_code_apply
[
-
1
],
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_init_code_apply"
str
(
node
.
op
)
+
" didn't return a string for c_init_code_apply"
)
)
try
:
try
:
struct_init
=
op
.
c_init_code_struct
(
node
,
name
,
sub_struct
)
struct_init
=
op
.
c_init_code_struct
(
node
,
name
,
sub_struct
)
assert
isinstance
(
struct_init
,
str
ing_types
),
(
assert
isinstance
(
struct_init
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_init_code_struct"
str
(
node
.
op
)
+
" didn't return a string for c_init_code_struct"
)
)
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -872,7 +871,7 @@ class CLinker(link.Linker):
...
@@ -872,7 +871,7 @@ class CLinker(link.Linker):
try
:
try
:
struct_support
=
op
.
c_support_code_struct
(
node
,
name
)
struct_support
=
op
.
c_support_code_struct
(
node
,
name
)
assert
isinstance
(
struct_support
,
str
ing_types
),
(
assert
isinstance
(
struct_support
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_support_code_struct"
str
(
node
.
op
)
+
" didn't return a string for c_support_code_struct"
)
)
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -880,7 +879,7 @@ class CLinker(link.Linker):
...
@@ -880,7 +879,7 @@ class CLinker(link.Linker):
try
:
try
:
struct_cleanup
=
op
.
c_cleanup_code_struct
(
node
,
name
)
struct_cleanup
=
op
.
c_cleanup_code_struct
(
node
,
name
)
assert
isinstance
(
struct_cleanup
,
str
ing_types
),
(
assert
isinstance
(
struct_cleanup
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_cleanup_code_struct"
str
(
node
.
op
)
+
" didn't return a string for c_cleanup_code_struct"
)
)
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -891,7 +890,7 @@ class CLinker(link.Linker):
...
@@ -891,7 +890,7 @@ class CLinker(link.Linker):
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
raise
NotImplementedError
(
"
%
s cannot produce C code"
%
op
)
raise
NotImplementedError
(
"
%
s cannot produce C code"
%
op
)
assert
isinstance
(
behavior
,
str
ing_types
),
(
assert
isinstance
(
behavior
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_code"
str
(
node
.
op
)
+
" didn't return a string for c_code"
)
)
# To help understand what is following. It help read the c code.
# To help understand what is following. It help read the c code.
...
@@ -1429,7 +1428,7 @@ class CLinker(link.Linker):
...
@@ -1429,7 +1428,7 @@ class CLinker(link.Linker):
# set of variables that have been computed by nodes we have
# set of variables that have been computed by nodes we have
# seen 'so far' in the loop below
# seen 'so far' in the loop below
fgraph_computed_set
=
set
()
fgraph_computed_set
=
set
()
fgraph_inputs_dict
=
dict
((
i
,
(
-
1
,
pos
))
for
pos
,
i
in
enumerate
(
fgraph
.
inputs
))
fgraph_inputs_dict
=
{
i
:
(
-
1
,
pos
)
for
pos
,
i
in
enumerate
(
fgraph
.
inputs
)}
constant_ids
=
dict
()
constant_ids
=
dict
()
op_pos
=
{}
# Apply -> topological position
op_pos
=
{}
# Apply -> topological position
...
@@ -1794,7 +1793,7 @@ class CLinker(link.Linker):
...
@@ -1794,7 +1793,7 @@ class CLinker(link.Linker):
return
code
.
getvalue
()
return
code
.
getvalue
()
class
_CThunk
(
object
)
:
class
_CThunk
:
"""
"""
A thunk with a C implementation.
A thunk with a C implementation.
...
@@ -1864,7 +1863,7 @@ class _CThunk(object):
...
@@ -1864,7 +1863,7 @@ class _CThunk(object):
)
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
raise
raise
r
eraise
(
exc_type
,
exc_value
,
exc_trace
)
r
aise
exc_value
.
with_traceback
(
exc_trace
)
class
OpWiseCLinker
(
link
.
LocalLinker
):
class
OpWiseCLinker
(
link
.
LocalLinker
):
...
@@ -2130,7 +2129,7 @@ class DualLinker(link.Linker):
...
@@ -2130,7 +2129,7 @@ class DualLinker(link.Linker):
return
f
,
i1
,
o1
return
f
,
i1
,
o1
class
HideC
(
object
)
:
class
HideC
:
def
__hide
(
*
args
):
def
__hide
(
*
args
):
raise
utils
.
MethodNotDefined
()
raise
utils
.
MethodNotDefined
()
...
...
theano/gof/cmodule.py
浏览文件 @
2e3f17cb
...
@@ -19,7 +19,7 @@ import warnings
...
@@ -19,7 +19,7 @@ import warnings
import
numpy.distutils
import
numpy.distutils
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
from
six
import
BytesIO
,
StringIO
,
b
,
string_types
from
six
import
BytesIO
,
StringIO
,
b
import
theano
import
theano
from
theano
import
config
from
theano
import
config
...
@@ -53,8 +53,6 @@ class MissingGXX(Exception):
...
@@ -53,8 +53,6 @@ class MissingGXX(Exception):
"""
"""
pass
def
debug_counter
(
name
,
every
=
1
):
def
debug_counter
(
name
,
every
=
1
):
"""
"""
...
@@ -70,10 +68,10 @@ def debug_counter(name, every=1):
...
@@ -70,10 +68,10 @@ def debug_counter(name, every=1):
setattr
(
debug_counter
,
name
,
getattr
(
debug_counter
,
name
,
0
)
+
1
)
setattr
(
debug_counter
,
name
,
getattr
(
debug_counter
,
name
,
0
)
+
1
)
n
=
getattr
(
debug_counter
,
name
)
n
=
getattr
(
debug_counter
,
name
)
if
n
%
every
==
0
:
if
n
%
every
==
0
:
print
(
"debug_counter [
%
s]:
%
s"
%
(
name
,
n
),
file
=
sys
.
stderr
)
print
(
"debug_counter [
{}]: {}"
.
format
(
name
,
n
),
file
=
sys
.
stderr
)
class
ExtFunction
(
object
)
:
class
ExtFunction
:
"""
"""
A C function to put into a DynamicModule.
A C function to put into a DynamicModule.
...
@@ -118,10 +116,12 @@ class ExtFunction(object):
...
@@ -118,10 +116,12 @@ class ExtFunction(object):
It goes into the DynamicModule's method table.
It goes into the DynamicModule's method table.
"""
"""
return
'
\t
{"
%
s",
%
s,
%
s, "
%
s"}'
%
(
self
.
name
,
self
.
name
,
self
.
method
,
self
.
doc
)
return
'
\t
{{"{}", {}, {}, "{}"}}'
.
format
(
self
.
name
,
self
.
name
,
self
.
method
,
self
.
doc
)
class
DynamicModule
(
object
)
:
class
DynamicModule
:
def
__init__
(
self
,
name
=
None
):
def
__init__
(
self
,
name
=
None
):
assert
name
is
None
,
(
assert
name
is
None
,
(
"The 'name' parameter of DynamicModule"
"The 'name' parameter of DynamicModule"
...
@@ -436,7 +436,7 @@ def get_module_hash(src_code, key):
...
@@ -436,7 +436,7 @@ def get_module_hash(src_code, key):
# This should be the C++ compilation command line parameters or the
# This should be the C++ compilation command line parameters or the
# libraries to link against.
# libraries to link against.
to_hash
+=
list
(
key_element
)
to_hash
+=
list
(
key_element
)
elif
isinstance
(
key_element
,
str
ing_types
):
elif
isinstance
(
key_element
,
str
):
if
key_element
.
startswith
(
"md5:"
)
or
key_element
.
startswith
(
"hash:"
):
if
key_element
.
startswith
(
"md5:"
)
or
key_element
.
startswith
(
"hash:"
):
# This is actually a sha256 hash of the config options.
# This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old Theano.
# Currently, we still keep md5 to don't break old Theano.
...
@@ -481,7 +481,7 @@ def get_safe_part(key):
...
@@ -481,7 +481,7 @@ def get_safe_part(key):
# rest of the cache mechanism will just skip that key.
# rest of the cache mechanism will just skip that key.
hash
=
None
hash
=
None
for
key_element
in
c_link_key
[
1
:]:
for
key_element
in
c_link_key
[
1
:]:
if
isinstance
(
key_element
,
str
ing_types
):
if
isinstance
(
key_element
,
str
):
if
key_element
.
startswith
(
"md5:"
):
if
key_element
.
startswith
(
"md5:"
):
hash
=
key_element
[
4
:]
hash
=
key_element
[
4
:]
break
break
...
@@ -492,7 +492,7 @@ def get_safe_part(key):
...
@@ -492,7 +492,7 @@ def get_safe_part(key):
return
key
[
0
]
+
(
hash
,)
return
key
[
0
]
+
(
hash
,)
class
KeyData
(
object
)
:
class
KeyData
:
"""
"""
Used to store the key information in the cache.
Used to store the key information in the cache.
...
@@ -594,7 +594,7 @@ class KeyData(object):
...
@@ -594,7 +594,7 @@ class KeyData(object):
pass
pass
class
ModuleCache
(
object
)
:
class
ModuleCache
:
"""
"""
Interface to the cache of dynamically compiled modules on disk.
Interface to the cache of dynamically compiled modules on disk.
...
@@ -1011,7 +1011,7 @@ class ModuleCache(object):
...
@@ -1011,7 +1011,7 @@ class ModuleCache(object):
# Test to see that the file is [present and] readable.
# Test to see that the file is [present and] readable.
open
(
entry
)
.
close
()
open
(
entry
)
.
close
()
gone
=
False
gone
=
False
except
IO
Error
:
except
OS
Error
:
gone
=
True
gone
=
True
if
gone
:
if
gone
:
...
@@ -1140,7 +1140,7 @@ class ModuleCache(object):
...
@@ -1140,7 +1140,7 @@ class ModuleCache(object):
key_pkl
=
os
.
path
.
join
(
location
,
"key.pkl"
)
key_pkl
=
os
.
path
.
join
(
location
,
"key.pkl"
)
assert
not
os
.
path
.
exists
(
key_pkl
)
assert
not
os
.
path
.
exists
(
key_pkl
)
key_data
=
KeyData
(
key_data
=
KeyData
(
keys
=
set
([
key
])
,
module_hash
=
module_hash
,
key_pkl
=
key_pkl
,
entry
=
name
keys
=
{
key
}
,
module_hash
=
module_hash
,
key_pkl
=
key_pkl
,
entry
=
name
)
)
key_broken
=
False
key_broken
=
False
...
@@ -1518,7 +1518,7 @@ class ModuleCache(object):
...
@@ -1518,7 +1518,7 @@ class ModuleCache(object):
fname
=
os
.
path
.
join
(
self
.
dirname
,
filename
,
"key.pkl"
)
fname
=
os
.
path
.
join
(
self
.
dirname
,
filename
,
"key.pkl"
)
open
(
fname
)
.
close
()
open
(
fname
)
.
close
()
has_key
=
True
has_key
=
True
except
IO
Error
:
except
OS
Error
:
has_key
=
False
has_key
=
False
if
not
has_key
:
if
not
has_key
:
# Use the compiled file by default
# Use the compiled file by default
...
@@ -1724,12 +1724,10 @@ def std_lib_dirs_and_libs():
...
@@ -1724,12 +1724,10 @@ def std_lib_dirs_and_libs():
for
f
,
lib
in
[(
"libpython27.a"
,
"libpython 1.2"
)]:
for
f
,
lib
in
[(
"libpython27.a"
,
"libpython 1.2"
)]:
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
libdir
,
f
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
libdir
,
f
)):
print
(
print
(
(
"Your Python version is from Canopy. "
"Your Python version is from Canopy. "
+
"You need to install the package '"
+
"You need to install the package '"
+
lib
+
lib
+
"' from Canopy package manager."
+
"' from Canopy package manager."
)
)
)
libdirs
=
[
libdirs
=
[
# Used in older Canopy
# Used in older Canopy
...
@@ -1747,12 +1745,10 @@ def std_lib_dirs_and_libs():
...
@@ -1747,12 +1745,10 @@ def std_lib_dirs_and_libs():
]
]
):
):
print
(
print
(
(
"Your Python version is from Canopy. "
"Your Python version is from Canopy. "
+
"You need to install the package '"
+
"You need to install the package '"
+
lib
+
lib
+
"' from Canopy package manager."
+
"' from Canopy package manager."
)
)
)
python_lib_dirs
.
insert
(
0
,
libdir
)
python_lib_dirs
.
insert
(
0
,
libdir
)
std_lib_dirs_and_libs
.
data
=
[
libname
],
python_lib_dirs
std_lib_dirs_and_libs
.
data
=
[
libname
],
python_lib_dirs
...
@@ -1833,15 +1829,15 @@ def gcc_llvm():
...
@@ -1833,15 +1829,15 @@ def gcc_llvm():
# Normally this should not happen as we should not try to
# Normally this should not happen as we should not try to
# compile when g++ is not available. If this happen, it
# compile when g++ is not available. If this happen, it
# will crash later so supposing it is not llvm is "safe".
# will crash later so supposing it is not llvm is "safe".
output
=
b
(
""
)
output
=
b
""
gcc_llvm
.
is_llvm
=
b
(
"llvm"
)
in
output
gcc_llvm
.
is_llvm
=
b
"llvm"
in
output
return
gcc_llvm
.
is_llvm
return
gcc_llvm
.
is_llvm
gcc_llvm
.
is_llvm
=
None
gcc_llvm
.
is_llvm
=
None
class
Compiler
(
object
)
:
class
Compiler
:
"""
"""
Meta compiler that offer some generic function.
Meta compiler that offer some generic function.
...
@@ -2077,7 +2073,7 @@ class GCC_compiler(Compiler):
...
@@ -2077,7 +2073,7 @@ class GCC_compiler(Compiler):
# as stdin (which is the default) results in the process
# as stdin (which is the default) results in the process
# waiting forever without returning. For that reason,
# waiting forever without returning. For that reason,
# we use a pipe, and use the empty string as input.
# we use a pipe, and use the empty string as input.
(
stdout
,
stderr
)
=
p
.
communicate
(
input
=
b
(
""
)
)
(
stdout
,
stderr
)
=
p
.
communicate
(
input
=
b
""
)
if
p
.
returncode
!=
0
:
if
p
.
returncode
!=
0
:
return
None
return
None
...
@@ -2355,7 +2351,7 @@ class GCC_compiler(Compiler):
...
@@ -2355,7 +2351,7 @@ class GCC_compiler(Compiler):
line
.
startswith
(
"#define hypot _hypot"
)
for
line
in
config_h
line
.
startswith
(
"#define hypot _hypot"
)
for
line
in
config_h
):
):
cxxflags
.
append
(
"-D_hypot=hypot"
)
cxxflags
.
append
(
"-D_hypot=hypot"
)
except
IO
Error
:
except
OS
Error
:
pass
pass
return
cxxflags
return
cxxflags
...
@@ -2472,9 +2468,9 @@ class GCC_compiler(Compiler):
...
@@ -2472,9 +2468,9 @@ class GCC_compiler(Compiler):
if
dist_suffix
is
not
None
and
dist_suffix
!=
""
:
if
dist_suffix
is
not
None
and
dist_suffix
!=
""
:
suffix
=
dist_suffix
suffix
=
dist_suffix
filepath
=
"
%
s
%
s"
%
(
module_name
,
suffix
)
filepath
=
"
{}{}"
.
format
(
module_name
,
suffix
)
else
:
else
:
filepath
=
"
%
s.
%
s"
%
(
module_name
,
get_lib_extension
())
filepath
=
"
{}.{}"
.
format
(
module_name
,
get_lib_extension
())
lib_filename
=
os
.
path
.
join
(
location
,
filepath
)
lib_filename
=
os
.
path
.
join
(
location
,
filepath
)
...
@@ -2488,10 +2484,13 @@ class GCC_compiler(Compiler):
...
@@ -2488,10 +2484,13 @@ class GCC_compiler(Compiler):
# to support path that includes spaces, we need to wrap it with double quotes on Windows
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
cmd
.
extend
(
cmd
.
extend
(
[
"-I
%
s
%
s
%
s"
%
(
path_wrapper
,
idir
,
path_wrapper
)
for
idir
in
include_dirs
]
[
"-I{}{}{}"
.
format
(
path_wrapper
,
idir
,
path_wrapper
)
for
idir
in
include_dirs
]
)
)
cmd
.
extend
(
cmd
.
extend
(
[
"-L
%
s
%
s
%
s"
%
(
path_wrapper
,
ldir
,
path_wrapper
)
for
ldir
in
lib_dirs
]
[
"-L
{}{}{}"
.
format
(
path_wrapper
,
ldir
,
path_wrapper
)
for
ldir
in
lib_dirs
]
)
)
if
hide_symbols
and
sys
.
platform
!=
"win32"
:
if
hide_symbols
and
sys
.
platform
!=
"win32"
:
# This has been available since gcc 4.0 so we suppose it
# This has been available since gcc 4.0 so we suppose it
...
@@ -2501,8 +2500,8 @@ class GCC_compiler(Compiler):
...
@@ -2501,8 +2500,8 @@ class GCC_compiler(Compiler):
# improved loading times on most platforms (win32 is
# improved loading times on most platforms (win32 is
# different, as usual).
# different, as usual).
cmd
.
append
(
"-fvisibility=hidden"
)
cmd
.
append
(
"-fvisibility=hidden"
)
cmd
.
extend
([
"-o"
,
"
%
s
%
s
%
s"
%
(
path_wrapper
,
lib_filename
,
path_wrapper
)])
cmd
.
extend
([
"-o"
,
"
{}{}{}"
.
format
(
path_wrapper
,
lib_filename
,
path_wrapper
)])
cmd
.
append
(
"
%
s
%
s
%
s"
%
(
path_wrapper
,
cppfilename
,
path_wrapper
))
cmd
.
append
(
"
{}{}{}"
.
format
(
path_wrapper
,
cppfilename
,
path_wrapper
))
cmd
.
extend
([
"-l
%
s"
%
l
for
l
in
libs
])
cmd
.
extend
([
"-l
%
s"
%
l
for
l
in
libs
])
# print >> sys.stderr, 'COMPILING W CMD', cmd
# print >> sys.stderr, 'COMPILING W CMD', cmd
_logger
.
debug
(
"Running cmd:
%
s"
,
" "
.
join
(
cmd
))
_logger
.
debug
(
"Running cmd:
%
s"
,
" "
.
join
(
cmd
))
...
...
theano/gof/compiledir.py
浏览文件 @
2e3f17cb
...
@@ -4,7 +4,6 @@ import shutil
...
@@ -4,7 +4,6 @@ import shutil
import
numpy
as
np
import
numpy
as
np
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
from
six
import
string_types
import
theano
import
theano
from
theano
import
config
from
theano
import
config
...
@@ -46,7 +45,7 @@ def cleanup():
...
@@ -46,7 +45,7 @@ def cleanup():
# force the removing of key
# force the removing of key
have_npy_abi_version
=
False
have_npy_abi_version
=
False
break
break
elif
isinstance
(
obj
,
str
ing_types
):
elif
isinstance
(
obj
,
str
):
if
obj
.
startswith
(
"NPY_ABI_VERSION=0x"
):
if
obj
.
startswith
(
"NPY_ABI_VERSION=0x"
):
have_npy_abi_version
=
True
have_npy_abi_version
=
True
elif
obj
.
startswith
(
"c_compiler_str="
):
elif
obj
.
startswith
(
"c_compiler_str="
):
...
@@ -67,7 +66,7 @@ def cleanup():
...
@@ -67,7 +66,7 @@ def cleanup():
if
keydata
.
key_pkl
!=
filename
:
if
keydata
.
key_pkl
!=
filename
:
keydata
.
key_pkl
=
filename
keydata
.
key_pkl
=
filename
keydata
.
remove_key
(
key
)
keydata
.
remove_key
(
key
)
except
IO
Error
:
except
OS
Error
:
_logger
.
error
(
_logger
.
error
(
"Could not remove file '
%
s'. To complete "
"Could not remove file '
%
s'. To complete "
"the clean-up, please remove manually "
"the clean-up, please remove manually "
...
@@ -84,7 +83,7 @@ def cleanup():
...
@@ -84,7 +83,7 @@ def cleanup():
"the directory containing it."
,
"the directory containing it."
,
filename
,
filename
,
)
)
except
IO
Error
:
except
OS
Error
:
_logger
.
error
(
_logger
.
error
(
"Could not clean up this directory: '
%
s'. To complete "
"Could not clean up this directory: '
%
s'. To complete "
"the clean-up, please remove it manually."
,
"the clean-up, please remove it manually."
,
...
@@ -126,29 +125,21 @@ def print_compiledir_content():
...
@@ -126,29 +125,21 @@ def print_compiledir_content():
try
:
try
:
keydata
=
pickle
.
load
(
file
)
keydata
=
pickle
.
load
(
file
)
ops
=
list
(
ops
=
list
(
set
(
{
x
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Op
)}
[
x
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Op
)
]
)
)
)
# Whatever the case, we count compilations for OP classes.
# Whatever the case, we count compilations for OP classes.
for
op_class
in
set
([
op
.
__class__
for
op
in
ops
])
:
for
op_class
in
{
op
.
__class__
for
op
in
ops
}
:
table_op_class
.
setdefault
(
op_class
,
0
)
table_op_class
.
setdefault
(
op_class
,
0
)
table_op_class
[
op_class
]
+=
1
table_op_class
[
op_class
]
+=
1
if
len
(
ops
)
==
0
:
if
len
(
ops
)
==
0
:
zeros_op
+=
1
zeros_op
+=
1
else
:
else
:
types
=
list
(
types
=
list
(
set
(
{
[
x
x
for
x
in
flatten
(
keydata
.
keys
)
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Type
)
if
isinstance
(
x
,
theano
.
gof
.
Type
)
}
]
)
)
)
compile_start
=
compile_end
=
float
(
"nan"
)
compile_start
=
compile_end
=
float
(
"nan"
)
for
fn
in
os
.
listdir
(
os
.
path
.
join
(
compiledir
,
dir
)):
for
fn
in
os
.
listdir
(
os
.
path
.
join
(
compiledir
,
dir
)):
...
@@ -177,7 +168,7 @@ def print_compiledir_content():
...
@@ -177,7 +168,7 @@ def print_compiledir_content():
nb_keys
.
setdefault
(
len
(
keydata
.
keys
),
0
)
nb_keys
.
setdefault
(
len
(
keydata
.
keys
),
0
)
nb_keys
[
len
(
keydata
.
keys
)]
+=
1
nb_keys
[
len
(
keydata
.
keys
)]
+=
1
except
IO
Error
:
except
OS
Error
:
pass
pass
except
AttributeError
:
except
AttributeError
:
_logger
.
error
(
"Could not read key file '
%
s'."
,
filename
)
_logger
.
error
(
"Could not read key file '
%
s'."
,
filename
)
...
@@ -221,16 +212,12 @@ def print_compiledir_content():
...
@@ -221,16 +212,12 @@ def print_compiledir_content():
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_total_size
=
sum
([
sz
for
_
,
sz
,
_
in
big_key_files
])
big_total_size
=
sum
([
sz
for
_
,
sz
,
_
in
big_key_files
])
print
(
print
(
(
"There are directories with key files bigger than
%
d bytes "
"There are directories with key files bigger than
%
d bytes "
"(they probably contain big tensor constants)"
%
max_key_file_size
"(they probably contain big tensor constants)"
%
max_key_file_size
)
)
)
print
(
print
(
(
"They use
%
d bytes out of
%
d (total size used by all key files)"
"They use
%
d bytes out of
%
d (total size used by all key files)"
""
%
(
big_total_size
,
total_key_sizes
)
""
%
(
big_total_size
,
total_key_sizes
)
)
)
)
for
dir
,
size
,
ops
in
big_key_files
:
for
dir
,
size
,
ops
in
big_key_files
:
...
@@ -246,10 +233,8 @@ def print_compiledir_content():
...
@@ -246,10 +233,8 @@ def print_compiledir_content():
print
(
n_k
,
n_m
)
print
(
n_k
,
n_m
)
print
()
print
()
print
(
print
(
(
"Skipped
%
d files that contained 0 op "
"Skipped
%
d files that contained 0 op "
"(are they always theano.scalar ops?)"
%
zeros_op
"(are they always theano.scalar ops?)"
%
zeros_op
)
)
)
...
...
theano/gof/compilelock.py
浏览文件 @
2e3f17cb
...
@@ -10,7 +10,6 @@ import time
...
@@ -10,7 +10,6 @@ import time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
numpy
as
np
import
numpy
as
np
from
six
import
PY3
from
theano
import
config
from
theano
import
config
...
@@ -283,14 +282,9 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1):
...
@@ -283,14 +282,9 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1):
nb_wait
+=
1
nb_wait
+=
1
time
.
sleep
(
random
.
uniform
(
min_wait
,
max_wait
))
time
.
sleep
(
random
.
uniform
(
min_wait
,
max_wait
))
if
PY3
:
exception
=
FileExistsError
# noqa
else
:
exception
=
OSError
try
:
try
:
os
.
mkdir
(
tmp_dir
)
os
.
mkdir
(
tmp_dir
)
except
exception
:
except
FileExistsError
:
# Error while creating the directory: someone else
# Error while creating the directory: someone else
# must have tried at the exact same time.
# must have tried at the exact same time.
nb_error
+=
1
nb_error
+=
1
...
@@ -332,7 +326,7 @@ def refresh_lock(lock_file):
...
@@ -332,7 +326,7 @@ def refresh_lock(lock_file):
unique id, using a new (randomly generated) id, which is also returned.
unique id, using a new (randomly generated) id, which is also returned.
"""
"""
unique_id
=
"
%
s_
%
s_
%
s"
%
(
unique_id
=
"
{}_{}_{}"
.
format
(
os
.
getpid
(),
os
.
getpid
(),
""
.
join
([
str
(
random
.
randint
(
0
,
9
))
for
i
in
range
(
10
)]),
""
.
join
([
str
(
random
.
randint
(
0
,
9
))
for
i
in
range
(
10
)]),
hostname
,
hostname
,
...
@@ -355,7 +349,7 @@ def refresh_lock(lock_file):
...
@@ -355,7 +349,7 @@ def refresh_lock(lock_file):
return
unique_id
return
unique_id
class
Unlocker
(
object
)
:
class
Unlocker
:
"""
"""
Class wrapper around release mechanism so that the lock is automatically
Class wrapper around release mechanism so that the lock is automatically
released when the program exits (even when crashing or being interrupted),
released when the program exits (even when crashing or being interrupted),
...
...
theano/gof/destroyhandler.py
浏览文件 @
2e3f17cb
...
@@ -22,8 +22,6 @@ class ProtocolError(Exception):
...
@@ -22,8 +22,6 @@ class ProtocolError(Exception):
"""
"""
pass
def
_contains_cycle
(
fgraph
,
orderings
):
def
_contains_cycle
(
fgraph
,
orderings
):
"""
"""
...
@@ -762,17 +760,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -762,17 +760,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# OPT: pre-compute this on import
# OPT: pre-compute this on import
tolerate_same
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_same"
,
[])
tolerate_same
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_same"
,
[])
assert
isinstance
(
tolerate_same
,
list
)
assert
isinstance
(
tolerate_same
,
list
)
tolerated
=
set
(
tolerated
=
{
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
)
}
tolerated
.
add
(
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
tolerate_aliased
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_aliased"
,
[]
app
.
op
,
"destroyhandler_tolerate_aliased"
,
[]
)
)
assert
isinstance
(
tolerate_aliased
,
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
set
(
ignored
=
{
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
)
}
for
i
,
input
in
enumerate
(
app
.
inputs
):
for
i
,
input
in
enumerate
(
app
.
inputs
):
if
i
in
ignored
:
if
i
in
ignored
:
continue
continue
...
...
theano/gof/fg.py
浏览文件 @
2e3f17cb
...
@@ -26,8 +26,6 @@ class InconsistencyError(Exception):
...
@@ -26,8 +26,6 @@ class InconsistencyError(Exception):
"""
"""
pass
class
MissingInputError
(
Exception
):
class
MissingInputError
(
Exception
):
"""
"""
...
...
theano/gof/graph.py
浏览文件 @
2e3f17cb
...
@@ -7,8 +7,6 @@ from collections import deque
...
@@ -7,8 +7,6 @@ from collections import deque
from
copy
import
copy
from
copy
import
copy
from
itertools
import
count
from
itertools
import
count
from
six
import
integer_types
,
string_types
import
theano
import
theano
from
theano
import
config
from
theano
import
config
from
theano.gof.utils
import
(
from
theano.gof.utils
import
(
...
@@ -174,7 +172,7 @@ class Apply(Node):
...
@@ -174,7 +172,7 @@ class Apply(Node):
raise
ValueError
(
raise
ValueError
(
"
%
s.default_output should be an output index."
%
self
.
op
"
%
s.default_output should be an output index."
%
self
.
op
)
)
elif
not
isinstance
(
do
,
int
eger_types
):
elif
not
isinstance
(
do
,
int
):
raise
ValueError
(
"
%
s.default_output should be an int or long"
%
self
.
op
)
raise
ValueError
(
"
%
s.default_output should be an int or long"
%
self
.
op
)
elif
do
<
0
or
do
>=
len
(
self
.
outputs
):
elif
do
<
0
or
do
>=
len
(
self
.
outputs
):
raise
ValueError
(
"
%
s.default_output is out of range."
%
self
.
op
)
raise
ValueError
(
"
%
s.default_output is out of range."
%
self
.
op
)
...
@@ -395,11 +393,11 @@ class Variable(Node):
...
@@ -395,11 +393,11 @@ class Variable(Node):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
self
.
owner
=
owner
self
.
owner
=
owner
if
index
is
not
None
and
not
isinstance
(
index
,
int
eger_types
):
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
raise
TypeError
(
"index must be an int"
,
index
)
raise
TypeError
(
"index must be an int"
,
index
)
self
.
index
=
index
self
.
index
=
index
if
name
is
not
None
and
not
isinstance
(
name
,
str
ing_types
):
if
name
is
not
None
and
not
isinstance
(
name
,
str
):
raise
TypeError
(
"name must be a string"
,
name
)
raise
TypeError
(
"name must be a string"
,
name
)
self
.
name
=
name
self
.
name
=
name
...
@@ -1156,7 +1154,7 @@ default_leaf_formatter = str
...
@@ -1156,7 +1154,7 @@ default_leaf_formatter = str
def
default_node_formatter
(
op
,
argstrings
):
def
default_node_formatter
(
op
,
argstrings
):
return
"
%
s(
%
s)"
%
(
op
.
op
,
", "
.
join
(
argstrings
))
return
"
{}({})"
.
format
(
op
.
op
,
", "
.
join
(
argstrings
))
def
io_connection_pattern
(
inputs
,
outputs
):
def
io_connection_pattern
(
inputs
,
outputs
):
...
@@ -1331,7 +1329,7 @@ def view_roots(r):
...
@@ -1331,7 +1329,7 @@ def view_roots(r):
if
owner
is
not
None
:
if
owner
is
not
None
:
try
:
try
:
view_map
=
owner
.
op
.
view_map
view_map
=
owner
.
op
.
view_map
view_map
=
dict
((
owner
.
outputs
[
o
],
i
)
for
o
,
i
in
view_map
.
items
())
view_map
=
{
owner
.
outputs
[
o
]:
i
for
o
,
i
in
view_map
.
items
()}
except
AttributeError
:
except
AttributeError
:
return
[
r
]
return
[
r
]
if
r
in
view_map
:
if
r
in
view_map
:
...
...
theano/gof/lazylinker_c.py
浏览文件 @
2e3f17cb
...
@@ -59,11 +59,11 @@ try:
...
@@ -59,11 +59,11 @@ try:
if
not
os
.
path
.
exists
(
init_file
):
if
not
os
.
path
.
exists
(
init_file
):
try
:
try
:
open
(
init_file
,
"w"
)
.
close
()
open
(
init_file
,
"w"
)
.
close
()
except
IO
Error
as
e
:
except
OS
Error
as
e
:
if
os
.
path
.
exists
(
init_file
):
if
os
.
path
.
exists
(
init_file
):
pass
# has already been created
pass
# has already been created
else
:
else
:
e
.
args
+=
(
"
%
s exist?
%
s"
%
(
location
,
os
.
path
.
exists
(
location
)),)
e
.
args
+=
(
"
{} exist? {}"
.
format
(
location
,
os
.
path
.
exists
(
location
)),)
raise
raise
_need_reload
=
False
_need_reload
=
False
...
...
theano/gof/link.py
浏览文件 @
2e3f17cb
...
@@ -4,7 +4,6 @@ from copy import copy, deepcopy
...
@@ -4,7 +4,6 @@ from copy import copy, deepcopy
from
sys
import
getsizeof
from
sys
import
getsizeof
import
numpy
as
np
import
numpy
as
np
from
six
import
reraise
from
six.moves
import
StringIO
from
six.moves
import
StringIO
import
theano
import
theano
...
@@ -123,7 +122,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
...
@@ -123,7 +122,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
exc_type
,
exc_value
,
exc_trace
=
exc_info
exc_type
,
exc_value
,
exc_trace
=
exc_info
if
exc_type
==
KeyboardInterrupt
:
if
exc_type
==
KeyboardInterrupt
:
# print a simple traceback from KeyboardInterrupt
# print a simple traceback from KeyboardInterrupt
r
eraise
(
exc_type
,
exc_value
,
exc_trace
)
r
aise
exc_value
.
with_traceback
(
exc_trace
)
try
:
try
:
trace
=
node
.
outputs
[
0
]
.
tag
.
trace
trace
=
node
.
outputs
[
0
]
.
tag
.
trace
except
AttributeError
:
except
AttributeError
:
...
@@ -315,11 +314,11 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
...
@@ -315,11 +314,11 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg
+=
", TotalSize:
%
s Byte(s)
\n
"
%
item
[
3
]
detailed_err_msg
+=
", TotalSize:
%
s Byte(s)
\n
"
%
item
[
3
]
else
:
else
:
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
" TotalSize:
%
s Byte(s)
%.3
f GB
\n
"
%
(
detailed_err_msg
+=
" TotalSize:
{} Byte(s) {:.3f} GB
\n
"
.
format
(
total_size
,
total_size
,
total_size
/
1024.0
/
1024
/
1024
,
total_size
/
1024.0
/
1024
/
1024
,
)
)
detailed_err_msg
+=
" TotalSize inputs:
%
s Byte(s)
%.3
f GB
\n
"
%
(
detailed_err_msg
+=
" TotalSize inputs:
{} Byte(s) {:.3f} GB
\n
"
.
format
(
total_size_inputs
,
total_size_inputs
,
total_size_inputs
/
1024.0
/
1024
/
1024
,
total_size_inputs
/
1024.0
/
1024
/
1024
,
)
)
...
@@ -341,11 +340,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
...
@@ -341,11 +340,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
)
)
# Some exception need extra parameter in inputs. So forget the
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
# extra long error message in that case.
pass
raise
exc_value
.
with_traceback
(
exc_trace
)
reraise
(
exc_type
,
exc_value
,
exc_trace
)
class
Linker
(
object
)
:
class
Linker
:
"""
"""
WRITEME
WRITEME
...
@@ -434,7 +432,7 @@ class Linker(object):
...
@@ -434,7 +432,7 @@ class Linker(object):
# TODO: Move this class to the compile module, where it is used (and for which it exists).
# TODO: Move this class to the compile module, where it is used (and for which it exists).
class
Container
(
object
)
:
class
Container
:
"""
"""
This class joins a variable with its computed value.
This class joins a variable with its computed value.
...
...
theano/gof/op.py
浏览文件 @
2e3f17cb
...
@@ -122,7 +122,7 @@ def compute_test_value(node):
...
@@ -122,7 +122,7 @@ def compute_test_value(node):
output
.
tag
.
test_value
=
storage_map
[
output
][
0
]
output
.
tag
.
test_value
=
storage_map
[
output
][
0
]
class
CLinkerObject
(
object
)
:
class
CLinkerObject
:
"""
"""
Standard elements of an Op or Type used with the CLinker.
Standard elements of an Op or Type used with the CLinker.
...
@@ -550,7 +550,7 @@ class CLinkerOp(CLinkerObject):
...
@@ -550,7 +550,7 @@ class CLinkerOp(CLinkerObject):
)
)
class
PureOp
(
object
)
:
class
PureOp
:
"""A class that models and constructs operations in a graph.
"""A class that models and constructs operations in a graph.
A `PureOp` instance has several responsibilities:
A `PureOp` instance has several responsibilities:
...
@@ -842,7 +842,6 @@ class Op(object2, PureOp, CLinkerOp):
...
@@ -842,7 +842,6 @@ class Op(object2, PureOp, CLinkerOp):
good to do so.
good to do so.
"""
"""
pass
def
make_c_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
def
make_c_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
"""Like make_thunk, but will only try to make a C thunk."""
"""Like make_thunk, but will only try to make a C thunk."""
...
@@ -1263,19 +1262,17 @@ class COp(Op):
...
@@ -1263,19 +1262,17 @@ class COp(Op):
section_re
=
re
.
compile
(
r"^#section ([a-zA-Z0-9_]+)$"
,
re
.
MULTILINE
)
section_re
=
re
.
compile
(
r"^#section ([a-zA-Z0-9_]+)$"
,
re
.
MULTILINE
)
backward_re
=
re
.
compile
(
r"^THEANO_(APPLY|SUPPORT)_CODE_SECTION$"
,
re
.
MULTILINE
)
backward_re
=
re
.
compile
(
r"^THEANO_(APPLY|SUPPORT)_CODE_SECTION$"
,
re
.
MULTILINE
)
# This is the set of allowed markers
# This is the set of allowed markers
SECTIONS
=
set
(
SECTIONS
=
{
[
"init_code"
,
"init_code"
,
"init_code_apply"
,
"init_code_apply"
,
"init_code_struct"
,
"init_code_struct"
,
"support_code"
,
"support_code"
,
"support_code_apply"
,
"support_code_apply"
,
"support_code_struct"
,
"support_code_struct"
,
"cleanup_code_struct"
,
"cleanup_code_struct"
,
"code"
,
"code"
,
"code_cleanup"
,
"code_cleanup"
,
}
]
)
@classmethod
@classmethod
def
get_path
(
cls
,
f
):
def
get_path
(
cls
,
f
):
...
@@ -1535,10 +1532,10 @@ class COp(Op):
...
@@ -1535,10 +1532,10 @@ class COp(Op):
def
get_sub_macros
(
self
,
sub
):
def
get_sub_macros
(
self
,
sub
):
define_macros
=
[]
define_macros
=
[]
undef_macros
=
[]
undef_macros
=
[]
define_macros
.
append
(
"#define FAIL
%
s"
%
(
self
.
_lquote_macro
(
sub
[
"fail"
]),
))
define_macros
.
append
(
"#define FAIL
{}"
.
format
(
self
.
_lquote_macro
(
sub
[
"fail"
])
))
undef_macros
.
append
(
"#undef FAIL"
)
undef_macros
.
append
(
"#undef FAIL"
)
if
"params"
in
sub
:
if
"params"
in
sub
:
define_macros
.
append
(
"#define PARAMS
%
s"
%
(
sub
[
"params"
],
))
define_macros
.
append
(
"#define PARAMS
{}"
.
format
(
sub
[
"params"
]
))
undef_macros
.
append
(
"#undef PARAMS"
)
undef_macros
.
append
(
"#undef PARAMS"
)
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
...
@@ -1584,7 +1581,7 @@ class COp(Op):
...
@@ -1584,7 +1581,7 @@ class COp(Op):
params
=
""
params
=
""
if
"params"
in
sub
:
if
"params"
in
sub
:
params
=
",
%
s"
%
(
sub
[
"params"
],
)
params
=
",
{}"
.
format
(
sub
[
"params"
]
)
# Generate the C code
# Generate the C code
return
"""
return
"""
...
...
theano/gof/opt.py
浏览文件 @
2e3f17cb
差异被折叠。
点击展开。
theano/gof/optdb.py
浏览文件 @
2e3f17cb
...
@@ -2,7 +2,7 @@ import copy
...
@@ -2,7 +2,7 @@ import copy
import
math
import
math
import
sys
import
sys
from
six
import
StringIO
,
integer_types
from
six
import
StringIO
from
theano
import
config
from
theano
import
config
from
theano.compat
import
DefaultOrderedDict
from
theano.compat
import
DefaultOrderedDict
...
@@ -10,7 +10,7 @@ from theano.gof import opt
...
@@ -10,7 +10,7 @@ from theano.gof import opt
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
class
DB
(
object
)
:
class
DB
:
def
__hash__
(
self
):
def
__hash__
(
self
):
if
not
hasattr
(
self
,
"_optimizer_idx"
):
if
not
hasattr
(
self
,
"_optimizer_idx"
):
self
.
_optimizer_idx
=
opt
.
_optimizer_idx
[
0
]
self
.
_optimizer_idx
=
opt
.
_optimizer_idx
[
0
]
...
@@ -169,7 +169,7 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
...
@@ -169,7 +169,7 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
class
Query
(
object
)
:
class
Query
:
"""
"""
Parameters
Parameters
...
@@ -296,7 +296,7 @@ class EquilibriumDB(DB):
...
@@ -296,7 +296,7 @@ class EquilibriumDB(DB):
"""
"""
def
__init__
(
self
,
ignore_newtrees
=
True
,
tracks_on_change_inputs
=
False
):
def
__init__
(
self
,
ignore_newtrees
=
True
,
tracks_on_change_inputs
=
False
):
super
(
EquilibriumDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
ignore_newtrees
=
ignore_newtrees
self
.
ignore_newtrees
=
ignore_newtrees
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
__final__
=
{}
self
.
__final__
=
{}
...
@@ -307,12 +307,12 @@ class EquilibriumDB(DB):
...
@@ -307,12 +307,12 @@ class EquilibriumDB(DB):
cleanup
=
kwtags
.
pop
(
"cleanup"
,
False
)
cleanup
=
kwtags
.
pop
(
"cleanup"
,
False
)
# An opt should not be final and clean up
# An opt should not be final and clean up
assert
not
(
final_opt
and
cleanup
)
assert
not
(
final_opt
and
cleanup
)
super
(
EquilibriumDB
,
self
)
.
register
(
name
,
obj
,
*
tags
,
**
kwtags
)
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwtags
)
self
.
__final__
[
name
]
=
final_opt
self
.
__final__
[
name
]
=
final_opt
self
.
__cleanup__
[
name
]
=
cleanup
self
.
__cleanup__
[
name
]
=
cleanup
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
_opts
=
super
(
EquilibriumDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
_opts
=
super
()
.
query
(
*
tags
,
**
kwtags
)
final_opts
=
[
o
for
o
in
_opts
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
final_opts
=
[
o
for
o
in
_opts
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
cleanup_opts
=
[
o
for
o
in
_opts
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)]
cleanup_opts
=
[
o
for
o
in
_opts
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)]
opts
=
[
o
for
o
in
_opts
if
o
not
in
final_opts
and
o
not
in
cleanup_opts
]
opts
=
[
o
for
o
in
_opts
if
o
not
in
final_opts
and
o
not
in
cleanup_opts
]
...
@@ -349,19 +349,19 @@ class SequenceDB(DB):
...
@@ -349,19 +349,19 @@ class SequenceDB(DB):
seq_opt
=
opt
.
SeqOptimizer
seq_opt
=
opt
.
SeqOptimizer
def
__init__
(
self
,
failure_callback
=
opt
.
SeqOptimizer
.
warn
):
def
__init__
(
self
,
failure_callback
=
opt
.
SeqOptimizer
.
warn
):
super
(
SequenceDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
__position__
=
{}
self
.
__position__
=
{}
self
.
failure_callback
=
failure_callback
self
.
failure_callback
=
failure_callback
def
register
(
self
,
name
,
obj
,
position
,
*
tags
):
def
register
(
self
,
name
,
obj
,
position
,
*
tags
):
super
(
SequenceDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
super
()
.
register
(
name
,
obj
,
*
tags
)
if
position
==
"last"
:
if
position
==
"last"
:
if
len
(
self
.
__position__
)
==
0
:
if
len
(
self
.
__position__
)
==
0
:
self
.
__position__
[
name
]
=
0
self
.
__position__
[
name
]
=
0
else
:
else
:
self
.
__position__
[
name
]
=
max
(
self
.
__position__
.
values
())
+
1
self
.
__position__
[
name
]
=
max
(
self
.
__position__
.
values
())
+
1
else
:
else
:
assert
isinstance
(
position
,
(
integer_types
,
float
))
assert
isinstance
(
position
,
(
(
int
,)
,
float
))
self
.
__position__
[
name
]
=
position
self
.
__position__
[
name
]
=
position
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
...
@@ -373,7 +373,7 @@ class SequenceDB(DB):
...
@@ -373,7 +373,7 @@ class SequenceDB(DB):
Only optimizations with position less than the cutoff are returned.
Only optimizations with position less than the cutoff are returned.
"""
"""
opts
=
super
(
SequenceDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
opts
=
super
()
.
query
(
*
tags
,
**
kwtags
)
position_cutoff
=
kwtags
.
pop
(
"position_cutoff"
,
config
.
optdb
.
position_cutoff
)
position_cutoff
=
kwtags
.
pop
(
"position_cutoff"
,
config
.
optdb
.
position_cutoff
)
position_dict
=
self
.
__position__
position_dict
=
self
.
__position__
...
@@ -442,7 +442,7 @@ class LocalGroupDB(DB):
...
@@ -442,7 +442,7 @@ class LocalGroupDB(DB):
def
__init__
(
def
__init__
(
self
,
apply_all_opts
=
False
,
profile
=
False
,
local_opt
=
opt
.
LocalOptGroup
self
,
apply_all_opts
=
False
,
profile
=
False
,
local_opt
=
opt
.
LocalOptGroup
):
):
super
(
LocalGroupDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
failure_callback
=
None
self
.
failure_callback
=
None
self
.
apply_all_opts
=
apply_all_opts
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
self
.
profile
=
profile
...
@@ -450,7 +450,7 @@ class LocalGroupDB(DB):
...
@@ -450,7 +450,7 @@ class LocalGroupDB(DB):
self
.
local_opt
=
local_opt
self
.
local_opt
=
local_opt
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
super
(
LocalGroupDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
super
()
.
register
(
name
,
obj
,
*
tags
)
position
=
kwargs
.
pop
(
"position"
,
"last"
)
position
=
kwargs
.
pop
(
"position"
,
"last"
)
if
position
==
"last"
:
if
position
==
"last"
:
if
len
(
self
.
__position__
)
==
0
:
if
len
(
self
.
__position__
)
==
0
:
...
@@ -458,12 +458,12 @@ class LocalGroupDB(DB):
...
@@ -458,12 +458,12 @@ class LocalGroupDB(DB):
else
:
else
:
self
.
__position__
[
name
]
=
max
(
self
.
__position__
.
values
())
+
1
self
.
__position__
[
name
]
=
max
(
self
.
__position__
.
values
())
+
1
else
:
else
:
assert
isinstance
(
position
,
(
integer_types
,
float
))
assert
isinstance
(
position
,
(
(
int
,)
,
float
))
self
.
__position__
[
name
]
=
position
self
.
__position__
[
name
]
=
position
def
query
(
self
,
*
tags
,
**
kwtags
):
def
query
(
self
,
*
tags
,
**
kwtags
):
# For the new `useless` optimizer
# For the new `useless` optimizer
opts
=
list
(
super
(
LocalGroupDB
,
self
)
.
query
(
*
tags
,
**
kwtags
))
opts
=
list
(
super
()
.
query
(
*
tags
,
**
kwtags
))
opts
.
sort
(
key
=
lambda
obj
:
(
self
.
__position__
[
obj
.
name
],
obj
.
name
))
opts
.
sort
(
key
=
lambda
obj
:
(
self
.
__position__
[
obj
.
name
],
obj
.
name
))
ret
=
self
.
local_opt
(
ret
=
self
.
local_opt
(
...
@@ -482,7 +482,7 @@ class TopoDB(DB):
...
@@ -482,7 +482,7 @@ class TopoDB(DB):
def
__init__
(
def
__init__
(
self
,
db
,
order
=
"in_to_out"
,
ignore_newtrees
=
False
,
failure_callback
=
None
self
,
db
,
order
=
"in_to_out"
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
):
super
(
TopoDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
db
=
db
self
.
db
=
db
self
.
order
=
order
self
.
order
=
order
self
.
ignore_newtrees
=
ignore_newtrees
self
.
ignore_newtrees
=
ignore_newtrees
...
...
theano/gof/params_type.py
浏览文件 @
2e3f17cb
...
@@ -153,13 +153,13 @@ class Params(dict):
...
@@ -153,13 +153,13 @@ class Params(dict):
raise
TypeError
(
raise
TypeError
(
'Params: ParamsType attribute "
%
s" not in Params args.'
%
field
'Params: ParamsType attribute "
%
s" not in Params args.'
%
field
)
)
super
(
Params
,
self
)
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
self
.
__dict__
.
update
(
__params_type__
=
params_type
,
__signatures__
=
None
)
self
.
__dict__
.
update
(
__params_type__
=
params_type
,
__signatures__
=
None
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"Params(
%
s)"
%
", "
.
join
(
return
"Params(
%
s)"
%
", "
.
join
(
[
[
(
"
%
s:
%
s:
%
s"
%
(
k
,
type
(
self
[
k
])
.
__name__
,
self
[
k
]))
(
"
{}:{}:{}"
.
format
(
k
,
type
(
self
[
k
])
.
__name__
,
self
[
k
]))
for
k
in
sorted
(
self
.
keys
())
for
k
in
sorted
(
self
.
keys
())
]
]
)
)
...
@@ -270,14 +270,14 @@ class ParamsType(Type):
...
@@ -270,14 +270,14 @@ class ParamsType(Type):
if
enum_types
:
if
enum_types
:
# We don't want same enum names in different enum types.
# We don't want same enum names in different enum types.
if
sum
(
len
(
t
)
for
t
in
enum_types
)
!=
len
(
if
sum
(
len
(
t
)
for
t
in
enum_types
)
!=
len
(
set
(
k
for
t
in
enum_types
for
k
in
t
)
{
k
for
t
in
enum_types
for
k
in
t
}
):
):
raise
AttributeError
(
raise
AttributeError
(
"ParamsType: found different enum types with common constants names."
"ParamsType: found different enum types with common constants names."
)
)
# We don't want same aliases in different enum types.
# We don't want same aliases in different enum types.
if
sum
(
len
(
t
.
aliases
)
for
t
in
enum_types
)
!=
len
(
if
sum
(
len
(
t
.
aliases
)
for
t
in
enum_types
)
!=
len
(
set
(
alias
for
t
in
enum_types
for
alias
in
t
.
aliases
)
{
alias
for
t
in
enum_types
for
alias
in
t
.
aliases
}
):
):
raise
AttributeError
(
raise
AttributeError
(
"ParamsType: found different enum types with common constants aliases."
"ParamsType: found different enum types with common constants aliases."
...
@@ -319,11 +319,14 @@ class ParamsType(Type):
...
@@ -319,11 +319,14 @@ class ParamsType(Type):
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if
key
in
self
.
__const_to_enum
:
if
key
in
self
.
__const_to_enum
:
return
self
.
__const_to_enum
[
key
][
key
]
return
self
.
__const_to_enum
[
key
][
key
]
return
super
(
ParamsType
,
self
)
.
__getattr__
(
self
,
key
)
return
super
()
.
__getattr__
(
self
,
key
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"ParamsType<
%
s>"
%
", "
.
join
(
return
"ParamsType<
%
s>"
%
", "
.
join
(
[(
"
%
s:
%
s"
%
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)]
[
(
"{}:{}"
.
format
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)
]
)
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -345,7 +348,7 @@ class ParamsType(Type):
...
@@ -345,7 +348,7 @@ class ParamsType(Type):
types_string
=
","
.
join
(
str
(
t
)
for
t
in
self
.
types
)
.
encode
(
"utf-8"
)
types_string
=
","
.
join
(
str
(
t
)
for
t
in
self
.
types
)
.
encode
(
"utf-8"
)
fields_hex
=
hashlib
.
sha256
(
fields_string
)
.
hexdigest
()
fields_hex
=
hashlib
.
sha256
(
fields_string
)
.
hexdigest
()
types_hex
=
hashlib
.
sha256
(
types_string
)
.
hexdigest
()
types_hex
=
hashlib
.
sha256
(
types_string
)
.
hexdigest
()
return
"_Params_
%
s_
%
s"
%
(
fields_hex
,
types_hex
)
return
"_Params_
{}_{}"
.
format
(
fields_hex
,
types_hex
)
def
has_type
(
self
,
theano_type
):
def
has_type
(
self
,
theano_type
):
"""
"""
...
...
theano/gof/sched.py
浏览文件 @
2e3f17cb
...
@@ -139,8 +139,8 @@ def _toposort(edges):
...
@@ -139,8 +139,8 @@ def _toposort(edges):
"""
"""
incoming_edges
=
reverse_dict
(
edges
)
incoming_edges
=
reverse_dict
(
edges
)
incoming_edges
=
dict
((
k
,
set
(
val
))
for
k
,
val
in
incoming_edges
.
items
())
incoming_edges
=
{
k
:
set
(
val
)
for
k
,
val
in
incoming_edges
.
items
()}
S
=
set
((
v
for
v
in
edges
if
v
not
in
incoming_edges
))
S
=
{
v
for
v
in
edges
if
v
not
in
incoming_edges
}
L
=
[]
L
=
[]
while
S
:
while
S
:
...
@@ -189,8 +189,8 @@ def posort(nodes, *cmps):
...
@@ -189,8 +189,8 @@ def posort(nodes, *cmps):
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
"""
"""
comes_before
=
dict
((
a
,
set
())
for
a
in
nodes
)
comes_before
=
{
a
:
set
()
for
a
in
nodes
}
comes_after
=
dict
((
a
,
set
())
for
a
in
nodes
)
comes_after
=
{
a
:
set
()
for
a
in
nodes
}
def
add_links
(
a
,
b
):
# b depends on a
def
add_links
(
a
,
b
):
# b depends on a
comes_after
[
a
]
.
add
(
b
)
comes_after
[
a
]
.
add
(
b
)
...
...
theano/gof/toolbox.py
浏览文件 @
2e3f17cb
...
@@ -21,8 +21,6 @@ class AlreadyThere(Exception):
...
@@ -21,8 +21,6 @@ class AlreadyThere(Exception):
"""
"""
pass
class
ReplacementDidntRemovedError
(
Exception
):
class
ReplacementDidntRemovedError
(
Exception
):
"""
"""
...
@@ -32,8 +30,6 @@ class ReplacementDidntRemovedError(Exception):
...
@@ -32,8 +30,6 @@ class ReplacementDidntRemovedError(Exception):
"""
"""
pass
class
BadOptimization
(
Exception
):
class
BadOptimization
(
Exception
):
"""
"""
...
@@ -103,7 +99,7 @@ class BadOptimization(Exception):
...
@@ -103,7 +99,7 @@ class BadOptimization(Exception):
old_graph
=
None
,
old_graph
=
None
,
new_graph
=
None
,
new_graph
=
None
,
):
):
super
(
BadOptimization
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
old_r
=
old_r
self
.
old_r
=
old_r
self
.
new_r
=
new_r
self
.
new_r
=
new_r
self
.
old_r_val
=
old_r_val
self
.
old_r_val
=
old_r_val
...
@@ -139,7 +135,7 @@ class BadOptimization(Exception):
...
@@ -139,7 +135,7 @@ class BadOptimization(Exception):
return
self
.
full_err
return
self
.
full_err
sio
=
StringIO
()
sio
=
StringIO
()
val_str_len_limit
=
800
val_str_len_limit
=
800
print
(
"BadOptimization Error"
,
super
(
BadOptimization
,
self
)
.
__str__
(),
file
=
sio
)
print
(
"BadOptimization Error"
,
super
()
.
__str__
(),
file
=
sio
)
print
(
" Variable: id"
,
id
(
self
.
new_r
),
self
.
new_r
,
file
=
sio
)
print
(
" Variable: id"
,
id
(
self
.
new_r
),
self
.
new_r
,
file
=
sio
)
print
(
" Op"
,
self
.
new_r
.
owner
,
file
=
sio
)
print
(
" Op"
,
self
.
new_r
.
owner
,
file
=
sio
)
print
(
" Value Type:"
,
type
(
self
.
new_r_val
),
file
=
sio
)
print
(
" Value Type:"
,
type
(
self
.
new_r_val
),
file
=
sio
)
...
@@ -225,7 +221,7 @@ class BadOptimization(Exception):
...
@@ -225,7 +221,7 @@ class BadOptimization(Exception):
return
sio
.
getvalue
()
return
sio
.
getvalue
()
class
Feature
(
object
)
:
class
Feature
:
"""
"""
Base class for FunctionGraph extensions.
Base class for FunctionGraph extensions.
...
@@ -466,7 +462,9 @@ class Validator(Feature):
...
@@ -466,7 +462,9 @@ class Validator(Feature):
r
=
uf
.
f_locals
.
get
(
"r"
,
""
)
r
=
uf
.
f_locals
.
get
(
"r"
,
""
)
reason
=
uf_info
.
function
reason
=
uf_info
.
function
print
(
print
(
"validate failed on node
%
s.
\n
Reason:
%
s,
%
s"
%
(
r
,
reason
,
e
)
"validate failed on node {}.
\n
Reason: {}, {}"
.
format
(
r
,
reason
,
e
)
)
)
raise
raise
t1
=
time
.
time
()
t1
=
time
.
time
()
...
@@ -578,7 +576,9 @@ class ReplaceValidate(History, Validator):
...
@@ -578,7 +576,9 @@ class ReplaceValidate(History, Validator):
except
Exception
as
e
:
except
Exception
as
e
:
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
if
verbose
:
if
verbose
:
print
(
"validate failed on node
%
s.
\n
Reason:
%
s,
%
s"
%
(
r
,
reason
,
e
))
print
(
"validate failed on node {}.
\n
Reason: {}, {}"
.
format
(
r
,
reason
,
e
)
)
raise
raise
if
config
.
scan
.
debug
:
if
config
.
scan
.
debug
:
scans2
=
[
scans2
=
[
...
@@ -731,15 +731,15 @@ class PrintListener(Feature):
...
@@ -731,15 +731,15 @@ class PrintListener(Feature):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
if
self
.
active
:
print
(
"-- importing:
%
s, reason:
%
s"
%
(
node
,
reason
))
print
(
"-- importing:
{}, reason: {}"
.
format
(
node
,
reason
))
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
if
self
.
active
:
print
(
"-- pruning:
%
s, reason:
%
s"
%
(
node
,
reason
))
print
(
"-- pruning:
{}, reason: {}"
.
format
(
node
,
reason
))
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
if
self
.
active
:
print
(
"-- changing (
%
s.inputs[
%
s]) from
%
s to
%
s"
%
(
node
,
i
,
r
,
new_r
))
print
(
"-- changing (
{}.inputs[{}]) from {} to {}"
.
format
(
node
,
i
,
r
,
new_r
))
class
PreserveNames
(
Feature
):
class
PreserveNames
(
Feature
):
...
@@ -918,9 +918,9 @@ def is_same_graph(var1, var2, givens=None):
...
@@ -918,9 +918,9 @@ def is_same_graph(var1, var2, givens=None):
for
to_replace
,
replace_by
in
givens
.
items
():
for
to_replace
,
replace_by
in
givens
.
items
():
# Map a substitution variable to the computational graphs it
# Map a substitution variable to the computational graphs it
# belongs to.
# belongs to.
inside
=
dict
(
inside
=
{
(
v
,
[
in_var
(
v
,
k
)
for
k
in
(
1
,
2
)])
for
v
in
(
to_replace
,
replace_by
)
v
:
[
in_var
(
v
,
k
)
for
k
in
(
1
,
2
)]
for
v
in
(
to_replace
,
replace_by
)
)
}
if
(
if
(
inside
[
to_replace
][
0
]
inside
[
to_replace
][
0
]
and
not
inside
[
to_replace
][
1
]
and
not
inside
[
to_replace
][
1
]
...
...
theano/gof/type.py
浏览文件 @
2e3f17cb
...
@@ -10,8 +10,6 @@ import ctypes
...
@@ -10,8 +10,6 @@ import ctypes
import
platform
import
platform
import
re
import
re
from
six
import
string_types
import
theano
import
theano
from
theano
import
change_flags
from
theano
import
change_flags
from
theano.gof
import
graph
,
utils
from
theano.gof
import
graph
,
utils
...
@@ -272,7 +270,7 @@ class CLinkerType(CLinkerObject):
...
@@ -272,7 +270,7 @@ class CLinkerType(CLinkerObject):
return
()
return
()
class
PureType
(
object
)
:
class
PureType
:
"""
"""
Interface specification for variable type instances.
Interface specification for variable type instances.
...
@@ -707,10 +705,10 @@ class CDataType(Type):
...
@@ -707,10 +705,10 @@ class CDataType(Type):
extra_support_code
=
""
,
extra_support_code
=
""
,
version
=
None
,
version
=
None
,
):
):
assert
isinstance
(
ctype
,
str
ing_types
)
assert
isinstance
(
ctype
,
str
)
self
.
ctype
=
ctype
self
.
ctype
=
ctype
if
freefunc
is
not
None
:
if
freefunc
is
not
None
:
assert
isinstance
(
freefunc
,
str
ing_types
)
assert
isinstance
(
freefunc
,
str
)
self
.
freefunc
=
freefunc
self
.
freefunc
=
freefunc
self
.
headers
=
tuple
(
headers
)
self
.
headers
=
tuple
(
headers
)
self
.
header_dirs
=
tuple
(
header_dirs
)
self
.
header_dirs
=
tuple
(
header_dirs
)
...
@@ -848,7 +846,7 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
...
@@ -848,7 +846,7 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
return
v
return
v
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s{
%
s}"
%
(
self
.
__class__
.
__name__
,
self
.
ctype
)
return
"
{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
self
.
ctype
)
def
__setstate__
(
self
,
dct
):
def
__setstate__
(
self
,
dct
):
self
.
__dict__
.
update
(
dct
)
self
.
__dict__
.
update
(
dct
)
...
@@ -1034,7 +1032,7 @@ class EnumType(Type, dict):
...
@@ -1034,7 +1032,7 @@ class EnumType(Type, dict):
raise
TypeError
(
raise
TypeError
(
"
%
s: some aliases have same names as constants."
%
type
(
self
)
.
__name__
"
%
s: some aliases have same names as constants."
%
type
(
self
)
.
__name__
)
)
super
(
EnumType
,
self
)
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
def
fromalias
(
self
,
alias
):
def
fromalias
(
self
,
alias
):
"""
"""
...
@@ -1060,11 +1058,11 @@ class EnumType(Type, dict):
...
@@ -1060,11 +1058,11 @@ class EnumType(Type, dict):
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
for
alias
in
self
.
aliases
:
for
alias
in
self
.
aliases
:
names_to_aliases
[
self
.
aliases
[
alias
]]
=
"(
%
s)"
%
alias
names_to_aliases
[
self
.
aliases
[
alias
]]
=
"(
%
s)"
%
alias
return
"
%
s<
%
s>(
%
s)"
%
(
return
"
{}<{}>({})"
.
format
(
type
(
self
)
.
__name__
,
type
(
self
)
.
__name__
,
self
.
ctype
,
self
.
ctype
,
", "
.
join
(
", "
.
join
(
"
%
s
%
s:
%
s"
%
(
k
,
names_to_aliases
[
k
],
self
[
k
])
"
{}{}:{}"
.
format
(
k
,
names_to_aliases
[
k
],
self
[
k
])
for
k
in
sorted
(
self
.
keys
())
for
k
in
sorted
(
self
.
keys
())
),
),
)
)
...
@@ -1298,7 +1296,7 @@ class EnumList(EnumType):
...
@@ -1298,7 +1296,7 @@ class EnumList(EnumType):
kwargs
.
update
(
ctype
=
ctype
)
kwargs
.
update
(
ctype
=
ctype
)
if
cname
is
not
None
:
if
cname
is
not
None
:
kwargs
.
update
(
cname
=
cname
)
kwargs
.
update
(
cname
=
cname
)
super
(
EnumList
,
self
)
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
class
CEnumType
(
EnumList
):
class
CEnumType
(
EnumList
):
...
@@ -1336,7 +1334,7 @@ class CEnumType(EnumList):
...
@@ -1336,7 +1334,7 @@ class CEnumType(EnumList):
return
self
.
pyint_compat_code
+
self
.
c_to_string
()
return
self
.
pyint_compat_code
+
self
.
c_to_string
()
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
):
swapped_dict
=
dict
((
v
,
k
)
for
(
k
,
v
)
in
self
.
items
())
swapped_dict
=
{
v
:
k
for
(
k
,
v
)
in
self
.
items
()}
# swapped_dict's keys are integers.
# swapped_dict's keys are integers.
return
"""
return
"""
...
@@ -1360,4 +1358,4 @@ class CEnumType(EnumList):
...
@@ -1360,4 +1358,4 @@ class CEnumType(EnumList):
)
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,
super
(
CEnumType
,
self
)
.
c_code_cache_version
())
return
(
1
,
super
()
.
c_code_cache_version
())
theano/gof/unify.py
浏览文件 @
2e3f17cb
...
@@ -19,7 +19,7 @@ from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard
...
@@ -19,7 +19,7 @@ from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard
################################
################################
class
Variable
(
object
)
:
class
Variable
:
"""
"""
Serves as a base class of variables for the purpose of unification.
Serves as a base class of variables for the purpose of unification.
"Unification" here basically means matching two patterns, see the
"Unification" here basically means matching two patterns, see the
...
@@ -46,7 +46,9 @@ class Variable(object):
...
@@ -46,7 +46,9 @@ class Variable(object):
return
(
return
(
self
.
__class__
.
__name__
self
.
__class__
.
__name__
+
"("
+
"("
+
", "
.
join
(
"
%
s=
%
s"
%
(
key
,
value
)
for
key
,
value
in
self
.
__dict__
.
items
())
+
", "
.
join
(
"{}={}"
.
format
(
key
,
value
)
for
key
,
value
in
self
.
__dict__
.
items
()
)
+
")"
+
")"
)
)
...
@@ -60,8 +62,6 @@ class FreeVariable(Variable):
...
@@ -60,8 +62,6 @@ class FreeVariable(Variable):
"""
"""
pass
class
BoundVariable
(
Variable
):
class
BoundVariable
(
Variable
):
"""
"""
...
@@ -70,7 +70,7 @@ class BoundVariable(Variable):
...
@@ -70,7 +70,7 @@ class BoundVariable(Variable):
"""
"""
def
__init__
(
self
,
name
,
value
):
def
__init__
(
self
,
name
,
value
):
super
(
BoundVariable
,
self
)
.
__init__
(
name
=
name
)
super
()
.
__init__
(
name
=
name
)
self
.
value
=
value
self
.
value
=
value
...
@@ -82,7 +82,7 @@ class OrVariable(Variable):
...
@@ -82,7 +82,7 @@ class OrVariable(Variable):
"""
"""
def
__init__
(
self
,
name
,
options
):
def
__init__
(
self
,
name
,
options
):
super
(
OrVariable
,
self
)
.
__init__
(
name
=
name
)
super
()
.
__init__
(
name
=
name
)
self
.
options
=
options
self
.
options
=
options
...
@@ -94,7 +94,7 @@ class NotVariable(Variable):
...
@@ -94,7 +94,7 @@ class NotVariable(Variable):
"""
"""
def
__init__
(
self
,
name
,
not_options
):
def
__init__
(
self
,
name
,
not_options
):
super
(
NotVariable
,
self
)
.
__init__
(
name
=
name
)
super
()
.
__init__
(
name
=
name
)
self
.
not_options
=
not_options
self
.
not_options
=
not_options
...
...
theano/gof/utils.py
浏览文件 @
2e3f17cb
...
@@ -3,7 +3,6 @@ import sys
...
@@ -3,7 +3,6 @@ import sys
import
traceback
import
traceback
import
numpy
as
np
import
numpy
as
np
from
six
import
integer_types
,
string_types
,
with_metaclass
from
six.moves
import
StringIO
from
six.moves
import
StringIO
from
theano
import
config
from
theano
import
config
...
@@ -161,8 +160,6 @@ undef = object()
...
@@ -161,8 +160,6 @@ undef = object()
class
TestValueError
(
Exception
):
class
TestValueError
(
Exception
):
"""Base exception class for all test value errors."""
"""Base exception class for all test value errors."""
pass
class
MethodNotDefined
(
Exception
):
class
MethodNotDefined
(
Exception
):
"""
"""
...
@@ -173,8 +170,6 @@ class MethodNotDefined(Exception):
...
@@ -173,8 +170,6 @@ class MethodNotDefined(Exception):
"""
"""
pass
class
MetaObject
(
type
):
class
MetaObject
(
type
):
def
__new__
(
cls
,
name
,
bases
,
dct
):
def
__new__
(
cls
,
name
,
bases
,
dct
):
...
@@ -182,7 +177,7 @@ class MetaObject(type):
...
@@ -182,7 +177,7 @@ class MetaObject(type):
if
props
is
not
None
:
if
props
is
not
None
:
if
not
isinstance
(
props
,
tuple
):
if
not
isinstance
(
props
,
tuple
):
raise
TypeError
(
"__props__ has to be a tuple"
)
raise
TypeError
(
"__props__ has to be a tuple"
)
if
not
all
(
isinstance
(
p
,
str
ing_types
)
for
p
in
props
):
if
not
all
(
isinstance
(
p
,
str
)
for
p
in
props
):
raise
TypeError
(
"elements of __props__ have to be strings"
)
raise
TypeError
(
"elements of __props__ have to be strings"
)
def
_props
(
self
):
def
_props
(
self
):
...
@@ -201,7 +196,7 @@ class MetaObject(type):
...
@@ -201,7 +196,7 @@ class MetaObject(type):
least all the original props.
least all the original props.
"""
"""
return
dict
([(
a
,
getattr
(
self
,
a
))
for
a
in
props
])
return
{
a
:
getattr
(
self
,
a
)
for
a
in
props
}
dct
[
"_props_dict"
]
=
_props_dict
dct
[
"_props_dict"
]
=
_props_dict
...
@@ -225,14 +220,16 @@ class MetaObject(type):
...
@@ -225,14 +220,16 @@ class MetaObject(type):
if
len
(
props
)
==
0
:
if
len
(
props
)
==
0
:
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s"
%
(
self
.
__class__
.
__name__
,
)
return
"
{}"
.
format
(
self
.
__class__
.
__name__
)
else
:
else
:
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s{
%
s}"
%
(
return
"
{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
", "
.
join
(
"
%
s=
%
r"
%
(
p
,
getattr
(
self
,
p
))
for
p
in
props
),
", "
.
join
(
"{}={!r}"
.
format
(
p
,
getattr
(
self
,
p
))
for
p
in
props
),
)
)
dct
[
"__str__"
]
=
__str__
dct
[
"__str__"
]
=
__str__
...
@@ -240,14 +237,14 @@ class MetaObject(type):
...
@@ -240,14 +237,14 @@ class MetaObject(type):
return
type
.
__new__
(
cls
,
name
,
bases
,
dct
)
return
type
.
__new__
(
cls
,
name
,
bases
,
dct
)
class
object2
(
with_metaclass
(
MetaObject
,
object
)
):
class
object2
(
metaclass
=
MetaObject
):
__slots__
=
[]
__slots__
=
[]
def
__ne__
(
self
,
other
):
def
__ne__
(
self
,
other
):
return
not
self
==
other
return
not
self
==
other
class
Scratchpad
(
object
)
:
class
Scratchpad
:
def
clear
(
self
):
def
clear
(
self
):
self
.
__dict__
.
clear
()
self
.
__dict__
.
clear
()
...
@@ -264,7 +261,7 @@ class Scratchpad(object):
...
@@ -264,7 +261,7 @@ class Scratchpad(object):
def
info
(
self
):
def
info
(
self
):
print
(
"<theano.gof.utils.scratchpad instance at
%
i>"
%
id
(
self
))
print
(
"<theano.gof.utils.scratchpad instance at
%
i>"
%
id
(
self
))
for
k
,
v
in
self
.
__dict__
.
items
():
for
k
,
v
in
self
.
__dict__
.
items
():
print
(
"
%
s:
%
s"
%
(
k
,
v
))
print
(
"
{}: {}"
.
format
(
k
,
v
))
class
ValidatingScratchpad
(
Scratchpad
):
class
ValidatingScratchpad
(
Scratchpad
):
...
@@ -330,7 +327,7 @@ def deprecated(filename, msg=""):
...
@@ -330,7 +327,7 @@ def deprecated(filename, msg=""):
def
g
(
*
args
,
**
kwargs
):
def
g
(
*
args
,
**
kwargs
):
if
printme
[
0
]:
if
printme
[
0
]:
print
(
"WARNING:
%
s.
%
s deprecated.
%
s"
%
(
filename
,
f
.
__name__
,
msg
))
print
(
"WARNING:
{}.{} deprecated. {}"
.
format
(
filename
,
f
.
__name__
,
msg
))
printme
[
0
]
=
False
printme
[
0
]
=
False
return
f
(
*
args
,
**
kwargs
)
return
f
(
*
args
,
**
kwargs
)
...
@@ -404,7 +401,7 @@ def toposort(prereqs_d):
...
@@ -404,7 +401,7 @@ def toposort(prereqs_d):
for
x
,
prereqs
in
prereqs_d
.
items
():
for
x
,
prereqs
in
prereqs_d
.
items
():
for
prereq
in
prereqs
:
for
prereq
in
prereqs
:
postreqs_d
.
setdefault
(
prereq
,
set
())
.
add
(
x
)
postreqs_d
.
setdefault
(
prereq
,
set
())
.
add
(
x
)
next
=
set
([
k
for
k
in
prereqs_d
if
not
prereqs_d
[
k
]])
next
=
{
k
for
k
in
prereqs_d
if
not
prereqs_d
[
k
]}
while
next
:
while
next
:
bases
=
next
bases
=
next
next
=
set
()
next
=
set
()
...
@@ -449,7 +446,7 @@ RETRY = Keyword("RETRY", False)
...
@@ -449,7 +446,7 @@ RETRY = Keyword("RETRY", False)
FAILURE
=
Keyword
(
"FAILURE"
,
False
)
FAILURE
=
Keyword
(
"FAILURE"
,
False
)
simple_types
=
integer_types
+
string_types
+
(
float
,
bool
,
None
.
__class__
,
Keyword
)
simple_types
=
(
int
,
str
,
float
,
bool
,
type
(
None
)
,
Keyword
)
ANY_TYPE
=
Keyword
(
"ANY_TYPE"
)
ANY_TYPE
=
Keyword
(
"ANY_TYPE"
)
...
...
theano/gof/vm.py
浏览文件 @
2e3f17cb
...
@@ -30,8 +30,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -30,8 +30,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for
var
in
fgraph
.
variables
:
for
var
in
fgraph
.
variables
:
viewed_by
[
var
]
=
[]
viewed_by
[
var
]
=
[]
view_of
=
{}
view_of
=
{}
pre_allocated
=
set
(
[]
)
pre_allocated
=
set
()
allocated
=
set
(
[]
)
allocated
=
set
()
for
idx
in
range
(
len
(
order
)):
for
idx
in
range
(
len
(
order
)):
node
=
order
[
idx
]
node
=
order
[
idx
]
...
@@ -120,7 +120,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -120,7 +120,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
return
reallocated_info
return
reallocated_info
class
VM
(
object
)
:
class
VM
:
"""
"""
A VM object's __call__ method evaluates a Theano program.
A VM object's __call__ method evaluates a Theano program.
...
@@ -273,7 +273,7 @@ class LoopGC(VM):
...
@@ -273,7 +273,7 @@ class LoopGC(VM):
"""
"""
def
__init__
(
self
,
nodes
,
thunks
,
pre_call_clear
,
post_thunk_clear
):
def
__init__
(
self
,
nodes
,
thunks
,
pre_call_clear
,
post_thunk_clear
):
super
(
LoopGC
,
self
)
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
super
()
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
self
.
post_thunk_clear
=
post_thunk_clear
self
.
post_thunk_clear
=
post_thunk_clear
# Some other part of Theano query that information
# Some other part of Theano query that information
self
.
allow_gc
=
True
self
.
allow_gc
=
True
...
@@ -353,7 +353,7 @@ class Stack(VM):
...
@@ -353,7 +353,7 @@ class Stack(VM):
callback
=
None
,
callback
=
None
,
callback_input
=
None
,
callback_input
=
None
,
):
):
super
(
Stack
,
self
)
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
super
()
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
self
.
allow_gc
=
allow_gc
self
.
allow_gc
=
allow_gc
self
.
message
=
""
self
.
message
=
""
...
@@ -712,7 +712,6 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e:
...
@@ -712,7 +712,6 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e:
assert
not
[
x
for
x
in
_config_var_list
if
x
.
fullname
==
"linker"
][
assert
not
[
x
for
x
in
_config_var_list
if
x
.
fullname
==
"linker"
][
0
0
]
.
default
.
startswith
(
"cvm"
),
e
]
.
default
.
startswith
(
"cvm"
),
e
pass
class
VM_Linker
(
link
.
LocalLinker
):
class
VM_Linker
(
link
.
LocalLinker
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论