Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
eb6569b5
提交
eb6569b5
authored
7月 16, 2009
作者:
bergstra@tikuanyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ModuleCache works without the pkl file now, more robust to various errors
上级
ada92aea
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
374 行增加
和
67 行删除
+374
-67
__init__.py
theano/__init__.py
+10
-0
test_inplace_opt_for_value.py
theano/compile/tests/test_inplace_opt_for_value.py
+6
-1
cc.py
theano/gof/cc.py
+31
-7
cmodule.py
theano/gof/cmodule.py
+249
-59
op.py
theano/gof/op.py
+10
-0
test_cc.py
theano/gof/tests/test_cc.py
+5
-0
type.py
theano/gof/type.py
+10
-0
basic.py
theano/sparse/basic.py
+49
-0
blas.py
theano/tensor/blas.py
+4
-0
没有找到文件。
theano/__init__.py
浏览文件 @
eb6569b5
...
@@ -147,3 +147,13 @@ def dot(l, r):
...
@@ -147,3 +147,13 @@ def dot(l, r):
raise
NotImplementedError
(
"Dot failed for the following reaons:"
,
(
e0
,
e1
))
raise
NotImplementedError
(
"Dot failed for the following reaons:"
,
(
e0
,
e1
))
return
rval
return
rval
###
# Set a default logger
#
import
logging
logging_default_handler
=
logging
.
StreamHandler
()
logging
.
getLogger
(
"theano"
)
.
addHandler
(
logging_default_handler
)
logging
.
getLogger
(
"theano"
)
.
setLevel
(
logging
.
WARNING
)
theano/compile/tests/test_inplace_opt_for_value.py
浏览文件 @
eb6569b5
...
@@ -81,6 +81,11 @@ class TanhRnn(Op):
...
@@ -81,6 +81,11 @@ class TanhRnn(Op):
in which z[0] = z0.
in which z[0] = z0.
"""
"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
z0
,
A
):
def
make_node
(
self
,
x
,
z0
,
A
):
"""
"""
...
@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
...
@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
,
other
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
def
make_node
(
self
,
A
,
z
,
gz
):
def
make_node
(
self
,
A
,
z
,
gz
):
...
...
theano/gof/cc.py
浏览文件 @
eb6569b5
...
@@ -26,10 +26,10 @@ import cmodule
...
@@ -26,10 +26,10 @@ import cmodule
import
logging
import
logging
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
def
info
(
*
args
):
def
info
(
*
args
):
sys
.
stderr
.
write
(
'INFO:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
#
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
def
debug
(
*
args
):
sys
.
stderr
.
write
(
'DEBUG:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
#
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
warning
(
*
args
):
def
warning
(
*
args
):
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
sys
.
stderr
.
write
(
'WARNING:'
+
' '
.
join
(
str
(
a
)
for
a
in
args
)
+
'
\n
'
)
...
@@ -367,6 +367,7 @@ class CLinker(link.Linker):
...
@@ -367,6 +367,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order.
# The orphans field is listified to ensure a consistent order.
self
.
orphans
=
list
(
r
for
r
in
self
.
variables
if
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
)
#list(env.orphans.difference(self.outputs))
self
.
orphans
=
list
(
r
for
r
in
self
.
variables
if
isinstance
(
r
,
graph
.
Value
)
and
r
not
in
self
.
inputs
)
#list(env.orphans.difference(self.outputs))
self
.
temps
=
list
(
set
(
self
.
variables
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
temps
=
list
(
set
(
self
.
variables
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
consts
=
[]
self
.
node_order
=
env
.
toposort
()
self
.
node_order
=
env
.
toposort
()
def
code_gen
(
self
):
def
code_gen
(
self
):
...
@@ -390,7 +391,7 @@ class CLinker(link.Linker):
...
@@ -390,7 +391,7 @@ class CLinker(link.Linker):
env
=
self
.
env
env
=
self
.
env
consts
=
[]
self
.
consts
=
[]
symbol
=
{}
symbol
=
{}
...
@@ -428,7 +429,7 @@ class CLinker(link.Linker):
...
@@ -428,7 +429,7 @@ class CLinker(link.Linker):
if
isinstance
(
variable
,
graph
.
Constant
):
if
isinstance
(
variable
,
graph
.
Constant
):
try
:
try
:
symbol
[
variable
]
=
"("
+
variable
.
type
.
c_literal
(
variable
.
data
)
+
")"
symbol
[
variable
]
=
"("
+
variable
.
type
.
c_literal
(
variable
.
data
)
+
")"
consts
.
append
(
variable
)
self
.
consts
.
append
(
variable
)
self
.
orphans
.
remove
(
variable
)
self
.
orphans
.
remove
(
variable
)
continue
continue
except
(
utils
.
MethodNotDefined
,
NotImplementedError
):
except
(
utils
.
MethodNotDefined
,
NotImplementedError
):
...
@@ -530,6 +531,11 @@ class CLinker(link.Linker):
...
@@ -530,6 +531,11 @@ class CLinker(link.Linker):
self
.
tasks
=
tasks
self
.
tasks
=
tasks
all
=
self
.
inputs
+
self
.
outputs
+
self
.
orphans
all
=
self
.
inputs
+
self
.
outputs
+
self
.
orphans
if
(
self
.
init_tasks
,
self
.
tasks
)
!=
self
.
get_init_tasks
():
print
>>
sys
.
stderr
,
"init_tasks
\n
"
,
self
.
init_tasks
print
>>
sys
.
stderr
,
self
.
get_init_tasks
()[
0
]
print
>>
sys
.
stderr
,
"tasks
\n
"
,
self
.
tasks
print
>>
sys
.
stderr
,
self
.
get_init_tasks
()[
1
]
assert
(
self
.
init_tasks
,
self
.
tasks
)
==
self
.
get_init_tasks
()
assert
(
self
.
init_tasks
,
self
.
tasks
)
==
self
.
get_init_tasks
()
# List of indices that should be ignored when passing the arguments
# List of indices that should be ignored when passing the arguments
...
@@ -646,6 +652,14 @@ class CLinker(link.Linker):
...
@@ -646,6 +652,14 @@ class CLinker(link.Linker):
tasks
=
[]
tasks
=
[]
id
=
1
id
=
1
for
v
in
self
.
variables
:
for
v
in
self
.
variables
:
if
v
in
self
.
consts
:
continue
if
v
in
self
.
orphans
and
isinstance
(
v
,
graph
.
Constant
):
try
:
v
.
type
.
c_literal
(
v
.
data
)
#constant will be inlined, no need to get
continue
except
(
utils
.
MethodNotDefined
,
NotImplementedError
):
pass
init_tasks
.
append
((
v
,
'init'
,
id
))
init_tasks
.
append
((
v
,
'init'
,
id
))
tasks
.
append
((
v
,
'get'
,
id
+
1
))
tasks
.
append
((
v
,
'get'
,
id
+
1
))
id
+=
2
id
+=
2
...
@@ -687,7 +701,7 @@ class CLinker(link.Linker):
...
@@ -687,7 +701,7 @@ class CLinker(link.Linker):
The signature has the following form:
The signature has the following form:
{{{
{{{
'CLinker.cmodule_key',
'CLinker.cmodule_key',
compilation args, libraries,
op0, (input0.type, input1.type, input0 pos, input1 pos)
op0, (input0.type, input1.type, input0 pos, input1 pos)
op1, (...)
op1, (...)
...
...
...
@@ -717,6 +731,9 @@ class CLinker(link.Linker):
...
@@ -717,6 +731,9 @@ class CLinker(link.Linker):
env_computed_set
=
set
()
env_computed_set
=
set
()
op_pos
=
{}
# Apply -> topological position
op_pos
=
{}
# Apply -> topological position
rval
=
[
'CLinker.cmodule_key'
]
# will be cast to tuple on return
rval
=
[
'CLinker.cmodule_key'
]
# will be cast to tuple on return
rval
.
append
(
tuple
(
self
.
compile_args
()))
rval
.
append
(
tuple
(
self
.
libraries
()))
version
=
[]
# assert that every input to every node is one of'
# assert that every input to every node is one of'
# - an env input
# - an env input
...
@@ -735,12 +752,19 @@ class CLinker(link.Linker):
...
@@ -735,12 +752,19 @@ class CLinker(link.Linker):
return
(
op_pos
[
i
.
owner
],
i
.
owner
.
outputs
.
index
(
i
))
return
(
op_pos
[
i
.
owner
],
i
.
owner
.
outputs
.
index
(
i
))
for
opos
,
o
in
enumerate
(
order
):
for
opos
,
o
in
enumerate
(
order
):
version
.
append
(
o
.
op
.
c_code_cache_version
())
for
i
in
o
.
inputs
:
version
.
append
(
i
.
type
.
c_code_cache_version
())
for
i
in
o
.
outputs
:
version
.
append
(
i
.
type
.
c_code_cache_version
())
rval
.
append
((
o
.
op
,
tuple
((
i
.
type
,
graphpos
(
i
))
for
i
in
o
.
inputs
)))
rval
.
append
((
o
.
op
,
tuple
((
i
.
type
,
graphpos
(
i
))
for
i
in
o
.
inputs
)))
op_pos
[
o
]
=
opos
op_pos
[
o
]
=
opos
env_computed_set
.
update
(
o
.
outputs
)
env_computed_set
.
update
(
o
.
outputs
)
rval
=
tuple
(
rval
)
for
v
in
version
:
return
rval
if
not
v
:
#one of the ops or types here is unversioned
return
((),
tuple
(
rval
))
return
tuple
(
version
),
tuple
(
rval
)
def
compile_cmodule
(
self
,
location
=
None
):
def
compile_cmodule
(
self
,
location
=
None
):
"""
"""
...
...
theano/gof/cmodule.py
浏览文件 @
eb6569b5
"""Generate and compile C modules for Python
"""Generate and compile C modules for Python
,
"""
"""
import
os
,
tempfile
,
StringIO
,
sys
,
logging
,
subprocess
,
cPickle
,
atexit
import
os
,
tempfile
,
StringIO
,
sys
,
logging
,
subprocess
,
cPickle
,
atexit
,
time
,
shutil
,
stat
import
compilelock
# we will abuse the lockfile mechanism when reading and writing the registry
_logger
=
logging
.
getLogger
(
"theano.gof.cmodule"
)
_logger
=
logging
.
getLogger
(
"theano.gof.cmodule"
)
_logger
.
setLevel
(
logging
.
INFO
)
def
error
(
*
args
):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
error
(
"ERROR: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
warning
(
*
args
):
def
warning
(
*
args
):
#sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
warning
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
warning
(
"WARNING: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
info
(
*
args
):
def
info
(
*
args
):
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
info
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
info
(
"INFO: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
def
debug
(
*
args
):
def
debug
(
*
args
):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger
.
debug
(
' '
.
join
(
str
(
a
)
for
a
in
args
))
_logger
.
debug
(
"DEBUG: "
+
' '
.
join
(
str
(
a
)
for
a
in
args
))
METH_VARARGS
=
"METH_VARARGS"
METH_VARARGS
=
"METH_VARARGS"
METH_NOARGS
=
"METH_NOARGS"
METH_NOARGS
=
"METH_NOARGS"
...
@@ -156,35 +162,158 @@ def dlimport(fullpath, suffix=None):
...
@@ -156,35 +162,158 @@ def dlimport(fullpath, suffix=None):
assert
fullpath
.
startswith
(
rval
.
__file__
)
assert
fullpath
.
startswith
(
rval
.
__file__
)
return
rval
return
rval
def
last_access_time
(
path
):
"""Return the number of seconds since the epoch of the last access of a given file"""
return
os
.
stat
(
path
)[
stat
.
ST_ATIME
]
def
module_name_from_dir
(
dirname
):
"""Scan the contents of a cache directory and return full path of the dynamic lib in it.
"""
files
=
os
.
listdir
(
dirname
)
names
=
[
file
for
file
in
files
if
file
.
endswith
(
'.so'
)]
if
len
(
names
)
!=
1
:
raise
Exception
(
'Failed to load .so from dir'
,
dirname
)
return
os
.
path
.
join
(
dirname
,
names
[
0
])
class
ModuleCache
(
object
):
class
ModuleCache
(
object
):
def
__init__
(
self
,
dirname
,
force_fresh
=
False
):
"""Interface to the cache of dynamically compiled modules on disk
Note that this interface does not assume exclusive use of the cache directory.
It is built to handle the case where multiple programs are also using instances of this
class to manage the same directory.
The cache works on the basis of keys. Keys are used to uniquely identify a dynamic module.
Keys should be tuples of length 2: (version, rest)
The ``rest`` can be anything hashable and picklable, that uniquely identifies the
computation in the module.
The ``version`` should be a hierarchy of tuples of integers.
If the ``version`` is either 0 or (), then the corresponding module is unversioned, and
will be deleted in an atexit() handler.
If the ``version`` is neither 0 nor (), then the module will be kept in the cache between
processes, but it may be deleted if another key comes
along that has the same ``rest``, and a ``version`` that is considered higher than the
first one.
:todo: Versioning functionality is planned for implementation later, it is not implemented
yet.
"""
dirname
=
""
"""The working directory that is managed by this interface"""
module_from_name
=
{}
"""maps module names to loaded module objects"""
entry_from_key
=
{}
"""Maps keys to the filename of a .so
"""
stats
=
[]
"""A list with counters for the number of hits, loads, compiles issued by module_from_key()
"""
force_fresh
=
False
"""True -> Ignore previously-compiled modules
"""
loaded_key_pkl
=
set
()
"""set of all key.pkl files that have been loaded.
"""
def
__init__
(
self
,
dirname
,
force_fresh
=
None
,
check_for_broken_eq
=
True
):
"""
:param check_for_broken_eq: A bad __eq__ implemenation can break this cache mechanism.
This option turns on a not-too-expensive sanity check during the load of an old cache
file.
"""
self
.
dirname
=
dirname
self
.
dirname
=
dirname
self
.
module_from_name
=
{}
self
.
module_from_name
=
dict
(
self
.
module_from_name
)
self
.
name_from_key_filename
=
os
.
path
.
join
(
self
.
dirname
,
'module_cache.pkl'
)
self
.
entry_from_key
=
dict
(
self
.
entry_from_key
)
self
.
name_from_key
=
{}
self
.
stats
=
[
0
,
0
,
0
]
self
.
stats
=
[
0
,
0
,
0
]
self
.
force_fresh
=
self
.
force_fresh
if
force_fresh
is
None
else
force_fresh
self
.
loaded_key_pkl
=
set
()
if
not
force_fresh
:
self
.
refresh
()
if
check_for_broken_eq
:
for
k0
in
self
.
entry_from_key
:
for
k1
in
self
.
entry_from_key
:
if
k0
==
k1
and
not
(
k0
is
k1
):
warning
((
"The __eq__ and __hash__ functions are broken for some element"
" in the following two keys. The cache mechanism will say that"
" graphs like this need recompiling, when they could have been"
" retrieved):"
))
warning
(
"Key 0:"
,
k0
)
warning
(
"Key 1:"
,
k1
)
def
refresh
(
self
):
"""Update self.entry_from_key by walking the cache directory structure.
Add entries that are not in the entry_from_key dictionary.
Remove entries which have been removed from the filesystem.
"""
compilelock
.
get_lock
()
try
:
try
:
f
=
file
(
self
.
name_from_key_filename
,
'r'
)
# add entries that are not in the entry_from_key dictionary
self
.
name_from_key
=
cPickle
.
load
(
f
)
for
root
,
dirs
,
files
in
os
.
walk
(
self
.
dirname
):
debug
(
'ModuleCache loaded'
,
len
(
self
.
name_from_key
))
if
os
.
path
.
join
(
root
,
'key.pkl'
)
in
self
.
loaded_key_pkl
:
f
.
close
()
continue
except
(
IOError
,
EOFError
):
if
'key.pkl'
in
files
:
debug
(
'cache load failed. Using fresh cache'
)
key_pkl
=
os
.
path
.
join
(
root
,
'key.pkl'
)
pass
debug
(
'refresh adding'
,
key_pkl
)
try
:
key
=
cPickle
.
load
(
file
(
key_pkl
))
except
:
error
(
"ModuleCache.refresh() Failed to unpickle cache key"
,
key_pkl
)
info
(
"Erasing broken file"
,
key_pkl
)
os
.
remove
(
key_pkl
)
continue
if
key
not
in
self
.
entry_from_key
:
entry
=
module_name_from_dir
(
root
)
self
.
entry_from_key
[
key
]
=
entry
# assert that we haven't already got this entry somehow
assert
entry
not
in
self
.
module_from_name
self
.
loaded_key_pkl
.
add
(
key_pkl
)
# remove entries that are not in the filesystem
items_copy
=
list
(
self
.
entry_from_key
.
iteritems
())
for
key
,
entry
in
items_copy
:
try
:
# test to see that the file is [present and] readable
open
(
entry
)
.
close
()
gone
=
False
except
IOError
:
gone
=
True
if
gone
:
# assert that we didn't have one of the deleted files
# loaded up and in use.
# If so, it should not have been deleted. This should be considered a
# failure of the OTHER process, that deleted it.
if
entry
in
self
.
module_from_name
:
error
(
"The module
%
s that was loaded by this ModuleCache can no longer be read from file... this could lead to problems."
%
name
)
del
self
.
module_from_name
[
entry
]
info
(
"deleting ModuleCache entry"
,
entry
)
del
self
.
entry_from_key
[
key
]
self
.
loaded_key_pkl
.
remove
(
os
.
path
.
join
(
os
.
path
.
dirname
(
entry
),
'key.pkl'
))
def
persist
(
self
):
finally
:
f
=
file
(
self
.
name_from_key_filename
,
'w'
)
compilelock
.
release_lock
()
cPickle
.
dump
(
self
.
name_from_key
,
f
)
f
.
close
()
def
module_from_key
(
self
,
key
,
fn
=
None
):
def
module_from_key
(
self
,
key
,
fn
=
None
):
rval
=
None
rval
=
None
if
key
in
self
.
name_from_key
:
try
:
_version
,
_rest
=
key
except
:
raise
ValueError
(
"Invalid key. key must have form (version, rest)"
,
key
)
if
key
in
self
.
entry_from_key
:
# we have seen this key either in this process or previously
# we have seen this key either in this process or previously
#debug('OLD KEY HASH', hash(key), hash(key[1][0]), key[1][0])
#debug('OLD KEY HASH', hash(key), hash(key[1][0]), key[1][0])
name
=
self
.
name
_from_key
[
key
]
name
=
self
.
entry
_from_key
[
key
]
if
name
not
in
self
.
module_from_name
:
if
name
not
in
self
.
module_from_name
:
#debug('loading name', name)
#debug('loading name', name)
...
@@ -199,49 +328,115 @@ class ModuleCache(object):
...
@@ -199,49 +328,115 @@ class ModuleCache(object):
#debug("LOCATION*", location)
#debug("LOCATION*", location)
try
:
try
:
module
=
fn
(
location
=
location
)
# WILL FAIL FOR BAD C CODE
module
=
fn
(
location
=
location
)
# WILL FAIL FOR BAD C CODE
finally
:
except
Exception
,
e
:
# >>TODO: erase location
shutil
.
rmtree
(
location
)
pass
#try:
#except Exception, ee:
debug
(
'NEW KEY HASH'
,
hash
(
key
),
hash
(
key
[
1
][
0
]),
key
[
1
][
0
])
#error('failed to cleanup location', location, ee)
for
k
,
n
in
self
.
name_from_key
.
iteritems
():
raise
if
k
==
key
:
debug
(
"HASH OF RELOAD IS DIFFERENT"
,
hash
(
k
),
hash
(
key
))
print
''
print
hash
(
k
[
0
])
print
hash
(
key
[
0
])
print
''
print
"OLD"
,
print
hash
(
k
[
1
][
0
])
print
k
[
1
][
0
]
.
rehash
()
print
""
print
"NEW"
,
hash
(
key
[
1
][
0
]),
key
[
1
][
0
]
.
rehash
()
print
''
print
hash
(
k
[
1
][
1
])
print
hash
(
key
[
1
][
1
])
assert
k
!=
key
name
=
module
.
__file__
name
=
module
.
__file__
#debug("LOCATION**", location)
#debug("NAME**", name)
assert
name
.
startswith
(
location
)
debug
(
"Adding module to cache"
,
key
,
name
)
assert
name
.
startswith
(
location
)
assert
name
not
in
self
.
module_from_name
assert
name
not
in
self
.
module_from_name
assert
key
not
in
self
.
name_from_key
assert
key
not
in
self
.
entry_from_key
self
.
name_from_key
[
key
]
=
name
key_pkl
=
os
.
path
.
join
(
location
,
'key.pkl'
)
key_file
=
file
(
key_pkl
,
'w'
)
try
:
cPickle
.
dump
(
key
,
key_file
,
cPickle
.
HIGHEST_PROTOCOL
)
key_file
.
close
()
key_broken
=
False
except
cPickle
.
PicklingError
:
key_file
.
close
()
os
.
remove
(
key_pkl
)
warning
(
"Cache leak due to unpickle-able key"
,
key
)
key_broken
=
True
if
_version
and
not
key_broken
:
key_from_file
=
cPickle
.
load
(
file
(
key_pkl
))
if
key
!=
key_from_file
:
raise
Exception
(
"key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops"
,
(
key
,
key_from_file
))
self
.
entry_from_key
[
key
]
=
name
self
.
module_from_name
[
name
]
=
module
self
.
module_from_name
[
name
]
=
module
self
.
loaded_key_pkl
.
add
(
key_pkl
)
self
.
stats
[
2
]
+=
1
self
.
stats
[
2
]
+=
1
rval
=
module
rval
=
module
#debug('stats', self.stats, sum(self.stats))
#debug('stats', self.stats, sum(self.stats))
return
rval
return
rval
age_thresh
=
60
*
60
*
24
*
31
"""The default age threshold for `clear_old` (in seconds)
"""
def
clear_old
(
self
,
age_thresh
=
None
):
#default to a 31-day age_threshold
"""
Delete entries from the filesystem for cache entries that are too old.
:param age_thresh: dynamic modules whose last access time is more than ``age_thresh``
seconds ago will be erased.
"""
age_thresh
=
self
.
age_thresh
if
age_thresh
is
None
else
age_thresh
compilelock
.
get_lock
()
try
:
# update the age of modules that have been accessed by other processes
self
.
refresh
()
time_now
=
time
.
time
()
# the .items() is important here:
# we need to get a copy of the whole list of keys and entries
items_copy
=
list
(
self
.
entry_from_key
.
iteritems
())
for
key
,
entry
in
items_copy
:
age
=
time_now
-
last_access_time
(
entry
)
if
age
>
age_thresh
:
# TODO: we are assuming that modules that haven't been accessed in over
# age_thresh are not currently in use by other processes, but that could be
# false for long-running jobs...
assert
entry
not
in
self
.
module_from_name
del
self
.
entry_from_key
[
key
]
parent
=
os
.
path
.
dirname
(
entry
)
assert
parent
.
startswith
(
os
.
path
.
join
(
self
.
dirname
,
'tmp'
))
debug
(
"Removing cache dir"
,
parent
)
shutil
.
rmtree
(
parent
)
finally
:
compilelock
.
release_lock
()
def
clear
(
self
):
"""
Clear all the elements of the cache
"""
return
self
.
clear_old
(
-
1.0
)
def
clear_unversioned
(
self
):
"""Delete unversioned dynamic modules from the internal dictionaries and from the
filesystem.
"""
items_copy
=
list
(
self
.
entry_from_key
.
iteritems
())
for
key
,
entry
in
items_copy
:
version
,
rest
=
key
if
not
version
:
del
self
.
entry_from_key
[
key
]
# entry is guaranteed to be in this dictionary,
# because an unversioned entry should never have been loaded via refresh
assert
entry
in
self
.
module_from_name
del
self
.
module_from_name
[
entry
]
parent
=
os
.
path
.
dirname
(
entry
)
assert
parent
.
startswith
(
os
.
path
.
join
(
self
.
dirname
,
'tmp'
))
debug
(
"Removing unversioned dir"
,
parent
)
shutil
.
rmtree
(
parent
)
def
_on_atexit
(
self
):
self
.
refresh
()
self
.
clear_old
()
self
.
clear_unversioned
()
_module_cache
=
None
_module_cache
=
None
def
get_module_cache
(
dirname
):
def
get_module_cache
(
dirname
,
force_fresh
=
None
):
global
_module_cache
global
_module_cache
if
_module_cache
is
None
:
if
_module_cache
is
None
:
_module_cache
=
ModuleCache
(
dirname
,
force_fresh
=
False
)
_module_cache
=
ModuleCache
(
dirname
,
force_fresh
=
force_fresh
)
atexit
.
register
(
_module_cache
.
persis
t
)
atexit
.
register
(
_module_cache
.
_on_atexi
t
)
return
_module_cache
return
_module_cache
def
gcc_module_compile_str
(
module_name
,
src_code
,
location
=
None
,
include_dirs
=
[],
lib_dirs
=
[],
libs
=
[],
def
gcc_module_compile_str
(
module_name
,
src_code
,
location
=
None
,
include_dirs
=
[],
lib_dirs
=
[],
libs
=
[],
...
@@ -263,9 +458,10 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
...
@@ -263,9 +458,10 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
debug
(
'Writing module C++ code to'
,
cppfilename
)
debug
(
'Writing module C++ code to'
,
cppfilename
)
ofiles
=
[]
ofiles
=
[]
rval
=
None
rval
=
None
try
:
cppfile
.
write
(
src_code
)
cppfile
.
write
(
src_code
)
cppfile
.
close
()
cppfile
.
close
()
lib_filename
=
os
.
path
.
join
(
workdir
,
'
%
s.so'
%
module_name
)
lib_filename
=
os
.
path
.
join
(
workdir
,
'
%
s.so'
%
module_name
)
debug
(
'Generating shared lib'
,
lib_filename
)
debug
(
'Generating shared lib'
,
lib_filename
)
...
@@ -292,12 +488,6 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
...
@@ -292,12 +488,6 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
rval
=
dlimport
(
lib_filename
)
rval
=
dlimport
(
lib_filename
)
finally
:
warning
(
"TODO: cleanup"
)
#os.remove(cppfilename)
for
ofile
in
ofiles
:
#os.remove(ofiles[0])
pass
return
rval
return
rval
...
...
theano/gof/op.py
浏览文件 @
eb6569b5
...
@@ -162,6 +162,16 @@ class CLinkerOp(object):
...
@@ -162,6 +162,16 @@ class CLinkerOp(object):
raise
utils
.
MethodNotDefined
(
'
%
s.c_support_code'
\
raise
utils
.
MethodNotDefined
(
'
%
s.c_support_code'
\
%
self
.
__class__
.
__name__
)
%
self
.
__class__
.
__name__
)
def
c_code_cache_version
(
self
):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return
(
1
,)
class
PureOp
(
object
):
class
PureOp
(
object
):
"""
"""
An :term:`Op` is a type of operation.
An :term:`Op` is a type of operation.
...
...
theano/gof/tests/test_cc.py
浏览文件 @
eb6569b5
...
@@ -57,6 +57,9 @@ class TDouble(Type):
...
@@ -57,6 +57,9 @@ class TDouble(Type):
free(
%(name)
s_bad_thing);
free(
%(name)
s_bad_thing);
"""
%
locals
()
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
()
tdouble
=
TDouble
()
tdouble
=
TDouble
()
def
double
(
name
):
def
double
(
name
):
...
@@ -83,6 +86,8 @@ class MyOp(Op):
...
@@ -83,6 +86,8 @@ class MyOp(Op):
def
perform
(
self
,
node
,
inputs
,
(
out
,
)):
def
perform
(
self
,
node
,
inputs
,
(
out
,
)):
out
[
0
]
=
self
.
impl
(
*
inputs
)
out
[
0
]
=
self
.
impl
(
*
inputs
)
def
c_code_cache_version
(
self
):
return
()
class
Unary
(
MyOp
):
class
Unary
(
MyOp
):
...
...
theano/gof/type.py
浏览文件 @
eb6569b5
...
@@ -210,6 +210,16 @@ class CLinkerType(object):
...
@@ -210,6 +210,16 @@ class CLinkerType(object):
"""
"""
raise
MethodNotDefined
(
"c_support_code"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
MethodNotDefined
(
"c_support_code"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
c_code_cache_version
(
self
):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return
(
1
,)
class
PureType
(
object
):
class
PureType
(
object
):
"""Interface specification for variable type instances.
"""Interface specification for variable type instances.
...
...
theano/sparse/basic.py
浏览文件 @
eb6569b5
...
@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
...
@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
"""
"""
sparse_grad
=
True
sparse_grad
=
True
"""WRITEME"""
"""WRITEME"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
...
@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
...
@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
class
Transpose
(
gof
.
op
.
Op
):
class
Transpose
(
gof
.
op
.
Op
):
format_map
=
{
'csr'
:
'csc'
,
format_map
=
{
'csr'
:
'csc'
,
'csc'
:
'csr'
}
'csc'
:
'csr'
}
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
return
gof
.
Apply
(
self
,
return
gof
.
Apply
(
self
,
...
@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
...
@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
transpose
=
Transpose
()
transpose
=
Transpose
()
class
Neg
(
gof
.
op
.
Op
):
class
Neg
(
gof
.
op
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
gof
.
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
@@ -523,6 +535,10 @@ neg = Neg()
...
@@ -523,6 +535,10 @@ neg = Neg()
class
AddSS
(
gof
.
op
.
Op
):
class
AddSS
(
gof
.
op
.
Op
):
'''Add two sparse matrices '''
'''Add two sparse matrices '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
map
(
as_sparse_variable
,
[
x
,
y
])
x
,
y
=
map
(
as_sparse_variable
,
[
x
,
y
])
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
...
@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
...
@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
add_s_s
=
AddSS
()
add_s_s
=
AddSS
()
class
AddSD
(
gof
.
op
.
Op
):
class
AddSD
(
gof
.
op
.
Op
):
''' Add a sparse and a dense matrix '''
''' Add a sparse and a dense matrix '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
...
@@ -586,6 +606,10 @@ def sub(x,y):
...
@@ -586,6 +606,10 @@ def sub(x,y):
class
MulSS
(
gof
.
op
.
Op
):
class
MulSS
(
gof
.
op
.
Op
):
''' Elementwise multiply a sparse and a ndarray '''
''' Elementwise multiply a sparse and a ndarray '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
as_sparse_variable
(
x
),
as_sparse_variable
(
y
)
x
,
y
=
as_sparse_variable
(
x
),
as_sparse_variable
(
y
)
if
x
.
type
!=
y
.
type
:
if
x
.
type
!=
y
.
type
:
...
@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
...
@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
mul_s_s
=
MulSS
()
mul_s_s
=
MulSS
()
class
MulSD
(
gof
.
op
.
Op
):
class
MulSD
(
gof
.
op
.
Op
):
''' Elementwise multiply a sparse and a ndarray '''
''' Elementwise multiply a sparse and a ndarray '''
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
,
y
):
def
make_node
(
self
,
x
,
y
):
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
x
,
y
=
as_sparse_variable
(
x
),
tensor
.
as_tensor_variable
(
y
)
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
if
x
.
type
.
dtype
!=
y
.
type
.
dtype
:
...
@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
...
@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
"""
"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a
,
b
):
def
make_node
(
self
,
a
,
b
):
if
type
(
a
)
is
not
SparseVariable
and
type
(
a
)
is
not
SparseConstant
:
if
type
(
a
)
is
not
SparseVariable
and
type
(
a
)
is
not
SparseConstant
:
raise
TypeError
(
'First argument must be of type SparseVariable or SparseConstant'
);
raise
TypeError
(
'First argument must be of type SparseVariable or SparseConstant'
);
...
@@ -750,6 +782,10 @@ def structured_dot(x, y):
...
@@ -750,6 +782,10 @@ def structured_dot(x, y):
return
_structured_dot
(
y
.
T
,
x
.
T
)
.
T
return
_structured_dot
(
y
.
T
,
x
.
T
)
.
T
class
StructuredDotCSC
(
gof
.
Op
):
class
StructuredDotCSC
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
):
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
):
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
],
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
a_nrows
,
b
],
...
@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
...
@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
sd_csc
=
StructuredDotCSC
()
sd_csc
=
StructuredDotCSC
()
class
StructuredDotCSR
(
gof
.
Op
):
class
StructuredDotCSR
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
b
):
def
make_node
(
self
,
a_val
,
a_ind
,
a_ptr
,
b
):
self
.
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
self
.
dtype_out
=
scalar
.
upcast
(
a_val
.
type
.
dtype
,
b
.
type
.
dtype
)
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
b
],
r
=
gof
.
Apply
(
self
,
[
a_val
,
a_ind
,
a_ptr
,
b
],
...
@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
...
@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
class
StructuredDotGradCSC
(
gof
.
Op
):
class
StructuredDotGradCSC
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
.
tensor
(
g_ab
.
dtype
,
(
False
,))])
[
tensor
.
tensor
(
g_ab
.
dtype
,
(
False
,))])
...
@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
...
@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
class
StructuredDotGradCSR
(
gof
.
Op
):
class
StructuredDotGradCSR
(
gof
.
Op
):
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
def
make_node
(
self
,
a_indices
,
a_indptr
,
b
,
g_ab
):
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
.
tensor
(
b
.
dtype
,
(
False
,))])
return
gof
.
Apply
(
self
,
[
a_indices
,
a_indptr
,
b
,
g_ab
],
[
tensor
.
tensor
(
b
.
dtype
,
(
False
,))])
...
@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
...
@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
"""
%
dict
(
locals
(),
**
sub
)
"""
%
dict
(
locals
(),
**
sub
)
sdg_csr
=
StructuredDotGradCSR
()
sdg_csr
=
StructuredDotGradCSR
()
theano/tensor/blas.py
浏览文件 @
eb6569b5
...
@@ -49,6 +49,10 @@ class GemmRelated(Op):
...
@@ -49,6 +49,10 @@ class GemmRelated(Op):
This class provides a kind of templated gemm Op.
This class provides a kind of templated gemm Op.
"""
"""
def
__eq__
(
self
,
other
):
return
(
type
(
self
)
==
type
(
other
))
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
c_support_code
(
self
):
def
c_support_code
(
self
):
#return cblas_header_text()
#return cblas_header_text()
mod_str
=
"""
mod_str
=
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论