Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2e3f17cb
提交
2e3f17cb
authored
10月 21, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Apply pyupgrade to theano.gof
上级
beb93a44
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
186 行增加
和
223 行删除
+186
-223
callcache.py
theano/gof/callcache.py
+4
-4
cc.py
theano/gof/cc.py
+14
-15
cmodule.py
theano/gof/cmodule.py
+36
-37
compiledir.py
theano/gof/compiledir.py
+17
-32
compilelock.py
theano/gof/compilelock.py
+3
-9
destroyhandler.py
theano/gof/destroyhandler.py
+4
-6
fg.py
theano/gof/fg.py
+0
-2
graph.py
theano/gof/graph.py
+5
-7
lazylinker_c.py
theano/gof/lazylinker_c.py
+2
-2
link.py
theano/gof/link.py
+6
-8
op.py
theano/gof/op.py
+16
-19
opt.py
theano/gof/opt.py
+0
-0
optdb.py
theano/gof/optdb.py
+15
-15
params_type.py
theano/gof/params_type.py
+10
-7
sched.py
theano/gof/sched.py
+4
-4
toolbox.py
theano/gof/toolbox.py
+15
-15
type.py
theano/gof/type.py
+10
-12
unify.py
theano/gof/unify.py
+7
-7
utils.py
theano/gof/utils.py
+13
-16
vm.py
theano/gof/vm.py
+5
-6
没有找到文件。
theano/gof/callcache.py
浏览文件 @
2e3f17cb
...
...
@@ -6,15 +6,15 @@ import six.moves.cPickle as pickle
_logger
=
logging
.
getLogger
(
"theano.gof.callcache"
)
class
CallCache
(
object
)
:
class
CallCache
:
def
__init__
(
self
,
filename
=
None
):
self
.
filename
=
filename
try
:
if
filename
is
None
:
raise
IO
Error
(
"bad filename"
)
# just goes to except
with
open
(
filename
,
"r"
)
as
f
:
raise
OS
Error
(
"bad filename"
)
# just goes to except
with
open
(
filename
)
as
f
:
self
.
cache
=
pickle
.
load
(
f
)
except
IO
Error
:
except
OS
Error
:
self
.
cache
=
{}
def
persist
(
self
,
filename
=
None
):
...
...
theano/gof/cc.py
浏览文件 @
2e3f17cb
...
...
@@ -8,7 +8,6 @@ import sys
from
copy
import
copy
import
numpy
as
np
from
six
import
reraise
,
string_types
from
six.moves
import
StringIO
import
theano
...
...
@@ -233,7 +232,7 @@ def struct_gen(args, struct_builders, blocks, sub):
# declares the storage
storage_decl
=
"
\n
"
.
join
([
"PyObject*
%
s;"
%
arg
for
arg
in
args
])
# in the constructor, sets the storage to the arguments
storage_set
=
"
\n
"
.
join
([
"this->
%
s =
%
s;"
%
(
arg
,
arg
)
for
arg
in
args
])
storage_set
=
"
\n
"
.
join
([
"this->
{} = {};"
.
format
(
arg
,
arg
)
for
arg
in
args
])
# increments the storage's refcount in the constructor
storage_incref
=
"
\n
"
.
join
([
"Py_XINCREF(
%
s);"
%
arg
for
arg
in
args
])
# decrements the storage's refcount in the destructor
...
...
@@ -359,7 +358,7 @@ def get_c_declare(r, name, sub):
[
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
r
.
clients
if
not
isinstance
(
c
,
str
ing_types
)
if
not
isinstance
(
c
,
str
)
]
)
or
(
r
.
owner
and
getattr
(
r
.
owner
.
op
,
"check_input"
,
config
.
check_input
)):
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
...
...
@@ -405,7 +404,7 @@ def get_c_extract(r, name, sub):
[
getattr
(
c
.
op
,
"check_input"
,
config
.
check_input
)
for
(
c
,
_
)
in
r
.
clients
if
not
isinstance
(
c
,
str
ing_types
)
if
not
isinstance
(
c
,
str
)
]
):
# check_broadcast is just an hack to easily remove just the
...
...
@@ -415,7 +414,7 @@ def get_c_extract(r, name, sub):
[
getattr
(
c
.
op
,
"check_broadcast"
,
True
)
for
(
c
,
_
)
in
r
.
clients
if
not
isinstance
(
c
,
str
ing_types
)
if
not
isinstance
(
c
,
str
)
]
):
c_extract
=
r
.
type
.
c_extract
(
name
,
sub
,
True
)
...
...
@@ -849,7 +848,7 @@ class CLinker(link.Linker):
pass
else
:
# The following will be executed if the "try" block succeeds
assert
isinstance
(
c_support_code_apply
[
-
1
],
str
ing_types
),
(
assert
isinstance
(
c_support_code_apply
[
-
1
],
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_support_code_apply"
)
...
...
@@ -858,13 +857,13 @@ class CLinker(link.Linker):
except
utils
.
MethodNotDefined
:
pass
else
:
assert
isinstance
(
c_init_code_apply
[
-
1
],
str
ing_types
),
(
assert
isinstance
(
c_init_code_apply
[
-
1
],
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_init_code_apply"
)
try
:
struct_init
=
op
.
c_init_code_struct
(
node
,
name
,
sub_struct
)
assert
isinstance
(
struct_init
,
str
ing_types
),
(
assert
isinstance
(
struct_init
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_init_code_struct"
)
except
utils
.
MethodNotDefined
:
...
...
@@ -872,7 +871,7 @@ class CLinker(link.Linker):
try
:
struct_support
=
op
.
c_support_code_struct
(
node
,
name
)
assert
isinstance
(
struct_support
,
str
ing_types
),
(
assert
isinstance
(
struct_support
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_support_code_struct"
)
except
utils
.
MethodNotDefined
:
...
...
@@ -880,7 +879,7 @@ class CLinker(link.Linker):
try
:
struct_cleanup
=
op
.
c_cleanup_code_struct
(
node
,
name
)
assert
isinstance
(
struct_cleanup
,
str
ing_types
),
(
assert
isinstance
(
struct_cleanup
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_cleanup_code_struct"
)
except
utils
.
MethodNotDefined
:
...
...
@@ -891,7 +890,7 @@ class CLinker(link.Linker):
behavior
=
op
.
c_code
(
node
,
name
,
isyms
,
osyms
,
sub
)
except
utils
.
MethodNotDefined
:
raise
NotImplementedError
(
"
%
s cannot produce C code"
%
op
)
assert
isinstance
(
behavior
,
str
ing_types
),
(
assert
isinstance
(
behavior
,
str
),
(
str
(
node
.
op
)
+
" didn't return a string for c_code"
)
# To help understand what is following. It help read the c code.
...
...
@@ -1429,7 +1428,7 @@ class CLinker(link.Linker):
# set of variables that have been computed by nodes we have
# seen 'so far' in the loop below
fgraph_computed_set
=
set
()
fgraph_inputs_dict
=
dict
((
i
,
(
-
1
,
pos
))
for
pos
,
i
in
enumerate
(
fgraph
.
inputs
))
fgraph_inputs_dict
=
{
i
:
(
-
1
,
pos
)
for
pos
,
i
in
enumerate
(
fgraph
.
inputs
)}
constant_ids
=
dict
()
op_pos
=
{}
# Apply -> topological position
...
...
@@ -1794,7 +1793,7 @@ class CLinker(link.Linker):
return
code
.
getvalue
()
class
_CThunk
(
object
)
:
class
_CThunk
:
"""
A thunk with a C implementation.
...
...
@@ -1864,7 +1863,7 @@ class _CThunk(object):
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
raise
r
eraise
(
exc_type
,
exc_value
,
exc_trace
)
r
aise
exc_value
.
with_traceback
(
exc_trace
)
class
OpWiseCLinker
(
link
.
LocalLinker
):
...
...
@@ -2130,7 +2129,7 @@ class DualLinker(link.Linker):
return
f
,
i1
,
o1
class
HideC
(
object
)
:
class
HideC
:
def
__hide
(
*
args
):
raise
utils
.
MethodNotDefined
()
...
...
theano/gof/cmodule.py
浏览文件 @
2e3f17cb
...
...
@@ -19,7 +19,7 @@ import warnings
import
numpy.distutils
import
six.moves.cPickle
as
pickle
from
six
import
BytesIO
,
StringIO
,
b
,
string_types
from
six
import
BytesIO
,
StringIO
,
b
import
theano
from
theano
import
config
...
...
@@ -53,8 +53,6 @@ class MissingGXX(Exception):
"""
pass
def
debug_counter
(
name
,
every
=
1
):
"""
...
...
@@ -70,10 +68,10 @@ def debug_counter(name, every=1):
setattr
(
debug_counter
,
name
,
getattr
(
debug_counter
,
name
,
0
)
+
1
)
n
=
getattr
(
debug_counter
,
name
)
if
n
%
every
==
0
:
print
(
"debug_counter [
%
s]:
%
s"
%
(
name
,
n
),
file
=
sys
.
stderr
)
print
(
"debug_counter [
{}]: {}"
.
format
(
name
,
n
),
file
=
sys
.
stderr
)
class
ExtFunction
(
object
)
:
class
ExtFunction
:
"""
A C function to put into a DynamicModule.
...
...
@@ -118,10 +116,12 @@ class ExtFunction(object):
It goes into the DynamicModule's method table.
"""
return
'
\t
{"
%
s",
%
s,
%
s, "
%
s"}'
%
(
self
.
name
,
self
.
name
,
self
.
method
,
self
.
doc
)
return
'
\t
{{"{}", {}, {}, "{}"}}'
.
format
(
self
.
name
,
self
.
name
,
self
.
method
,
self
.
doc
)
class
DynamicModule
(
object
)
:
class
DynamicModule
:
def
__init__
(
self
,
name
=
None
):
assert
name
is
None
,
(
"The 'name' parameter of DynamicModule"
...
...
@@ -436,7 +436,7 @@ def get_module_hash(src_code, key):
# This should be the C++ compilation command line parameters or the
# libraries to link against.
to_hash
+=
list
(
key_element
)
elif
isinstance
(
key_element
,
str
ing_types
):
elif
isinstance
(
key_element
,
str
):
if
key_element
.
startswith
(
"md5:"
)
or
key_element
.
startswith
(
"hash:"
):
# This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old Theano.
...
...
@@ -481,7 +481,7 @@ def get_safe_part(key):
# rest of the cache mechanism will just skip that key.
hash
=
None
for
key_element
in
c_link_key
[
1
:]:
if
isinstance
(
key_element
,
str
ing_types
):
if
isinstance
(
key_element
,
str
):
if
key_element
.
startswith
(
"md5:"
):
hash
=
key_element
[
4
:]
break
...
...
@@ -492,7 +492,7 @@ def get_safe_part(key):
return
key
[
0
]
+
(
hash
,)
class
KeyData
(
object
)
:
class
KeyData
:
"""
Used to store the key information in the cache.
...
...
@@ -594,7 +594,7 @@ class KeyData(object):
pass
class
ModuleCache
(
object
)
:
class
ModuleCache
:
"""
Interface to the cache of dynamically compiled modules on disk.
...
...
@@ -1011,7 +1011,7 @@ class ModuleCache(object):
# Test to see that the file is [present and] readable.
open
(
entry
)
.
close
()
gone
=
False
except
IO
Error
:
except
OS
Error
:
gone
=
True
if
gone
:
...
...
@@ -1140,7 +1140,7 @@ class ModuleCache(object):
key_pkl
=
os
.
path
.
join
(
location
,
"key.pkl"
)
assert
not
os
.
path
.
exists
(
key_pkl
)
key_data
=
KeyData
(
keys
=
set
([
key
])
,
module_hash
=
module_hash
,
key_pkl
=
key_pkl
,
entry
=
name
keys
=
{
key
}
,
module_hash
=
module_hash
,
key_pkl
=
key_pkl
,
entry
=
name
)
key_broken
=
False
...
...
@@ -1518,7 +1518,7 @@ class ModuleCache(object):
fname
=
os
.
path
.
join
(
self
.
dirname
,
filename
,
"key.pkl"
)
open
(
fname
)
.
close
()
has_key
=
True
except
IO
Error
:
except
OS
Error
:
has_key
=
False
if
not
has_key
:
# Use the compiled file by default
...
...
@@ -1724,12 +1724,10 @@ def std_lib_dirs_and_libs():
for
f
,
lib
in
[(
"libpython27.a"
,
"libpython 1.2"
)]:
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
libdir
,
f
)):
print
(
(
"Your Python version is from Canopy. "
+
"You need to install the package '"
+
lib
+
"' from Canopy package manager."
)
"Your Python version is from Canopy. "
+
"You need to install the package '"
+
lib
+
"' from Canopy package manager."
)
libdirs
=
[
# Used in older Canopy
...
...
@@ -1747,12 +1745,10 @@ def std_lib_dirs_and_libs():
]
):
print
(
(
"Your Python version is from Canopy. "
+
"You need to install the package '"
+
lib
+
"' from Canopy package manager."
)
"Your Python version is from Canopy. "
+
"You need to install the package '"
+
lib
+
"' from Canopy package manager."
)
python_lib_dirs
.
insert
(
0
,
libdir
)
std_lib_dirs_and_libs
.
data
=
[
libname
],
python_lib_dirs
...
...
@@ -1833,15 +1829,15 @@ def gcc_llvm():
# Normally this should not happen as we should not try to
# compile when g++ is not available. If this happen, it
# will crash later so supposing it is not llvm is "safe".
output
=
b
(
""
)
gcc_llvm
.
is_llvm
=
b
(
"llvm"
)
in
output
output
=
b
""
gcc_llvm
.
is_llvm
=
b
"llvm"
in
output
return
gcc_llvm
.
is_llvm
gcc_llvm
.
is_llvm
=
None
class
Compiler
(
object
)
:
class
Compiler
:
"""
Meta compiler that offer some generic function.
...
...
@@ -2077,7 +2073,7 @@ class GCC_compiler(Compiler):
# as stdin (which is the default) results in the process
# waiting forever without returning. For that reason,
# we use a pipe, and use the empty string as input.
(
stdout
,
stderr
)
=
p
.
communicate
(
input
=
b
(
""
)
)
(
stdout
,
stderr
)
=
p
.
communicate
(
input
=
b
""
)
if
p
.
returncode
!=
0
:
return
None
...
...
@@ -2355,7 +2351,7 @@ class GCC_compiler(Compiler):
line
.
startswith
(
"#define hypot _hypot"
)
for
line
in
config_h
):
cxxflags
.
append
(
"-D_hypot=hypot"
)
except
IO
Error
:
except
OS
Error
:
pass
return
cxxflags
...
...
@@ -2472,9 +2468,9 @@ class GCC_compiler(Compiler):
if
dist_suffix
is
not
None
and
dist_suffix
!=
""
:
suffix
=
dist_suffix
filepath
=
"
%
s
%
s"
%
(
module_name
,
suffix
)
filepath
=
"
{}{}"
.
format
(
module_name
,
suffix
)
else
:
filepath
=
"
%
s.
%
s"
%
(
module_name
,
get_lib_extension
())
filepath
=
"
{}.{}"
.
format
(
module_name
,
get_lib_extension
())
lib_filename
=
os
.
path
.
join
(
location
,
filepath
)
...
...
@@ -2488,10 +2484,13 @@ class GCC_compiler(Compiler):
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper
=
'"'
if
os
.
name
==
"nt"
else
""
cmd
.
extend
(
[
"-I
%
s
%
s
%
s"
%
(
path_wrapper
,
idir
,
path_wrapper
)
for
idir
in
include_dirs
]
[
"-I{}{}{}"
.
format
(
path_wrapper
,
idir
,
path_wrapper
)
for
idir
in
include_dirs
]
)
cmd
.
extend
(
[
"-L
%
s
%
s
%
s"
%
(
path_wrapper
,
ldir
,
path_wrapper
)
for
ldir
in
lib_dirs
]
[
"-L
{}{}{}"
.
format
(
path_wrapper
,
ldir
,
path_wrapper
)
for
ldir
in
lib_dirs
]
)
if
hide_symbols
and
sys
.
platform
!=
"win32"
:
# This has been available since gcc 4.0 so we suppose it
...
...
@@ -2501,8 +2500,8 @@ class GCC_compiler(Compiler):
# improved loading times on most platforms (win32 is
# different, as usual).
cmd
.
append
(
"-fvisibility=hidden"
)
cmd
.
extend
([
"-o"
,
"
%
s
%
s
%
s"
%
(
path_wrapper
,
lib_filename
,
path_wrapper
)])
cmd
.
append
(
"
%
s
%
s
%
s"
%
(
path_wrapper
,
cppfilename
,
path_wrapper
))
cmd
.
extend
([
"-o"
,
"
{}{}{}"
.
format
(
path_wrapper
,
lib_filename
,
path_wrapper
)])
cmd
.
append
(
"
{}{}{}"
.
format
(
path_wrapper
,
cppfilename
,
path_wrapper
))
cmd
.
extend
([
"-l
%
s"
%
l
for
l
in
libs
])
# print >> sys.stderr, 'COMPILING W CMD', cmd
_logger
.
debug
(
"Running cmd:
%
s"
,
" "
.
join
(
cmd
))
...
...
theano/gof/compiledir.py
浏览文件 @
2e3f17cb
...
...
@@ -4,7 +4,6 @@ import shutil
import
numpy
as
np
import
six.moves.cPickle
as
pickle
from
six
import
string_types
import
theano
from
theano
import
config
...
...
@@ -46,7 +45,7 @@ def cleanup():
# force the removing of key
have_npy_abi_version
=
False
break
elif
isinstance
(
obj
,
str
ing_types
):
elif
isinstance
(
obj
,
str
):
if
obj
.
startswith
(
"NPY_ABI_VERSION=0x"
):
have_npy_abi_version
=
True
elif
obj
.
startswith
(
"c_compiler_str="
):
...
...
@@ -67,7 +66,7 @@ def cleanup():
if
keydata
.
key_pkl
!=
filename
:
keydata
.
key_pkl
=
filename
keydata
.
remove_key
(
key
)
except
IO
Error
:
except
OS
Error
:
_logger
.
error
(
"Could not remove file '
%
s'. To complete "
"the clean-up, please remove manually "
...
...
@@ -84,7 +83,7 @@ def cleanup():
"the directory containing it."
,
filename
,
)
except
IO
Error
:
except
OS
Error
:
_logger
.
error
(
"Could not clean up this directory: '
%
s'. To complete "
"the clean-up, please remove it manually."
,
...
...
@@ -126,29 +125,21 @@ def print_compiledir_content():
try
:
keydata
=
pickle
.
load
(
file
)
ops
=
list
(
set
(
[
x
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Op
)
]
)
{
x
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Op
)}
)
# Whatever the case, we count compilations for OP classes.
for
op_class
in
set
([
op
.
__class__
for
op
in
ops
])
:
for
op_class
in
{
op
.
__class__
for
op
in
ops
}
:
table_op_class
.
setdefault
(
op_class
,
0
)
table_op_class
[
op_class
]
+=
1
if
len
(
ops
)
==
0
:
zeros_op
+=
1
else
:
types
=
list
(
set
(
[
x
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Type
)
]
)
{
x
for
x
in
flatten
(
keydata
.
keys
)
if
isinstance
(
x
,
theano
.
gof
.
Type
)
}
)
compile_start
=
compile_end
=
float
(
"nan"
)
for
fn
in
os
.
listdir
(
os
.
path
.
join
(
compiledir
,
dir
)):
...
...
@@ -177,7 +168,7 @@ def print_compiledir_content():
nb_keys
.
setdefault
(
len
(
keydata
.
keys
),
0
)
nb_keys
[
len
(
keydata
.
keys
)]
+=
1
except
IO
Error
:
except
OS
Error
:
pass
except
AttributeError
:
_logger
.
error
(
"Could not read key file '
%
s'."
,
filename
)
...
...
@@ -221,16 +212,12 @@ def print_compiledir_content():
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_total_size
=
sum
([
sz
for
_
,
sz
,
_
in
big_key_files
])
print
(
(
"There are directories with key files bigger than
%
d bytes "
"(they probably contain big tensor constants)"
%
max_key_file_size
)
"There are directories with key files bigger than
%
d bytes "
"(they probably contain big tensor constants)"
%
max_key_file_size
)
print
(
(
"They use
%
d bytes out of
%
d (total size used by all key files)"
""
%
(
big_total_size
,
total_key_sizes
)
)
"They use
%
d bytes out of
%
d (total size used by all key files)"
""
%
(
big_total_size
,
total_key_sizes
)
)
for
dir
,
size
,
ops
in
big_key_files
:
...
...
@@ -246,10 +233,8 @@ def print_compiledir_content():
print
(
n_k
,
n_m
)
print
()
print
(
(
"Skipped
%
d files that contained 0 op "
"(are they always theano.scalar ops?)"
%
zeros_op
)
"Skipped
%
d files that contained 0 op "
"(are they always theano.scalar ops?)"
%
zeros_op
)
...
...
theano/gof/compilelock.py
浏览文件 @
2e3f17cb
...
...
@@ -10,7 +10,6 @@ import time
from
contextlib
import
contextmanager
import
numpy
as
np
from
six
import
PY3
from
theano
import
config
...
...
@@ -283,14 +282,9 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1):
nb_wait
+=
1
time
.
sleep
(
random
.
uniform
(
min_wait
,
max_wait
))
if
PY3
:
exception
=
FileExistsError
# noqa
else
:
exception
=
OSError
try
:
os
.
mkdir
(
tmp_dir
)
except
exception
:
except
FileExistsError
:
# Error while creating the directory: someone else
# must have tried at the exact same time.
nb_error
+=
1
...
...
@@ -332,7 +326,7 @@ def refresh_lock(lock_file):
unique id, using a new (randomly generated) id, which is also returned.
"""
unique_id
=
"
%
s_
%
s_
%
s"
%
(
unique_id
=
"
{}_{}_{}"
.
format
(
os
.
getpid
(),
""
.
join
([
str
(
random
.
randint
(
0
,
9
))
for
i
in
range
(
10
)]),
hostname
,
...
...
@@ -355,7 +349,7 @@ def refresh_lock(lock_file):
return
unique_id
class
Unlocker
(
object
)
:
class
Unlocker
:
"""
Class wrapper around release mechanism so that the lock is automatically
released when the program exits (even when crashing or being interrupted),
...
...
theano/gof/destroyhandler.py
浏览文件 @
2e3f17cb
...
...
@@ -22,8 +22,6 @@ class ProtocolError(Exception):
"""
pass
def
_contains_cycle
(
fgraph
,
orderings
):
"""
...
...
@@ -762,17 +760,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# OPT: pre-compute this on import
tolerate_same
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_same"
,
[])
assert
isinstance
(
tolerate_same
,
list
)
tolerated
=
set
(
tolerated
=
{
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
)
}
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_aliased"
,
[]
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
set
(
ignored
=
{
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
)
}
for
i
,
input
in
enumerate
(
app
.
inputs
):
if
i
in
ignored
:
continue
...
...
theano/gof/fg.py
浏览文件 @
2e3f17cb
...
...
@@ -26,8 +26,6 @@ class InconsistencyError(Exception):
"""
pass
class
MissingInputError
(
Exception
):
"""
...
...
theano/gof/graph.py
浏览文件 @
2e3f17cb
...
...
@@ -7,8 +7,6 @@ from collections import deque
from
copy
import
copy
from
itertools
import
count
from
six
import
integer_types
,
string_types
import
theano
from
theano
import
config
from
theano.gof.utils
import
(
...
...
@@ -174,7 +172,7 @@ class Apply(Node):
raise
ValueError
(
"
%
s.default_output should be an output index."
%
self
.
op
)
elif
not
isinstance
(
do
,
int
eger_types
):
elif
not
isinstance
(
do
,
int
):
raise
ValueError
(
"
%
s.default_output should be an int or long"
%
self
.
op
)
elif
do
<
0
or
do
>=
len
(
self
.
outputs
):
raise
ValueError
(
"
%
s.default_output is out of range."
%
self
.
op
)
...
...
@@ -395,11 +393,11 @@ class Variable(Node):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
self
.
owner
=
owner
if
index
is
not
None
and
not
isinstance
(
index
,
int
eger_types
):
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
raise
TypeError
(
"index must be an int"
,
index
)
self
.
index
=
index
if
name
is
not
None
and
not
isinstance
(
name
,
str
ing_types
):
if
name
is
not
None
and
not
isinstance
(
name
,
str
):
raise
TypeError
(
"name must be a string"
,
name
)
self
.
name
=
name
...
...
@@ -1156,7 +1154,7 @@ default_leaf_formatter = str
def
default_node_formatter
(
op
,
argstrings
):
return
"
%
s(
%
s)"
%
(
op
.
op
,
", "
.
join
(
argstrings
))
return
"
{}({})"
.
format
(
op
.
op
,
", "
.
join
(
argstrings
))
def
io_connection_pattern
(
inputs
,
outputs
):
...
...
@@ -1331,7 +1329,7 @@ def view_roots(r):
if
owner
is
not
None
:
try
:
view_map
=
owner
.
op
.
view_map
view_map
=
dict
((
owner
.
outputs
[
o
],
i
)
for
o
,
i
in
view_map
.
items
())
view_map
=
{
owner
.
outputs
[
o
]:
i
for
o
,
i
in
view_map
.
items
()}
except
AttributeError
:
return
[
r
]
if
r
in
view_map
:
...
...
theano/gof/lazylinker_c.py
浏览文件 @
2e3f17cb
...
...
@@ -59,11 +59,11 @@ try:
if
not
os
.
path
.
exists
(
init_file
):
try
:
open
(
init_file
,
"w"
)
.
close
()
except
IO
Error
as
e
:
except
OS
Error
as
e
:
if
os
.
path
.
exists
(
init_file
):
pass
# has already been created
else
:
e
.
args
+=
(
"
%
s exist?
%
s"
%
(
location
,
os
.
path
.
exists
(
location
)),)
e
.
args
+=
(
"
{} exist? {}"
.
format
(
location
,
os
.
path
.
exists
(
location
)),)
raise
_need_reload
=
False
...
...
theano/gof/link.py
浏览文件 @
2e3f17cb
...
...
@@ -4,7 +4,6 @@ from copy import copy, deepcopy
from
sys
import
getsizeof
import
numpy
as
np
from
six
import
reraise
from
six.moves
import
StringIO
import
theano
...
...
@@ -123,7 +122,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
exc_type
,
exc_value
,
exc_trace
=
exc_info
if
exc_type
==
KeyboardInterrupt
:
# print a simple traceback from KeyboardInterrupt
r
eraise
(
exc_type
,
exc_value
,
exc_trace
)
r
aise
exc_value
.
with_traceback
(
exc_trace
)
try
:
trace
=
node
.
outputs
[
0
]
.
tag
.
trace
except
AttributeError
:
...
...
@@ -315,11 +314,11 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg
+=
", TotalSize:
%
s Byte(s)
\n
"
%
item
[
3
]
else
:
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
" TotalSize:
%
s Byte(s)
%.3
f GB
\n
"
%
(
detailed_err_msg
+=
" TotalSize:
{} Byte(s) {:.3f} GB
\n
"
.
format
(
total_size
,
total_size
/
1024.0
/
1024
/
1024
,
)
detailed_err_msg
+=
" TotalSize inputs:
%
s Byte(s)
%.3
f GB
\n
"
%
(
detailed_err_msg
+=
" TotalSize inputs:
{} Byte(s) {:.3f} GB
\n
"
.
format
(
total_size_inputs
,
total_size_inputs
/
1024.0
/
1024
/
1024
,
)
...
...
@@ -341,11 +340,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
)
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
pass
reraise
(
exc_type
,
exc_value
,
exc_trace
)
raise
exc_value
.
with_traceback
(
exc_trace
)
class
Linker
(
object
)
:
class
Linker
:
"""
WRITEME
...
...
@@ -434,7 +432,7 @@ class Linker(object):
# TODO: Move this class to the compile module, where it is used (and for which it exists).
class
Container
(
object
)
:
class
Container
:
"""
This class joins a variable with its computed value.
...
...
theano/gof/op.py
浏览文件 @
2e3f17cb
...
...
@@ -122,7 +122,7 @@ def compute_test_value(node):
output
.
tag
.
test_value
=
storage_map
[
output
][
0
]
class
CLinkerObject
(
object
)
:
class
CLinkerObject
:
"""
Standard elements of an Op or Type used with the CLinker.
...
...
@@ -550,7 +550,7 @@ class CLinkerOp(CLinkerObject):
)
class
PureOp
(
object
)
:
class
PureOp
:
"""A class that models and constructs operations in a graph.
A `PureOp` instance has several responsibilities:
...
...
@@ -842,7 +842,6 @@ class Op(object2, PureOp, CLinkerOp):
good to do so.
"""
pass
def
make_c_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
"""Like make_thunk, but will only try to make a C thunk."""
...
...
@@ -1263,19 +1262,17 @@ class COp(Op):
section_re
=
re
.
compile
(
r"^#section ([a-zA-Z0-9_]+)$"
,
re
.
MULTILINE
)
backward_re
=
re
.
compile
(
r"^THEANO_(APPLY|SUPPORT)_CODE_SECTION$"
,
re
.
MULTILINE
)
# This is the set of allowed markers
SECTIONS
=
set
(
[
"init_code"
,
"init_code_apply"
,
"init_code_struct"
,
"support_code"
,
"support_code_apply"
,
"support_code_struct"
,
"cleanup_code_struct"
,
"code"
,
"code_cleanup"
,
]
)
SECTIONS
=
{
"init_code"
,
"init_code_apply"
,
"init_code_struct"
,
"support_code"
,
"support_code_apply"
,
"support_code_struct"
,
"cleanup_code_struct"
,
"code"
,
"code_cleanup"
,
}
@classmethod
def
get_path
(
cls
,
f
):
...
...
@@ -1535,10 +1532,10 @@ class COp(Op):
def
get_sub_macros
(
self
,
sub
):
define_macros
=
[]
undef_macros
=
[]
define_macros
.
append
(
"#define FAIL
%
s"
%
(
self
.
_lquote_macro
(
sub
[
"fail"
]),
))
define_macros
.
append
(
"#define FAIL
{}"
.
format
(
self
.
_lquote_macro
(
sub
[
"fail"
])
))
undef_macros
.
append
(
"#undef FAIL"
)
if
"params"
in
sub
:
define_macros
.
append
(
"#define PARAMS
%
s"
%
(
sub
[
"params"
],
))
define_macros
.
append
(
"#define PARAMS
{}"
.
format
(
sub
[
"params"
]
))
undef_macros
.
append
(
"#undef PARAMS"
)
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
...
...
@@ -1584,7 +1581,7 @@ class COp(Op):
params
=
""
if
"params"
in
sub
:
params
=
",
%
s"
%
(
sub
[
"params"
],
)
params
=
",
{}"
.
format
(
sub
[
"params"
]
)
# Generate the C code
return
"""
...
...
theano/gof/opt.py
浏览文件 @
2e3f17cb
差异被折叠。
点击展开。
theano/gof/optdb.py
浏览文件 @
2e3f17cb
...
...
@@ -2,7 +2,7 @@ import copy
import
math
import
sys
from
six
import
StringIO
,
integer_types
from
six
import
StringIO
from
theano
import
config
from
theano.compat
import
DefaultOrderedDict
...
...
@@ -10,7 +10,7 @@ from theano.gof import opt
from
theano.misc.ordered_set
import
OrderedSet
class
DB
(
object
)
:
class
DB
:
def
__hash__
(
self
):
if
not
hasattr
(
self
,
"_optimizer_idx"
):
self
.
_optimizer_idx
=
opt
.
_optimizer_idx
[
0
]
...
...
@@ -169,7 +169,7 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
print
(
" db"
,
self
.
__db__
,
file
=
stream
)
class
Query
(
object
)
:
class
Query
:
"""
Parameters
...
...
@@ -296,7 +296,7 @@ class EquilibriumDB(DB):
"""
def
__init__
(
self
,
ignore_newtrees
=
True
,
tracks_on_change_inputs
=
False
):
super
(
EquilibriumDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
ignore_newtrees
=
ignore_newtrees
self
.
tracks_on_change_inputs
=
tracks_on_change_inputs
self
.
__final__
=
{}
...
...
@@ -307,12 +307,12 @@ class EquilibriumDB(DB):
cleanup
=
kwtags
.
pop
(
"cleanup"
,
False
)
# An opt should not be final and clean up
assert
not
(
final_opt
and
cleanup
)
super
(
EquilibriumDB
,
self
)
.
register
(
name
,
obj
,
*
tags
,
**
kwtags
)
super
()
.
register
(
name
,
obj
,
*
tags
,
**
kwtags
)
self
.
__final__
[
name
]
=
final_opt
self
.
__cleanup__
[
name
]
=
cleanup
def
query
(
self
,
*
tags
,
**
kwtags
):
_opts
=
super
(
EquilibriumDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
_opts
=
super
()
.
query
(
*
tags
,
**
kwtags
)
final_opts
=
[
o
for
o
in
_opts
if
self
.
__final__
.
get
(
o
.
name
,
False
)]
cleanup_opts
=
[
o
for
o
in
_opts
if
self
.
__cleanup__
.
get
(
o
.
name
,
False
)]
opts
=
[
o
for
o
in
_opts
if
o
not
in
final_opts
and
o
not
in
cleanup_opts
]
...
...
@@ -349,19 +349,19 @@ class SequenceDB(DB):
seq_opt
=
opt
.
SeqOptimizer
def
__init__
(
self
,
failure_callback
=
opt
.
SeqOptimizer
.
warn
):
super
(
SequenceDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
__position__
=
{}
self
.
failure_callback
=
failure_callback
def
register
(
self
,
name
,
obj
,
position
,
*
tags
):
super
(
SequenceDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
super
()
.
register
(
name
,
obj
,
*
tags
)
if
position
==
"last"
:
if
len
(
self
.
__position__
)
==
0
:
self
.
__position__
[
name
]
=
0
else
:
self
.
__position__
[
name
]
=
max
(
self
.
__position__
.
values
())
+
1
else
:
assert
isinstance
(
position
,
(
integer_types
,
float
))
assert
isinstance
(
position
,
(
(
int
,)
,
float
))
self
.
__position__
[
name
]
=
position
def
query
(
self
,
*
tags
,
**
kwtags
):
...
...
@@ -373,7 +373,7 @@ class SequenceDB(DB):
Only optimizations with position less than the cutoff are returned.
"""
opts
=
super
(
SequenceDB
,
self
)
.
query
(
*
tags
,
**
kwtags
)
opts
=
super
()
.
query
(
*
tags
,
**
kwtags
)
position_cutoff
=
kwtags
.
pop
(
"position_cutoff"
,
config
.
optdb
.
position_cutoff
)
position_dict
=
self
.
__position__
...
...
@@ -442,7 +442,7 @@ class LocalGroupDB(DB):
def
__init__
(
self
,
apply_all_opts
=
False
,
profile
=
False
,
local_opt
=
opt
.
LocalOptGroup
):
super
(
LocalGroupDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
failure_callback
=
None
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
...
...
@@ -450,7 +450,7 @@ class LocalGroupDB(DB):
self
.
local_opt
=
local_opt
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
super
(
LocalGroupDB
,
self
)
.
register
(
name
,
obj
,
*
tags
)
super
()
.
register
(
name
,
obj
,
*
tags
)
position
=
kwargs
.
pop
(
"position"
,
"last"
)
if
position
==
"last"
:
if
len
(
self
.
__position__
)
==
0
:
...
...
@@ -458,12 +458,12 @@ class LocalGroupDB(DB):
else
:
self
.
__position__
[
name
]
=
max
(
self
.
__position__
.
values
())
+
1
else
:
assert
isinstance
(
position
,
(
integer_types
,
float
))
assert
isinstance
(
position
,
(
(
int
,)
,
float
))
self
.
__position__
[
name
]
=
position
def
query
(
self
,
*
tags
,
**
kwtags
):
# For the new `useless` optimizer
opts
=
list
(
super
(
LocalGroupDB
,
self
)
.
query
(
*
tags
,
**
kwtags
))
opts
=
list
(
super
()
.
query
(
*
tags
,
**
kwtags
))
opts
.
sort
(
key
=
lambda
obj
:
(
self
.
__position__
[
obj
.
name
],
obj
.
name
))
ret
=
self
.
local_opt
(
...
...
@@ -482,7 +482,7 @@ class TopoDB(DB):
def
__init__
(
self
,
db
,
order
=
"in_to_out"
,
ignore_newtrees
=
False
,
failure_callback
=
None
):
super
(
TopoDB
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
db
=
db
self
.
order
=
order
self
.
ignore_newtrees
=
ignore_newtrees
...
...
theano/gof/params_type.py
浏览文件 @
2e3f17cb
...
...
@@ -153,13 +153,13 @@ class Params(dict):
raise
TypeError
(
'Params: ParamsType attribute "
%
s" not in Params args.'
%
field
)
super
(
Params
,
self
)
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
self
.
__dict__
.
update
(
__params_type__
=
params_type
,
__signatures__
=
None
)
def
__repr__
(
self
):
return
"Params(
%
s)"
%
", "
.
join
(
[
(
"
%
s:
%
s:
%
s"
%
(
k
,
type
(
self
[
k
])
.
__name__
,
self
[
k
]))
(
"
{}:{}:{}"
.
format
(
k
,
type
(
self
[
k
])
.
__name__
,
self
[
k
]))
for
k
in
sorted
(
self
.
keys
())
]
)
...
...
@@ -270,14 +270,14 @@ class ParamsType(Type):
if
enum_types
:
# We don't want same enum names in different enum types.
if
sum
(
len
(
t
)
for
t
in
enum_types
)
!=
len
(
set
(
k
for
t
in
enum_types
for
k
in
t
)
{
k
for
t
in
enum_types
for
k
in
t
}
):
raise
AttributeError
(
"ParamsType: found different enum types with common constants names."
)
# We don't want same aliases in different enum types.
if
sum
(
len
(
t
.
aliases
)
for
t
in
enum_types
)
!=
len
(
set
(
alias
for
t
in
enum_types
for
alias
in
t
.
aliases
)
{
alias
for
t
in
enum_types
for
alias
in
t
.
aliases
}
):
raise
AttributeError
(
"ParamsType: found different enum types with common constants aliases."
...
...
@@ -319,11 +319,14 @@ class ParamsType(Type):
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if
key
in
self
.
__const_to_enum
:
return
self
.
__const_to_enum
[
key
][
key
]
return
super
(
ParamsType
,
self
)
.
__getattr__
(
self
,
key
)
return
super
()
.
__getattr__
(
self
,
key
)
def
__repr__
(
self
):
return
"ParamsType<
%
s>"
%
", "
.
join
(
[(
"
%
s:
%
s"
%
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)]
[
(
"{}:{}"
.
format
(
self
.
fields
[
i
],
self
.
types
[
i
]))
for
i
in
range
(
self
.
length
)
]
)
def
__eq__
(
self
,
other
):
...
...
@@ -345,7 +348,7 @@ class ParamsType(Type):
types_string
=
","
.
join
(
str
(
t
)
for
t
in
self
.
types
)
.
encode
(
"utf-8"
)
fields_hex
=
hashlib
.
sha256
(
fields_string
)
.
hexdigest
()
types_hex
=
hashlib
.
sha256
(
types_string
)
.
hexdigest
()
return
"_Params_
%
s_
%
s"
%
(
fields_hex
,
types_hex
)
return
"_Params_
{}_{}"
.
format
(
fields_hex
,
types_hex
)
def
has_type
(
self
,
theano_type
):
"""
...
...
theano/gof/sched.py
浏览文件 @
2e3f17cb
...
...
@@ -139,8 +139,8 @@ def _toposort(edges):
"""
incoming_edges
=
reverse_dict
(
edges
)
incoming_edges
=
dict
((
k
,
set
(
val
))
for
k
,
val
in
incoming_edges
.
items
())
S
=
set
((
v
for
v
in
edges
if
v
not
in
incoming_edges
))
incoming_edges
=
{
k
:
set
(
val
)
for
k
,
val
in
incoming_edges
.
items
()}
S
=
{
v
for
v
in
edges
if
v
not
in
incoming_edges
}
L
=
[]
while
S
:
...
...
@@ -189,8 +189,8 @@ def posort(nodes, *cmps):
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
"""
comes_before
=
dict
((
a
,
set
())
for
a
in
nodes
)
comes_after
=
dict
((
a
,
set
())
for
a
in
nodes
)
comes_before
=
{
a
:
set
()
for
a
in
nodes
}
comes_after
=
{
a
:
set
()
for
a
in
nodes
}
def
add_links
(
a
,
b
):
# b depends on a
comes_after
[
a
]
.
add
(
b
)
...
...
theano/gof/toolbox.py
浏览文件 @
2e3f17cb
...
...
@@ -21,8 +21,6 @@ class AlreadyThere(Exception):
"""
pass
class
ReplacementDidntRemovedError
(
Exception
):
"""
...
...
@@ -32,8 +30,6 @@ class ReplacementDidntRemovedError(Exception):
"""
pass
class
BadOptimization
(
Exception
):
"""
...
...
@@ -103,7 +99,7 @@ class BadOptimization(Exception):
old_graph
=
None
,
new_graph
=
None
,
):
super
(
BadOptimization
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
old_r
=
old_r
self
.
new_r
=
new_r
self
.
old_r_val
=
old_r_val
...
...
@@ -139,7 +135,7 @@ class BadOptimization(Exception):
return
self
.
full_err
sio
=
StringIO
()
val_str_len_limit
=
800
print
(
"BadOptimization Error"
,
super
(
BadOptimization
,
self
)
.
__str__
(),
file
=
sio
)
print
(
"BadOptimization Error"
,
super
()
.
__str__
(),
file
=
sio
)
print
(
" Variable: id"
,
id
(
self
.
new_r
),
self
.
new_r
,
file
=
sio
)
print
(
" Op"
,
self
.
new_r
.
owner
,
file
=
sio
)
print
(
" Value Type:"
,
type
(
self
.
new_r_val
),
file
=
sio
)
...
...
@@ -225,7 +221,7 @@ class BadOptimization(Exception):
return
sio
.
getvalue
()
class
Feature
(
object
)
:
class
Feature
:
"""
Base class for FunctionGraph extensions.
...
...
@@ -466,7 +462,9 @@ class Validator(Feature):
r
=
uf
.
f_locals
.
get
(
"r"
,
""
)
reason
=
uf_info
.
function
print
(
"validate failed on node
%
s.
\n
Reason:
%
s,
%
s"
%
(
r
,
reason
,
e
)
"validate failed on node {}.
\n
Reason: {}, {}"
.
format
(
r
,
reason
,
e
)
)
raise
t1
=
time
.
time
()
...
...
@@ -578,7 +576,9 @@ class ReplaceValidate(History, Validator):
except
Exception
as
e
:
fgraph
.
revert
(
chk
)
if
verbose
:
print
(
"validate failed on node
%
s.
\n
Reason:
%
s,
%
s"
%
(
r
,
reason
,
e
))
print
(
"validate failed on node {}.
\n
Reason: {}, {}"
.
format
(
r
,
reason
,
e
)
)
raise
if
config
.
scan
.
debug
:
scans2
=
[
...
...
@@ -731,15 +731,15 @@ class PrintListener(Feature):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
(
"-- importing:
%
s, reason:
%
s"
%
(
node
,
reason
))
print
(
"-- importing:
{}, reason: {}"
.
format
(
node
,
reason
))
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
print
(
"-- pruning:
%
s, reason:
%
s"
%
(
node
,
reason
))
print
(
"-- pruning:
{}, reason: {}"
.
format
(
node
,
reason
))
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
print
(
"-- changing (
%
s.inputs[
%
s]) from
%
s to
%
s"
%
(
node
,
i
,
r
,
new_r
))
print
(
"-- changing (
{}.inputs[{}]) from {} to {}"
.
format
(
node
,
i
,
r
,
new_r
))
class
PreserveNames
(
Feature
):
...
...
@@ -918,9 +918,9 @@ def is_same_graph(var1, var2, givens=None):
for
to_replace
,
replace_by
in
givens
.
items
():
# Map a substitution variable to the computational graphs it
# belongs to.
inside
=
dict
(
(
v
,
[
in_var
(
v
,
k
)
for
k
in
(
1
,
2
)])
for
v
in
(
to_replace
,
replace_by
)
)
inside
=
{
v
:
[
in_var
(
v
,
k
)
for
k
in
(
1
,
2
)]
for
v
in
(
to_replace
,
replace_by
)
}
if
(
inside
[
to_replace
][
0
]
and
not
inside
[
to_replace
][
1
]
...
...
theano/gof/type.py
浏览文件 @
2e3f17cb
...
...
@@ -10,8 +10,6 @@ import ctypes
import
platform
import
re
from
six
import
string_types
import
theano
from
theano
import
change_flags
from
theano.gof
import
graph
,
utils
...
...
@@ -272,7 +270,7 @@ class CLinkerType(CLinkerObject):
return
()
class
PureType
(
object
)
:
class
PureType
:
"""
Interface specification for variable type instances.
...
...
@@ -707,10 +705,10 @@ class CDataType(Type):
extra_support_code
=
""
,
version
=
None
,
):
assert
isinstance
(
ctype
,
str
ing_types
)
assert
isinstance
(
ctype
,
str
)
self
.
ctype
=
ctype
if
freefunc
is
not
None
:
assert
isinstance
(
freefunc
,
str
ing_types
)
assert
isinstance
(
freefunc
,
str
)
self
.
freefunc
=
freefunc
self
.
headers
=
tuple
(
headers
)
self
.
header_dirs
=
tuple
(
header_dirs
)
...
...
@@ -848,7 +846,7 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
return
v
def
__str__
(
self
):
return
"
%
s{
%
s}"
%
(
self
.
__class__
.
__name__
,
self
.
ctype
)
return
"
{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
self
.
ctype
)
def
__setstate__
(
self
,
dct
):
self
.
__dict__
.
update
(
dct
)
...
...
@@ -1034,7 +1032,7 @@ class EnumType(Type, dict):
raise
TypeError
(
"
%
s: some aliases have same names as constants."
%
type
(
self
)
.
__name__
)
super
(
EnumType
,
self
)
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
def
fromalias
(
self
,
alias
):
"""
...
...
@@ -1060,11 +1058,11 @@ class EnumType(Type, dict):
names_to_aliases
=
{
constant_name
:
""
for
constant_name
in
self
}
for
alias
in
self
.
aliases
:
names_to_aliases
[
self
.
aliases
[
alias
]]
=
"(
%
s)"
%
alias
return
"
%
s<
%
s>(
%
s)"
%
(
return
"
{}<{}>({})"
.
format
(
type
(
self
)
.
__name__
,
self
.
ctype
,
", "
.
join
(
"
%
s
%
s:
%
s"
%
(
k
,
names_to_aliases
[
k
],
self
[
k
])
"
{}{}:{}"
.
format
(
k
,
names_to_aliases
[
k
],
self
[
k
])
for
k
in
sorted
(
self
.
keys
())
),
)
...
...
@@ -1298,7 +1296,7 @@ class EnumList(EnumType):
kwargs
.
update
(
ctype
=
ctype
)
if
cname
is
not
None
:
kwargs
.
update
(
cname
=
cname
)
super
(
EnumList
,
self
)
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
class
CEnumType
(
EnumList
):
...
...
@@ -1336,7 +1334,7 @@ class CEnumType(EnumList):
return
self
.
pyint_compat_code
+
self
.
c_to_string
()
def
c_extract
(
self
,
name
,
sub
,
check_input
=
True
):
swapped_dict
=
dict
((
v
,
k
)
for
(
k
,
v
)
in
self
.
items
())
swapped_dict
=
{
v
:
k
for
(
k
,
v
)
in
self
.
items
()}
# swapped_dict's keys are integers.
return
"""
...
...
@@ -1360,4 +1358,4 @@ class CEnumType(EnumList):
)
def
c_code_cache_version
(
self
):
return
(
1
,
super
(
CEnumType
,
self
)
.
c_code_cache_version
())
return
(
1
,
super
()
.
c_code_cache_version
())
theano/gof/unify.py
浏览文件 @
2e3f17cb
...
...
@@ -19,7 +19,7 @@ from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard
################################
class
Variable
(
object
)
:
class
Variable
:
"""
Serves as a base class of variables for the purpose of unification.
"Unification" here basically means matching two patterns, see the
...
...
@@ -46,7 +46,9 @@ class Variable(object):
return
(
self
.
__class__
.
__name__
+
"("
+
", "
.
join
(
"
%
s=
%
s"
%
(
key
,
value
)
for
key
,
value
in
self
.
__dict__
.
items
())
+
", "
.
join
(
"{}={}"
.
format
(
key
,
value
)
for
key
,
value
in
self
.
__dict__
.
items
()
)
+
")"
)
...
...
@@ -60,8 +62,6 @@ class FreeVariable(Variable):
"""
pass
class
BoundVariable
(
Variable
):
"""
...
...
@@ -70,7 +70,7 @@ class BoundVariable(Variable):
"""
def
__init__
(
self
,
name
,
value
):
super
(
BoundVariable
,
self
)
.
__init__
(
name
=
name
)
super
()
.
__init__
(
name
=
name
)
self
.
value
=
value
...
...
@@ -82,7 +82,7 @@ class OrVariable(Variable):
"""
def
__init__
(
self
,
name
,
options
):
super
(
OrVariable
,
self
)
.
__init__
(
name
=
name
)
super
()
.
__init__
(
name
=
name
)
self
.
options
=
options
...
...
@@ -94,7 +94,7 @@ class NotVariable(Variable):
"""
def
__init__
(
self
,
name
,
not_options
):
super
(
NotVariable
,
self
)
.
__init__
(
name
=
name
)
super
()
.
__init__
(
name
=
name
)
self
.
not_options
=
not_options
...
...
theano/gof/utils.py
浏览文件 @
2e3f17cb
...
...
@@ -3,7 +3,6 @@ import sys
import
traceback
import
numpy
as
np
from
six
import
integer_types
,
string_types
,
with_metaclass
from
six.moves
import
StringIO
from
theano
import
config
...
...
@@ -161,8 +160,6 @@ undef = object()
class
TestValueError
(
Exception
):
"""Base exception class for all test value errors."""
pass
class
MethodNotDefined
(
Exception
):
"""
...
...
@@ -173,8 +170,6 @@ class MethodNotDefined(Exception):
"""
pass
class
MetaObject
(
type
):
def
__new__
(
cls
,
name
,
bases
,
dct
):
...
...
@@ -182,7 +177,7 @@ class MetaObject(type):
if
props
is
not
None
:
if
not
isinstance
(
props
,
tuple
):
raise
TypeError
(
"__props__ has to be a tuple"
)
if
not
all
(
isinstance
(
p
,
str
ing_types
)
for
p
in
props
):
if
not
all
(
isinstance
(
p
,
str
)
for
p
in
props
):
raise
TypeError
(
"elements of __props__ have to be strings"
)
def
_props
(
self
):
...
...
@@ -201,7 +196,7 @@ class MetaObject(type):
least all the original props.
"""
return
dict
([(
a
,
getattr
(
self
,
a
))
for
a
in
props
])
return
{
a
:
getattr
(
self
,
a
)
for
a
in
props
}
dct
[
"_props_dict"
]
=
_props_dict
...
...
@@ -225,14 +220,16 @@ class MetaObject(type):
if
len
(
props
)
==
0
:
def
__str__
(
self
):
return
"
%
s"
%
(
self
.
__class__
.
__name__
,
)
return
"
{}"
.
format
(
self
.
__class__
.
__name__
)
else
:
def
__str__
(
self
):
return
"
%
s{
%
s}"
%
(
return
"
{}{{{}}}"
.
format
(
self
.
__class__
.
__name__
,
", "
.
join
(
"
%
s=
%
r"
%
(
p
,
getattr
(
self
,
p
))
for
p
in
props
),
", "
.
join
(
"{}={!r}"
.
format
(
p
,
getattr
(
self
,
p
))
for
p
in
props
),
)
dct
[
"__str__"
]
=
__str__
...
...
@@ -240,14 +237,14 @@ class MetaObject(type):
return
type
.
__new__
(
cls
,
name
,
bases
,
dct
)
class
object2
(
with_metaclass
(
MetaObject
,
object
)
):
class
object2
(
metaclass
=
MetaObject
):
__slots__
=
[]
def
__ne__
(
self
,
other
):
return
not
self
==
other
class
Scratchpad
(
object
)
:
class
Scratchpad
:
def
clear
(
self
):
self
.
__dict__
.
clear
()
...
...
@@ -264,7 +261,7 @@ class Scratchpad(object):
def
info
(
self
):
print
(
"<theano.gof.utils.scratchpad instance at
%
i>"
%
id
(
self
))
for
k
,
v
in
self
.
__dict__
.
items
():
print
(
"
%
s:
%
s"
%
(
k
,
v
))
print
(
"
{}: {}"
.
format
(
k
,
v
))
class
ValidatingScratchpad
(
Scratchpad
):
...
...
@@ -330,7 +327,7 @@ def deprecated(filename, msg=""):
def
g
(
*
args
,
**
kwargs
):
if
printme
[
0
]:
print
(
"WARNING:
%
s.
%
s deprecated.
%
s"
%
(
filename
,
f
.
__name__
,
msg
))
print
(
"WARNING:
{}.{} deprecated. {}"
.
format
(
filename
,
f
.
__name__
,
msg
))
printme
[
0
]
=
False
return
f
(
*
args
,
**
kwargs
)
...
...
@@ -404,7 +401,7 @@ def toposort(prereqs_d):
for
x
,
prereqs
in
prereqs_d
.
items
():
for
prereq
in
prereqs
:
postreqs_d
.
setdefault
(
prereq
,
set
())
.
add
(
x
)
next
=
set
([
k
for
k
in
prereqs_d
if
not
prereqs_d
[
k
]])
next
=
{
k
for
k
in
prereqs_d
if
not
prereqs_d
[
k
]}
while
next
:
bases
=
next
next
=
set
()
...
...
@@ -449,7 +446,7 @@ RETRY = Keyword("RETRY", False)
FAILURE
=
Keyword
(
"FAILURE"
,
False
)
simple_types
=
integer_types
+
string_types
+
(
float
,
bool
,
None
.
__class__
,
Keyword
)
simple_types
=
(
int
,
str
,
float
,
bool
,
type
(
None
)
,
Keyword
)
ANY_TYPE
=
Keyword
(
"ANY_TYPE"
)
...
...
theano/gof/vm.py
浏览文件 @
2e3f17cb
...
...
@@ -30,8 +30,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for
var
in
fgraph
.
variables
:
viewed_by
[
var
]
=
[]
view_of
=
{}
pre_allocated
=
set
(
[]
)
allocated
=
set
(
[]
)
pre_allocated
=
set
()
allocated
=
set
()
for
idx
in
range
(
len
(
order
)):
node
=
order
[
idx
]
...
...
@@ -120,7 +120,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
return
reallocated_info
class
VM
(
object
)
:
class
VM
:
"""
A VM object's __call__ method evaluates a Theano program.
...
...
@@ -273,7 +273,7 @@ class LoopGC(VM):
"""
def
__init__
(
self
,
nodes
,
thunks
,
pre_call_clear
,
post_thunk_clear
):
super
(
LoopGC
,
self
)
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
super
()
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
self
.
post_thunk_clear
=
post_thunk_clear
# Some other part of Theano query that information
self
.
allow_gc
=
True
...
...
@@ -353,7 +353,7 @@ class Stack(VM):
callback
=
None
,
callback_input
=
None
,
):
super
(
Stack
,
self
)
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
super
()
.
__init__
(
nodes
,
thunks
,
pre_call_clear
)
self
.
allow_gc
=
allow_gc
self
.
message
=
""
...
...
@@ -712,7 +712,6 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e:
assert
not
[
x
for
x
in
_config_var_list
if
x
.
fullname
==
"linker"
][
0
]
.
default
.
startswith
(
"cvm"
),
e
pass
class
VM_Linker
(
link
.
LocalLinker
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论