Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cba9c812
提交
cba9c812
authored
7月 15, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3130 from harlouci/flake8_gof
Flake8 gof
上级
a1783c0b
4bae84c6
隐藏空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
281 行增加
和
291 行删除
+281
-291
configparser.py
theano/configparser.py
+1
-1
cc.py
theano/gof/cc.py
+43
-74
cmodule.py
theano/gof/cmodule.py
+56
-54
destroyhandler.py
theano/gof/destroyhandler.py
+33
-31
fg.py
theano/gof/fg.py
+20
-21
link.py
theano/gof/link.py
+18
-16
opt.py
theano/gof/opt.py
+77
-85
utils.py
theano/gof/utils.py
+30
-0
nvcc_compiler.py
theano/sandbox/cuda/nvcc_compiler.py
+1
-1
utils.py
theano/sparse/utils.py
+1
-1
utils.py
theano/tensor/utils.py
+1
-1
test_flake8.py
theano/tests/test_flake8.py
+0
-6
没有找到文件。
theano/configparser.py
浏览文件 @
cba9c812
...
@@ -177,7 +177,7 @@ def get_config_md5():
...
@@ -177,7 +177,7 @@ def get_config_md5():
"""
"""
all_opts
=
sorted
([
c
for
c
in
_config_var_list
if
c
.
in_c_key
],
all_opts
=
sorted
([
c
for
c
in
_config_var_list
if
c
.
in_c_key
],
key
=
lambda
cv
:
cv
.
fullname
)
key
=
lambda
cv
:
cv
.
fullname
)
return
theano
.
gof
.
cc
.
hash_from_code
(
'
\n
'
.
join
(
return
theano
.
gof
.
utils
.
hash_from_code
(
'
\n
'
.
join
(
[
'
%
s =
%
s'
%
(
cv
.
fullname
,
cv
.
__get__
())
for
cv
in
all_opts
]))
[
'
%
s =
%
s'
%
(
cv
.
fullname
,
cv
.
__get__
())
for
cv
in
all_opts
]))
...
...
theano/gof/cc.py
浏览文件 @
cba9c812
"""
"""
Defines Linkers that deal with C implementations.
Defines Linkers that deal with C implementations.
"""
"""
from
__future__
import
print_function
from
__future__
import
print_function
# Python imports
# Python imports
from
copy
import
copy
from
copy
import
copy
import
os
import
os
import
sys
import
sys
from
theano.compat
import
izip
import
logging
import
numpy
import
numpy
import
theano
from
theano
import
config
from
theano.compat
import
PY3
from
theano.compat
import
PY3
from
theano.compat
import
izip
from
six
import
string_types
,
reraise
from
six
import
string_types
,
reraise
from
six.moves
import
StringIO
,
xrange
from
six.moves
import
StringIO
,
xrange
from
theano.gof.utils
import
MethodNotDefined
import
theano
from
theano
import
config
if
PY3
:
import
hashlib
def
hash_from_code
(
msg
):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if
isinstance
(
msg
,
str
):
msg
=
msg
.
encode
()
# Python 3 does not like module names that start with
# a digit.
return
'm'
+
hashlib
.
md5
(
msg
)
.
hexdigest
()
else
:
import
hashlib
def
hash_from_code
(
msg
):
try
:
return
hashlib
.
md5
(
msg
)
.
hexdigest
()
except
TypeError
:
assert
isinstance
(
msg
,
numpy
.
ndarray
)
return
hashlib
.
md5
(
numpy
.
getbuffer
(
msg
))
.
hexdigest
()
def
hash_from_file
(
file_path
):
"""Return the MD5 hash of a file."""
return
hash_from_code
(
open
(
file_path
,
'rb'
)
.
read
())
# Note that we need to do this before importing cutils, since when there is
# Note that we need to do this before importing cutils, since when there is
# no theano cache dir initialized yet, importing cutils may require compilation
# no theano cache dir initialized yet, importing cutils may require compilation
# of cutils_ext.
# of cutils_ext.
from
theano.configparser
import
AddConfigVar
,
StrParam
from
theano.configparser
import
AddConfigVar
,
StrParam
AddConfigVar
(
'gcc.cxxflags'
,
"Extra compiler flags for gcc"
,
StrParam
(
""
))
# gof imports
# gof imports
from
theano.gof
import
graph
from
theano.gof
import
graph
from
theano.gof
import
link
from
theano.gof
import
link
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof
import
cmodule
from
theano.gof.compilelock
import
get_lock
,
release_lock
from
theano.gof.compilelock
import
get_lock
,
release_lock
from
theano.gof.callcache
import
CallCache
from
theano.gof
import
cmodule
AddConfigVar
(
'gcc.cxxflags'
,
"Extra compiler flags for gcc"
,
StrParam
(
""
))
import
logging
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
_logger
=
logging
.
getLogger
(
"theano.gof.cc"
)
from
theano.gof.callcache
import
CallCache
run_cthunk
=
None
# Will be imported only when needed.
run_cthunk
=
None
# Will be imported only when needed.
...
@@ -314,9 +284,8 @@ def get_c_declare(r, name, sub):
...
@@ -314,9 +284,8 @@ def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name"""
"""Wrapper around c_declare that declares py_name"""
if
any
([
c
!=
"output"
and
getattr
(
c
.
op
,
'check_input'
,
if
any
([
c
!=
"output"
and
getattr
(
c
.
op
,
'check_input'
,
config
.
check_input
)
for
(
c
,
_
)
in
r
.
clients
])
or
(
r
.
owner
config
.
check_input
)
for
(
c
,
_
)
in
r
.
clients
])
or
(
and
getattr
(
r
.
owner
.
op
,
'check_input'
,
True
)):
r
.
owner
and
getattr
(
r
.
owner
.
op
,
'check_input'
,
True
)):
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
True
)
else
:
else
:
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
False
)
c_declare
=
r
.
type
.
c_declare
(
name
,
sub
,
False
)
...
@@ -532,7 +501,7 @@ class CLinker(link.Linker):
...
@@ -532,7 +501,7 @@ class CLinker(link.Linker):
if
isinstance
(
r
,
graph
.
Constant
)
and
if
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
)
r
not
in
self
.
inputs
)
self
.
temps
=
list
(
set
(
self
.
variables
)
.
difference
(
self
.
temps
=
list
(
set
(
self
.
variables
)
.
difference
(
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
inputs
)
.
difference
(
self
.
outputs
)
.
difference
(
self
.
orphans
))
self
.
consts
=
[]
self
.
consts
=
[]
def
code_gen
(
self
):
def
code_gen
(
self
):
...
@@ -821,7 +790,7 @@ class CLinker(link.Linker):
...
@@ -821,7 +790,7 @@ class CLinker(link.Linker):
ret
=
[]
ret
=
[]
# generic support code
# generic support code
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
.
append
(
x
.
c_support_code
())
ret
.
append
(
x
.
c_support_code
())
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -840,11 +809,11 @@ class CLinker(link.Linker):
...
@@ -840,11 +809,11 @@ class CLinker(link.Linker):
# FillMissing must disable some of them. Putting -ffast-math would
# FillMissing must disable some of them. Putting -ffast-math would
# make it disable all other parameter at the same time.
# make it disable all other parameter at the same time.
ret
+=
[
"-fno-math-errno"
,
ret
+=
[
"-fno-math-errno"
,
#"-funsafe-math-optimizations",
#
"-funsafe-math-optimizations",
#"-fno-signaling-nans",
#
"-fno-signaling-nans",
#"-fcx-limited-range",
#
"-fcx-limited-range",
#"-fno-rounding-math",
#
"-fno-rounding-math",
#"-ffinite-math-only",
#
"-ffinite-math-only",
# the current code generate label event if they are not used.
# the current code generate label event if they are not used.
# Could use gcc attribute for those label only
# Could use gcc attribute for those label only
...
@@ -853,7 +822,7 @@ class CLinker(link.Linker):
...
@@ -853,7 +822,7 @@ class CLinker(link.Linker):
"-Wno-write-strings"
,
# generated by our code generator...
"-Wno-write-strings"
,
# generated by our code generator...
]
]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
+=
x
.
c_compile_args
()
ret
+=
x
.
c_compile_args
()
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -866,7 +835,7 @@ class CLinker(link.Linker):
...
@@ -866,7 +835,7 @@ class CLinker(link.Linker):
# to reorder them
# to reorder them
ret
+=
c_compiler
.
compile_args
()
ret
+=
c_compiler
.
compile_args
()
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
for
i
in
x
.
c_no_compile_args
():
for
i
in
x
.
c_no_compile_args
():
try
:
try
:
...
@@ -886,7 +855,7 @@ class CLinker(link.Linker):
...
@@ -886,7 +855,7 @@ class CLinker(link.Linker):
"""
"""
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
+=
x
.
c_headers
()
ret
+=
x
.
c_headers
()
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -901,7 +870,7 @@ class CLinker(link.Linker):
...
@@ -901,7 +870,7 @@ class CLinker(link.Linker):
"""
"""
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
+=
x
.
c_init_code
()
ret
+=
x
.
c_init_code
()
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -911,7 +880,7 @@ class CLinker(link.Linker):
...
@@ -911,7 +880,7 @@ class CLinker(link.Linker):
def
c_compiler
(
self
):
def
c_compiler
(
self
):
c_compiler
=
None
c_compiler
=
None
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
if
hasattr
(
x
,
'c_compiler'
):
if
hasattr
(
x
,
'c_compiler'
):
x_compiler
=
x
.
c_compiler
()
x_compiler
=
x
.
c_compiler
()
else
:
else
:
...
@@ -938,7 +907,7 @@ class CLinker(link.Linker):
...
@@ -938,7 +907,7 @@ class CLinker(link.Linker):
"""
"""
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
+=
x
.
c_header_dirs
()
ret
+=
x
.
c_header_dirs
()
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -954,7 +923,7 @@ class CLinker(link.Linker):
...
@@ -954,7 +923,7 @@ class CLinker(link.Linker):
"""
"""
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
+=
x
.
c_libraries
()
ret
+=
x
.
c_libraries
()
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -970,7 +939,7 @@ class CLinker(link.Linker):
...
@@ -970,7 +939,7 @@ class CLinker(link.Linker):
"""
"""
ret
=
[]
ret
=
[]
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
for
x
in
[
y
.
type
for
y
in
self
.
variables
]
+
[
y
.
op
for
y
in
self
.
node_order
]:
y
.
op
for
y
in
self
.
node_order
]:
try
:
try
:
ret
+=
x
.
c_lib_dirs
()
ret
+=
x
.
c_lib_dirs
()
except
utils
.
MethodNotDefined
:
except
utils
.
MethodNotDefined
:
...
@@ -1150,7 +1119,7 @@ class CLinker(link.Linker):
...
@@ -1150,7 +1119,7 @@ class CLinker(link.Linker):
libraries
=
self
.
libraries
(),
libraries
=
self
.
libraries
(),
header_dirs
=
self
.
header_dirs
(),
header_dirs
=
self
.
header_dirs
(),
c_compiler
=
self
.
c_compiler
(),
c_compiler
=
self
.
c_compiler
(),
)
)
def
cmodule_key_
(
self
,
fgraph
,
no_recycling
,
compile_args
=
None
,
def
cmodule_key_
(
self
,
fgraph
,
no_recycling
,
compile_args
=
None
,
libraries
=
None
,
header_dirs
=
None
,
insert_config_md5
=
True
,
libraries
=
None
,
header_dirs
=
None
,
insert_config_md5
=
True
,
...
@@ -1335,7 +1304,6 @@ class CLinker(link.Linker):
...
@@ -1335,7 +1304,6 @@ class CLinker(link.Linker):
preargs
.
remove
(
'-DREPLACE_WITH_AMDLIBM'
)
preargs
.
remove
(
'-DREPLACE_WITH_AMDLIBM'
)
if
'amdlibm'
in
libs
:
if
'amdlibm'
in
libs
:
libs
.
remove
(
'amdlibm'
)
libs
.
remove
(
'amdlibm'
)
src_code
=
mod
.
code
()
get_lock
()
get_lock
()
try
:
try
:
_logger
.
debug
(
"LOCATION
%
s"
,
str
(
location
))
_logger
.
debug
(
"LOCATION
%
s"
,
str
(
location
))
...
@@ -1371,9 +1339,9 @@ class CLinker(link.Linker):
...
@@ -1371,9 +1339,9 @@ class CLinker(link.Linker):
code
=
self
.
instantiate_code
(
1
+
len
(
self
.
args
))
code
=
self
.
instantiate_code
(
1
+
len
(
self
.
args
))
instantiate
=
cmodule
.
ExtFunction
(
'instantiate'
,
code
,
instantiate
=
cmodule
.
ExtFunction
(
'instantiate'
,
code
,
method
=
cmodule
.
METH_VARARGS
)
method
=
cmodule
.
METH_VARARGS
)
#
['error_storage'] + argnames,
#
['error_storage'] + argnames,
#
local_dict = d,
#
local_dict = d,
# global_dict = {})
# global_dict = {})
# Static methods that can run and destroy the struct built by
# Static methods that can run and destroy the struct built by
# instantiate.
# instantiate.
...
@@ -1498,7 +1466,7 @@ class _CThunk(object):
...
@@ -1498,7 +1466,7 @@ class _CThunk(object):
global
run_cthunk
global
run_cthunk
if
run_cthunk
is
None
:
if
run_cthunk
is
None
:
# Lazy import to avoid compilation when importing theano.
# Lazy import to avoid compilation when importing theano.
from
theano.gof.cutils
import
run_cthunk
from
theano.gof.cutils
import
run_cthunk
# noqa
self
.
cthunk
=
cthunk
self
.
cthunk
=
cthunk
self
.
init_tasks
=
init_tasks
self
.
init_tasks
=
init_tasks
self
.
tasks
=
tasks
self
.
tasks
=
tasks
...
@@ -1534,7 +1502,8 @@ class _CThunk(object):
...
@@ -1534,7 +1502,8 @@ class _CThunk(object):
exc_value
.
__thunk_trace__
=
trace
exc_value
.
__thunk_trace__
=
trace
except
Exception
:
except
Exception
:
print
((
'ERROR retrieving error_storage.'
print
((
'ERROR retrieving error_storage.'
' Was the error set in the c code?'
),
end
=
' '
,
file
=
sys
.
stderr
)
'Was the error set in the c code?'
),
end
=
' '
,
file
=
sys
.
stderr
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
raise
raise
reraise
(
exc_type
,
exc_value
,
exc_trace
)
reraise
(
exc_type
,
exc_value
,
exc_trace
)
...
@@ -1641,11 +1610,11 @@ class OpWiseCLinker(link.LocalLinker):
...
@@ -1641,11 +1610,11 @@ class OpWiseCLinker(link.LocalLinker):
for
node
in
order
:
for
node
in
order
:
if
self
.
allow_gc
:
if
self
.
allow_gc
:
post_thunk_old_storage
.
append
(
[
storage_map
[
input
]
post_thunk_old_storage
.
append
(
for
input
in
node
.
inputs
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
((
input
in
computed
)
and
if
((
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
input
not
in
fgraph
.
outputs
)
and
node
==
last_user
[
input
])])
node
==
last_user
[
input
])])
if
no_recycling
is
True
:
if
no_recycling
is
True
:
no_recycling
=
list
(
storage_map
.
values
())
no_recycling
=
list
(
storage_map
.
values
())
...
@@ -1741,12 +1710,12 @@ class DualLinker(link.Linker):
...
@@ -1741,12 +1710,12 @@ class DualLinker(link.Linker):
no_recycling
=
self
.
no_recycling
no_recycling
=
self
.
no_recycling
_f
,
i1
,
o1
,
thunks1
,
order1
=
(
_f
,
i1
,
o1
,
thunks1
,
order1
=
(
link
.
PerformLinker
(
schedule
=
self
.
schedule
)
.
accept
(
fgraph
,
link
.
PerformLinker
(
schedule
=
self
.
schedule
)
.
accept
(
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
))
fgraph
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
))
kwargs
.
pop
(
'input_storage'
,
None
)
kwargs
.
pop
(
'input_storage'
,
None
)
_f
,
i2
,
o2
,
thunks2
,
order2
=
(
_f
,
i2
,
o2
,
thunks2
,
order2
=
(
OpWiseCLinker
(
schedule
=
self
.
schedule
)
.
accept
(
fgraph
,
OpWiseCLinker
(
schedule
=
self
.
schedule
)
.
accept
(
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
))
fgraph
,
no_recycling
=
no_recycling
)
.
make_all
(
**
kwargs
))
def
f
():
def
f
():
for
input1
,
input2
in
izip
(
i1
,
i2
):
for
input1
,
input2
in
izip
(
i1
,
i2
):
...
...
theano/gof/cmodule.py
浏览文件 @
cba9c812
"""Generate and compile C modules for Python,
"""Generate and compile C modules for Python,
"""
"""
from
__future__
import
print_function
from
__future__
import
print_function
import
atexit
import
atexit
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
import
logging
import
logging
...
@@ -15,12 +17,6 @@ import time
...
@@ -15,12 +17,6 @@ import time
import
platform
import
platform
import
distutils.sysconfig
import
distutils.sysconfig
importlib
=
None
try
:
import
importlib
except
ImportError
:
pass
import
numpy.distutils
# TODO: TensorType should handle this
import
numpy.distutils
# TODO: TensorType should handle this
import
theano
import
theano
...
@@ -28,7 +24,7 @@ from theano.compat import PY3, decode, decode_iter
...
@@ -28,7 +24,7 @@ from theano.compat import PY3, decode, decode_iter
from
six
import
b
,
BytesIO
,
StringIO
,
string_types
,
iteritems
from
six
import
b
,
BytesIO
,
StringIO
,
string_types
,
iteritems
from
theano.gof.utils
import
flatten
from
theano.gof.utils
import
flatten
from
theano.configparser
import
config
from
theano.configparser
import
config
from
theano.gof.
cc
import
hash_from_code
from
theano.gof.
utils
import
hash_from_code
from
theano.misc.windows
import
(
subprocess_Popen
,
from
theano.misc.windows
import
(
subprocess_Popen
,
output_subprocess_Popen
)
output_subprocess_Popen
)
...
@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
...
@@ -38,10 +34,17 @@ from theano.gof.compiledir import gcc_version_str, local_bitwidth
from
theano.configparser
import
AddConfigVar
,
BoolParam
from
theano.configparser
import
AddConfigVar
,
BoolParam
AddConfigVar
(
'cmodule.mac_framework_link'
,
importlib
=
None
"If set to True, breaks certain MacOS installations with the infamous "
try
:
"Bus Error"
,
import
importlib
BoolParam
(
False
))
except
ImportError
:
pass
AddConfigVar
(
'cmodule.mac_framework_link'
,
"If set to True, breaks certain MacOS installations with the infamous "
"Bus Error"
,
BoolParam
(
False
))
AddConfigVar
(
'cmodule.warn_no_version'
,
AddConfigVar
(
'cmodule.warn_no_version'
,
"If True, will print a warning when compiling one or more Op "
"If True, will print a warning when compiling one or more Op "
...
@@ -131,15 +134,16 @@ class ExtFunction(object):
...
@@ -131,15 +134,16 @@ class ExtFunction(object):
It goes into the DynamicModule's method table.
It goes into the DynamicModule's method table.
"""
"""
return
'
\t
{"
%
s",
%
s,
%
s, "
%
s"}'
%
(
return
'
\t
{"
%
s",
%
s,
%
s, "
%
s"}'
%
(
self
.
name
,
self
.
name
,
self
.
method
,
self
.
doc
)
self
.
name
,
self
.
name
,
self
.
method
,
self
.
doc
)
class
DynamicModule
(
object
):
class
DynamicModule
(
object
):
def
__init__
(
self
,
name
=
None
):
def
__init__
(
self
,
name
=
None
):
assert
name
is
None
,
(
"The 'name' parameter of DynamicModule"
assert
name
is
None
,
(
" cannot be specified anymore. Instead, 'code_hash'"
"The 'name' parameter of DynamicModule"
" will be automatically computed and can be used as"
" cannot be specified anymore. Instead, 'code_hash'"
" the module's name."
)
" will be automatically computed and can be used as"
" the module's name."
)
# While the module is not finalized, we can call add_...
# While the module is not finalized, we can call add_...
# when it is finalized, a hash is computed and used instead of
# when it is finalized, a hash is computed and used instead of
# the placeholder, and as module name.
# the placeholder, and as module name.
...
@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{
...
@@ -171,18 +175,18 @@ static struct PyModuleDef moduledef = {{
}};
}};
"""
.
format
(
name
=
self
.
hash_placeholder
),
file
=
stream
)
"""
.
format
(
name
=
self
.
hash_placeholder
),
file
=
stream
)
print
((
"PyMODINIT_FUNC PyInit_
%
s(void) {"
%
print
((
"PyMODINIT_FUNC PyInit_
%
s(void) {"
%
self
.
hash_placeholder
),
file
=
stream
)
self
.
hash_placeholder
),
file
=
stream
)
for
block
in
self
.
init_blocks
:
for
block
in
self
.
init_blocks
:
print
(
' '
,
block
,
file
=
stream
)
print
(
' '
,
block
,
file
=
stream
)
print
(
" PyObject *m = PyModule_Create(&moduledef);"
,
file
=
stream
)
print
(
" PyObject *m = PyModule_Create(&moduledef);"
,
file
=
stream
)
print
(
" return m;"
,
file
=
stream
)
print
(
" return m;"
,
file
=
stream
)
else
:
else
:
print
((
"PyMODINIT_FUNC init
%
s(void){"
%
print
((
"PyMODINIT_FUNC init
%
s(void){"
%
self
.
hash_placeholder
),
file
=
stream
)
self
.
hash_placeholder
),
file
=
stream
)
for
block
in
self
.
init_blocks
:
for
block
in
self
.
init_blocks
:
print
(
' '
,
block
,
file
=
stream
)
print
(
' '
,
block
,
file
=
stream
)
print
(
' '
,
(
'(void) Py_InitModule("
%
s", MyMethods);'
print
(
' '
,
(
'(void) Py_InitModule("
%
s", MyMethods);'
%
self
.
hash_placeholder
),
file
=
stream
)
%
self
.
hash_placeholder
),
file
=
stream
)
print
(
"}"
,
file
=
stream
)
print
(
"}"
,
file
=
stream
)
def
add_include
(
self
,
str
):
def
add_include
(
self
,
str
):
...
@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2):
...
@@ -351,9 +355,9 @@ def is_same_entry(entry_1, entry_2):
if
os
.
path
.
realpath
(
entry_1
)
==
os
.
path
.
realpath
(
entry_2
):
if
os
.
path
.
realpath
(
entry_1
)
==
os
.
path
.
realpath
(
entry_2
):
return
True
return
True
if
(
os
.
path
.
basename
(
entry_1
)
==
os
.
path
.
basename
(
entry_2
)
and
if
(
os
.
path
.
basename
(
entry_1
)
==
os
.
path
.
basename
(
entry_2
)
and
(
os
.
path
.
basename
(
os
.
path
.
dirname
(
entry_1
))
==
(
os
.
path
.
basename
(
os
.
path
.
dirname
(
entry_1
))
==
os
.
path
.
basename
(
os
.
path
.
dirname
(
entry_2
)))
and
os
.
path
.
basename
(
os
.
path
.
dirname
(
entry_2
)))
and
os
.
path
.
basename
(
os
.
path
.
dirname
(
entry_1
))
.
startswith
(
'tmp'
)):
os
.
path
.
basename
(
os
.
path
.
dirname
(
entry_1
))
.
startswith
(
'tmp'
)):
return
True
return
True
return
False
return
False
...
@@ -429,8 +433,8 @@ def get_safe_part(key):
...
@@ -429,8 +433,8 @@ def get_safe_part(key):
# Find the md5 hash part.
# Find the md5 hash part.
c_link_key
=
key
[
1
]
c_link_key
=
key
[
1
]
for
key_element
in
c_link_key
[
1
:]:
for
key_element
in
c_link_key
[
1
:]:
if
(
isinstance
(
key_element
,
string_types
)
if
(
isinstance
(
key_element
,
string_types
)
and
and
key_element
.
startswith
(
'md5:'
)):
key_element
.
startswith
(
'md5:'
)):
md5
=
key_element
[
4
:]
md5
=
key_element
[
4
:]
break
break
...
@@ -761,9 +765,9 @@ class ModuleCache(object):
...
@@ -761,9 +765,9 @@ class ModuleCache(object):
# simpler to implement).
# simpler to implement).
rmtree
(
root
,
ignore_nocleanup
=
True
,
rmtree
(
root
,
ignore_nocleanup
=
True
,
msg
=
(
msg
=
(
'invalid cache entry format -- this '
'invalid cache entry format -- this '
'should not happen unless your cache '
'should not happen unless your cache '
'was really old'
),
'was really old'
),
level
=
logging
.
WARN
)
level
=
logging
.
WARN
)
continue
continue
...
@@ -964,7 +968,7 @@ class ModuleCache(object):
...
@@ -964,7 +968,7 @@ class ModuleCache(object):
# process that could be changing the file at the same
# process that could be changing the file at the same
# time.
# time.
if
(
key
[
0
]
and
not
key_broken
and
if
(
key
[
0
]
and
not
key_broken
and
self
.
check_for_broken_eq
):
self
.
check_for_broken_eq
):
self
.
check_key
(
key
,
key_data
.
key_pkl
)
self
.
check_key
(
key
,
key_data
.
key_pkl
)
self
.
_update_mappings
(
key
,
key_data
,
module
.
__file__
,
check_in_keys
=
not
key_broken
)
self
.
_update_mappings
(
key
,
key_data
,
module
.
__file__
,
check_in_keys
=
not
key_broken
)
return
module
return
module
...
@@ -1149,15 +1153,14 @@ class ModuleCache(object):
...
@@ -1149,15 +1153,14 @@ class ModuleCache(object):
# This is to make debugging in pdb easier, by providing
# This is to make debugging in pdb easier, by providing
# the offending keys in the local context.
# the offending keys in the local context.
# key_data_keys = list(key_data.keys)
# key_data_keys = list(key_data.keys)
#
#
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
pass
pass
elif
found
>
1
:
elif
found
>
1
:
msg
=
'Multiple equal keys found in unpickled KeyData file'
msg
=
'Multiple equal keys found in unpickled KeyData file'
if
msg
:
if
msg
:
raise
AssertionError
(
raise
AssertionError
(
"
%
s. Verify the __eq__ and __hash__ functions of your "
"
%
s. Verify the __eq__ and __hash__ functions of your "
"Ops. The file is:
%
s. The key is:
%
s"
%
"Ops. The file is:
%
s. The key is:
%
s"
%
(
msg
,
key_pkl
,
key
))
(
msg
,
key_pkl
,
key
))
# Also verify that there exists no other loaded key that would be equal
# Also verify that there exists no other loaded key that would be equal
# to this key. In order to speed things up, we only compare to keys
# to this key. In order to speed things up, we only compare to keys
# with the same version part and config md5, since we can assume this
# with the same version part and config md5, since we can assume this
...
@@ -1195,10 +1198,10 @@ class ModuleCache(object):
...
@@ -1195,10 +1198,10 @@ class ModuleCache(object):
if
age_thresh_del
<
self
.
age_thresh_use
:
if
age_thresh_del
<
self
.
age_thresh_use
:
if
age_thresh_del
>
0
:
if
age_thresh_del
>
0
:
_logger
.
warning
(
"Clearing modules that were not deemed "
_logger
.
warning
(
"Clearing modules that were not deemed "
"too old to use: age_thresh_del=
%
d, "
"too old to use: age_thresh_del=
%
d, "
"self.age_thresh_use=
%
d"
,
"self.age_thresh_use=
%
d"
,
age_thresh_del
,
age_thresh_del
,
self
.
age_thresh_use
)
self
.
age_thresh_use
)
else
:
else
:
_logger
.
info
(
"Clearing all modules."
)
_logger
.
info
(
"Clearing all modules."
)
age_thresh_use
=
age_thresh_del
age_thresh_use
=
age_thresh_del
...
@@ -1210,8 +1213,8 @@ class ModuleCache(object):
...
@@ -1210,8 +1213,8 @@ class ModuleCache(object):
# processes and get all module that are too old to use
# processes and get all module that are too old to use
# (not loaded in self.entry_from_key).
# (not loaded in self.entry_from_key).
too_old_to_use
=
self
.
refresh
(
too_old_to_use
=
self
.
refresh
(
age_thresh_use
=
age_thresh_use
,
age_thresh_use
=
age_thresh_use
,
delete_if_problem
=
delete_if_problem
)
delete_if_problem
=
delete_if_problem
)
for
entry
in
too_old_to_use
:
for
entry
in
too_old_to_use
:
# TODO: we are assuming that modules that haven't been
# TODO: we are assuming that modules that haven't been
...
@@ -1242,8 +1245,8 @@ class ModuleCache(object):
...
@@ -1242,8 +1245,8 @@ class ModuleCache(object):
"""
"""
with
compilelock
.
lock_ctx
():
with
compilelock
.
lock_ctx
():
self
.
clear_old
(
self
.
clear_old
(
age_thresh_del
=-
1.0
,
age_thresh_del
=-
1.0
,
delete_if_problem
=
delete_if_problem
)
delete_if_problem
=
delete_if_problem
)
self
.
clear_unversioned
(
min_age
=
unversioned_min_age
)
self
.
clear_unversioned
(
min_age
=
unversioned_min_age
)
if
clear_base_files
:
if
clear_base_files
:
self
.
clear_base_files
()
self
.
clear_base_files
()
...
@@ -1333,7 +1336,7 @@ class ModuleCache(object):
...
@@ -1333,7 +1336,7 @@ class ModuleCache(object):
if
filename
.
startswith
(
'tmp'
):
if
filename
.
startswith
(
'tmp'
):
try
:
try
:
open
(
os
.
path
.
join
(
self
.
dirname
,
filename
,
'key.pkl'
)
open
(
os
.
path
.
join
(
self
.
dirname
,
filename
,
'key.pkl'
)
)
.
close
()
)
.
close
()
has_key
=
True
has_key
=
True
except
IOError
:
except
IOError
:
has_key
=
False
has_key
=
False
...
@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None):
...
@@ -1420,8 +1423,8 @@ def get_module_cache(dirname, init_args=None):
'was created prior to this call'
)
'was created prior to this call'
)
if
_module_cache
.
dirname
!=
dirname
:
if
_module_cache
.
dirname
!=
dirname
:
_logger
.
warning
(
"Returning module cache instance with different "
_logger
.
warning
(
"Returning module cache instance with different "
"dirname (
%
s) than you requested (
%
s)"
,
"dirname (
%
s) than you requested (
%
s)"
,
_module_cache
.
dirname
,
dirname
)
_module_cache
.
dirname
,
dirname
)
return
_module_cache
return
_module_cache
...
@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler):
...
@@ -1685,7 +1688,7 @@ class GCC_compiler(Compiler):
break
break
if
(
'g++'
not
in
theano
.
config
.
cxx
and
if
(
'g++'
not
in
theano
.
config
.
cxx
and
'clang++'
not
in
theano
.
config
.
cxx
):
'clang++'
not
in
theano
.
config
.
cxx
):
_logger
.
warn
(
_logger
.
warn
(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization"
" the g++ compiler. So we disable the compiler optimization"
...
@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler):
...
@@ -1719,9 +1722,9 @@ class GCC_compiler(Compiler):
selected_lines
=
[]
selected_lines
=
[]
for
line
in
lines
:
for
line
in
lines
:
if
(
"COLLECT_GCC_OPTIONS="
in
line
or
if
(
"COLLECT_GCC_OPTIONS="
in
line
or
"CFLAGS="
in
line
or
"CFLAGS="
in
line
or
"CXXFLAGS="
in
line
or
"CXXFLAGS="
in
line
or
"-march=native"
in
line
):
"-march=native"
in
line
):
continue
continue
elif
"-march="
in
line
:
elif
"-march="
in
line
:
selected_lines
.
append
(
line
.
strip
())
selected_lines
.
append
(
line
.
strip
())
...
@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
...
@@ -1805,9 +1808,9 @@ class GCC_compiler(Compiler):
for
line
in
default_lines
:
for
line
in
default_lines
:
if
line
.
startswith
(
part
[
0
]):
if
line
.
startswith
(
part
[
0
]):
part2
=
[
p
for
p
in
join_options
(
line
.
split
())
part2
=
[
p
for
p
in
join_options
(
line
.
split
())
if
(
not
'march'
in
p
and
if
(
'march'
not
in
p
and
not
'mtune'
in
p
and
'mtune'
not
in
p
and
not
'target-cpu'
in
p
)]
'target-cpu'
not
in
p
)]
new_flags
=
[
p
for
p
in
part
if
p
not
in
part2
]
new_flags
=
[
p
for
p
in
part
if
p
not
in
part2
]
# Replace '-target-cpu value', which is an option
# Replace '-target-cpu value', which is an option
# of clang, with '-march=value', for g++
# of clang, with '-march=value', for g++
...
@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler):
...
@@ -2021,14 +2024,13 @@ class GCC_compiler(Compiler):
cmd
.
append
(
cppfilename
)
cmd
.
append
(
cppfilename
)
cmd
.
extend
([
'-L
%
s'
%
ldir
for
ldir
in
lib_dirs
])
cmd
.
extend
([
'-L
%
s'
%
ldir
for
ldir
in
lib_dirs
])
cmd
.
extend
([
'-l
%
s'
%
l
for
l
in
libs
])
cmd
.
extend
([
'-l
%
s'
%
l
for
l
in
libs
])
#print >> sys.stderr, 'COMPILING W CMD', cmd
#
print >> sys.stderr, 'COMPILING W CMD', cmd
_logger
.
debug
(
'Running cmd:
%
s'
,
' '
.
join
(
cmd
))
_logger
.
debug
(
'Running cmd:
%
s'
,
' '
.
join
(
cmd
))
def
print_command_line_error
():
def
print_command_line_error
():
# Print command line when a problem occurred.
# Print command line when a problem occurred.
print
((
print
((
"Problem occurred during compilation with the "
"Problem occurred during compilation with the "
"command line below:"
),
file
=
sys
.
stderr
)
"command line below:"
),
file
=
sys
.
stderr
)
print
(
' '
.
join
(
cmd
),
file
=
sys
.
stderr
)
print
(
' '
.
join
(
cmd
),
file
=
sys
.
stderr
)
try
:
try
:
...
...
theano/gof/destroyhandler.py
浏览文件 @
cba9c812
...
@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
...
@@ -46,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
"""
"""
# These are lists of Variable instances
# These are lists of Variable instances
inputs
=
fgraph
.
inputs
outputs
=
fgraph
.
outputs
outputs
=
fgraph
.
outputs
# this is hard-coded reimplementation of functions from graph.py
# this is hard-coded reimplementation of functions from graph.py
...
@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
...
@@ -65,8 +64,6 @@ def _contains_cycle(fgraph, orderings):
# (defaultdict runs faster than dict in the case where the key
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
# is not in the dictionary, at least in CPython)
iset
=
set
(
inputs
)
# IG: I tried converting parent_counts to use an id for the key,
# IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys.
# so that the dict would do reference counting on its keys.
# This caused a slowdown.
# This caused a slowdown.
...
@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
...
@@ -236,9 +233,9 @@ def fast_inplace_check(inputs):
protected_inputs
.
extend
(
fgraph
.
outputs
)
protected_inputs
.
extend
(
fgraph
.
outputs
)
inputs
=
[
i
for
i
in
inputs
if
inputs
=
[
i
for
i
in
inputs
if
not
isinstance
(
i
,
graph
.
Constant
)
not
isinstance
(
i
,
graph
.
Constant
)
and
and
not
fgraph
.
destroyers
(
i
)
not
fgraph
.
destroyers
(
i
)
and
and
i
not
in
protected_inputs
]
i
not
in
protected_inputs
]
return
inputs
return
inputs
if
0
:
if
0
:
...
@@ -293,7 +290,7 @@ if 0:
...
@@ -293,7 +290,7 @@ if 0:
TODO: WRITEME: what does this do besides the checks?
TODO: WRITEME: what does this do besides the checks?
"""
"""
#
###### Do the checking ##########
#
#
Do the checking
#
already_there
=
False
already_there
=
False
if
self
.
fgraph
not
in
[
None
,
fgraph
]:
if
self
.
fgraph
not
in
[
None
,
fgraph
]:
raise
Exception
(
"A DestroyHandler instance can only serve"
raise
Exception
(
"A DestroyHandler instance can only serve"
...
@@ -309,7 +306,7 @@ if 0:
...
@@ -309,7 +306,7 @@ if 0:
"DestroyHandler feature is already present or in"
"DestroyHandler feature is already present or in"
" conflict with another plugin."
)
" conflict with another plugin."
)
#
###### end of checking ###########
#
#
end of checking
#
def
get_destroyers_of
(
r
):
def
get_destroyers_of
(
r
):
droot
,
impact
,
root_destroyer
=
self
.
refresh_droot_impact
()
droot
,
impact
,
root_destroyer
=
self
.
refresh_droot_impact
()
...
@@ -362,8 +359,8 @@ if 0:
...
@@ -362,8 +359,8 @@ if 0:
"Multiple destroyers of
%
s"
%
input_root
)
"Multiple destroyers of
%
s"
%
input_root
)
droot
[
input_root
]
=
input_root
droot
[
input_root
]
=
input_root
root_destroyer
[
input_root
]
=
app
root_destroyer
[
input_root
]
=
app
#input_impact = set([input_root])
#
input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
#
add_impact(input_root, self.view_o, input_impact)
input_impact
=
get_impact
(
input_root
,
self
.
view_o
)
input_impact
=
get_impact
(
input_root
,
self
.
view_o
)
for
v
in
input_impact
:
for
v
in
input_impact
:
assert
v
not
in
droot
assert
v
not
in
droot
...
@@ -390,7 +387,7 @@ if 0:
...
@@ -390,7 +387,7 @@ if 0:
def
on_import
(
self
,
fgraph
,
app
,
reason
):
def
on_import
(
self
,
fgraph
,
app
,
reason
):
"""Add Apply instance to set which must be computed"""
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
#
if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...
@@ -421,7 +418,7 @@ if 0:
...
@@ -421,7 +418,7 @@ if 0:
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
def
on_prune
(
self
,
fgraph
,
app
,
reason
):
"""Remove Apply instance from set which must be computed"""
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
# self.debug_all_apps.remove(app)
# UPDATE self.clients
# UPDATE self.clients
...
@@ -458,7 +455,7 @@ if 0:
...
@@ -458,7 +455,7 @@ if 0:
# considered 'outputs' of the graph.
# considered 'outputs' of the graph.
pass
pass
else
:
else
:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
#
if app not in self.debug_all_apps: raise ProtocolError("change without import")
# UPDATE self.clients
# UPDATE self.clients
self
.
clients
[
old_r
][
app
]
-=
1
self
.
clients
[
old_r
][
app
]
-=
1
...
@@ -529,9 +526,10 @@ if 0:
...
@@ -529,9 +526,10 @@ if 0:
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
# check for destruction of constants
# check for destruction of constants
illegal_destroy
=
[
r
for
r
in
droot
if
illegal_destroy
=
[
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
r
for
r
in
droot
if
isinstance
(
r
,
graph
.
Constant
)]
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
isinstance
(
r
,
graph
.
Constant
)]
if
illegal_destroy
:
if
illegal_destroy
:
# print 'destroying illegally'
# print 'destroying illegally'
raise
InconsistencyError
(
raise
InconsistencyError
(
...
@@ -603,7 +601,7 @@ if 0:
...
@@ -603,7 +601,7 @@ if 0:
if
input
in
root_impact
\
if
input
in
root_impact
\
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
%
(
app
,
destroyed_idx
,
i
))
%
(
app
,
destroyed_idx
,
i
))
# add the rule: app must be preceded by all other Apply instances that
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
# depend on destroyed_input
...
@@ -621,7 +619,7 @@ if 0:
...
@@ -621,7 +619,7 @@ if 0:
return
rval
return
rval
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
class
DestroyHandler
(
toolbox
.
Bookkeeper
):
# noqa
"""
"""
The DestroyHandler class detects when a graph is impossible to evaluate
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
because of aliasing and destructive operations.
...
@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -702,7 +700,7 @@ class DestroyHandler(toolbox.Bookkeeper):
TODO: WRITEME: what does this do besides the checks?
TODO: WRITEME: what does this do besides the checks?
"""
"""
#
###### Do the checking ##########
#
#
Do the checking
#
already_there
=
False
already_there
=
False
if
self
.
fgraph
is
fgraph
:
if
self
.
fgraph
is
fgraph
:
already_there
=
True
already_there
=
True
...
@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -720,7 +718,7 @@ class DestroyHandler(toolbox.Bookkeeper):
"DestroyHandler feature is already present"
"DestroyHandler feature is already present"
" or in conflict with another plugin."
)
" or in conflict with another plugin."
)
#
###### Annotate the FunctionGraph ###########
#
#
Annotate the FunctionGraph
#
self
.
unpickle
(
fgraph
)
self
.
unpickle
(
fgraph
)
fgraph
.
destroy_handler
=
self
fgraph
.
destroy_handler
=
self
...
@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -945,12 +943,13 @@ class DestroyHandler(toolbox.Bookkeeper):
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
droot
,
impact
,
__ignore
=
self
.
refresh_droot_impact
()
# check for destruction of constants
# check for destruction of constants
illegal_destroy
=
[
r
for
r
in
droot
if
\
illegal_destroy
=
[
r
for
r
in
droot
if
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
\
getattr
(
r
.
tag
,
'indestructible'
,
False
)
or
isinstance
(
r
,
graph
.
Constant
)]
isinstance
(
r
,
graph
.
Constant
)]
if
illegal_destroy
:
if
illegal_destroy
:
raise
InconsistencyError
(
"Attempting to destroy indestructible variables:
%
s"
%
raise
InconsistencyError
(
illegal_destroy
)
"Attempting to destroy indestructible variables:
%
s"
%
illegal_destroy
)
# add destroyed variable clients as computational dependencies
# add destroyed variable clients as computational dependencies
for
app
in
self
.
destroyers
:
for
app
in
self
.
destroyers
:
...
@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -995,24 +994,27 @@ class DestroyHandler(toolbox.Bookkeeper):
# CHECK FOR INPUT ALIASING
# CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
# OPT: pre-compute this on import
tolerate_same
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_same'
,
[])
tolerate_same
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_same'
,
[])
assert
isinstance
(
tolerate_same
,
list
)
assert
isinstance
(
tolerate_same
,
list
)
tolerated
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_same
tolerated
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_same
if
idx0
==
destroyed_idx
)
if
idx0
==
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
tolerate_aliased
=
getattr
(
app
.
op
,
'destroyhandler_tolerate_aliased'
,
[])
assert
isinstance
(
tolerate_aliased
,
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
ignored
=
OrderedSet
(
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
)
if
idx0
==
destroyed_idx
)
# print 'tolerated', tolerated
# print 'tolerated', tolerated
# print 'ignored', ignored
# print 'ignored', ignored
for
i
,
input
in
enumerate
(
app
.
inputs
):
for
i
,
input
in
enumerate
(
app
.
inputs
):
if
i
in
ignored
:
if
i
in
ignored
:
continue
continue
if
input
in
root_impact
\
if
input
in
root_impact
\
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
and
(
i
not
in
tolerated
or
input
is
not
destroyed_variable
):
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
raise
InconsistencyError
(
"Input aliasing:
%
s (
%
i,
%
i)"
%
(
app
,
destroyed_idx
,
i
))
%
(
app
,
destroyed_idx
,
i
))
# add the rule: app must be preceded by all other Apply instances that
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
# depend on destroyed_input
...
...
theano/gof/fg.py
浏览文件 @
cba9c812
...
@@ -13,7 +13,6 @@ from theano.gof import graph
...
@@ -13,7 +13,6 @@ from theano.gof import graph
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof
import
toolbox
from
theano.gof
import
toolbox
from
theano
import
config
from
theano
import
config
import
warnings
from
theano.compat
import
OrderedDict
from
theano.compat
import
OrderedDict
from
six
import
iteritems
,
itervalues
from
six
import
iteritems
,
itervalues
...
@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
...
@@ -22,6 +21,7 @@ from theano.misc.ordered_set import OrderedSet
NullType
=
None
NullType
=
None
class
CachedConstantError
(
Exception
):
class
CachedConstantError
(
Exception
):
"""An exception thrown when we put in a FunctionGraph a Constant
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
that is cached. This should not happen as the user can reuse this
...
@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
...
@@ -143,7 +143,7 @@ class FunctionGraph(utils.object2):
self
.
variable_locks
=
{}
self
.
variable_locks
=
{}
self
.
profile
=
None
self
.
profile
=
None
#
## Setup a Variable ##
#
#
Setup a Variable
#
def
__setup_r__
(
self
,
r
):
def
__setup_r__
(
self
,
r
):
# sets up r so it belongs to this fgraph
# sets up r so it belongs to this fgraph
if
getattr
(
r
,
'cached'
,
False
):
if
getattr
(
r
,
'cached'
,
False
):
...
@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
...
@@ -152,12 +152,12 @@ class FunctionGraph(utils.object2):
" graph that has a cached constant. This should not happen."
" graph that has a cached constant. This should not happen."
" Clone the graph before building the FunctionGraph."
)
" Clone the graph before building the FunctionGraph."
)
if
(
hasattr
(
r
,
'fgraph'
)
and
if
(
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
None
and
r
.
fgraph
is
not
self
):
r
.
fgraph
is
not
self
):
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
r
.
fgraph
=
self
r
.
fgraph
=
self
r
.
clients
=
[]
r
.
clients
=
[]
#self.execute_callbacks('on_setup_variable', r)
#
self.execute_callbacks('on_setup_variable', r)
def
__setup_node__
(
self
,
node
):
def
__setup_node__
(
self
,
node
):
# sets up node so it belongs to this fgraph
# sets up node so it belongs to this fgraph
...
@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
...
@@ -177,7 +177,7 @@ class FunctionGraph(utils.object2):
str
(
node
.
op
),
str
(
node
.
op
.
destroy_map
)))
str
(
node
.
op
),
str
(
node
.
op
.
destroy_map
)))
node
.
fgraph
=
self
node
.
fgraph
=
self
node
.
deps
=
{}
node
.
deps
=
{}
#self.execute_callbacks('on_setup_node', node)
#
self.execute_callbacks('on_setup_node', node)
def
disown
(
self
):
def
disown
(
self
):
""" WRITEME
""" WRITEME
...
@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
...
@@ -201,7 +201,7 @@ class FunctionGraph(utils.object2):
self
.
inputs
=
None
self
.
inputs
=
None
self
.
outputs
=
None
self
.
outputs
=
None
#
## clients ##
#
#
clients
#
def
clients
(
self
,
r
):
def
clients
(
self
,
r
):
"""
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Set of all the (node, i) pairs such that node.inputs[i] is r.
...
@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
...
@@ -221,9 +221,9 @@ class FunctionGraph(utils.object2):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
if
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
)):
print
(
'ERROR: clients intersect!'
,
file
=
sys
.
stderr
)
print
(
'ERROR: clients intersect!'
,
file
=
sys
.
stderr
)
print
(
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
print
(
' RCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
r
.
clients
],
file
=
sys
.
stderr
)
for
n
,
i
in
r
.
clients
],
file
=
sys
.
stderr
)
print
(
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
print
(
' NCLIENTS of'
,
r
,
[(
n
,
i
,
type
(
n
),
id
(
n
))
for
n
,
i
in
new_clients
],
file
=
sys
.
stderr
)
for
n
,
i
in
new_clients
],
file
=
sys
.
stderr
)
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
assert
not
set
(
r
.
clients
)
.
intersection
(
set
(
new_clients
))
r
.
clients
+=
new_clients
r
.
clients
+=
new_clients
...
@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
...
@@ -245,7 +245,7 @@ class FunctionGraph(utils.object2):
return
True
return
True
return
False
return
False
#
## import ##
#
#
import
#
def
__import_r__
(
self
,
variable
,
reason
):
def
__import_r__
(
self
,
variable
,
reason
):
global
NullType
global
NullType
if
NullType
is
None
:
if
NullType
is
None
:
...
@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
...
@@ -279,9 +279,8 @@ class FunctionGraph(utils.object2):
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
if
hasattr
(
r
,
'fgraph'
)
and
r
.
fgraph
is
not
self
:
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
raise
Exception
(
"
%
s is already owned by another fgraph"
%
r
)
if
(
r
.
owner
is
None
and
if
(
r
.
owner
is
None
and
not
isinstance
(
r
,
graph
.
Constant
)
and
not
isinstance
(
r
,
graph
.
Constant
)
and
r
not
in
self
.
inputs
):
r
not
in
self
.
inputs
):
# Verbose error message
# Verbose error message
# Show a complete chain of variables from the missing input to an output
# Show a complete chain of variables from the missing input to an output
if
config
.
exception_verbosity
==
'high'
:
if
config
.
exception_verbosity
==
'high'
:
...
@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
...
@@ -373,7 +372,7 @@ class FunctionGraph(utils.object2):
assert
node
.
fgraph
is
self
assert
node
.
fgraph
is
self
self
.
execute_callbacks
(
'on_import'
,
node
,
reason
)
self
.
execute_callbacks
(
'on_import'
,
node
,
reason
)
#
## prune ##
#
#
prune
#
def
__prune_r__
(
self
,
variable
,
reason
=
None
):
def
__prune_r__
(
self
,
variable
,
reason
=
None
):
"""Should be called for variable that aren't used anymore:
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
len(var.clients) == 0
...
@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
...
@@ -430,7 +429,7 @@ class FunctionGraph(utils.object2):
self
.
__remove_clients__
(
input
,
[(
apply_node
,
i
)],
reason
=
reason
)
self
.
__remove_clients__
(
input
,
[(
apply_node
,
i
)],
reason
=
reason
)
# self.__prune_r__(apply_node.inputs)
# self.__prune_r__(apply_node.inputs)
#
## change input ##
#
#
change input
#
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
def
change_input
(
self
,
node
,
i
,
new_r
,
reason
=
None
):
"""WRITEME
"""WRITEME
Changes node.inputs[i] to new_r.
Changes node.inputs[i] to new_r.
...
@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
...
@@ -475,7 +474,7 @@ class FunctionGraph(utils.object2):
if
prune
:
if
prune
:
self
.
__prune_r__
(
r
,
reason
=
reason
)
self
.
__prune_r__
(
r
,
reason
=
reason
)
#
## replace ##
#
#
replace
#
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
def
replace
(
self
,
r
,
new_r
,
reason
=
None
,
verbose
=
None
):
""" WRITEME
""" WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
This is the main interface to manipulate the subgraph in FunctionGraph.
...
@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
...
@@ -582,7 +581,7 @@ class FunctionGraph(utils.object2):
if
detach
is
not
None
:
if
detach
is
not
None
:
detach
(
self
)
detach
(
self
)
#
## callback utils ##
#
#
callback utils
#
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
def
execute_callbacks
(
self
,
name
,
*
args
,
**
kwargs
):
"""WRITEME
"""WRITEME
Calls
Calls
...
@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
...
@@ -618,7 +617,7 @@ class FunctionGraph(utils.object2):
d
[
feature
]
=
fn
(
*
args
)
d
[
feature
]
=
fn
(
*
args
)
return
d
return
d
#
## misc ##
#
#
misc
#
def
toposort
(
self
):
def
toposort
(
self
):
"""WRITEME
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
Returns an ordering of the graph's Apply nodes such that:
...
@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
...
@@ -712,8 +711,8 @@ class FunctionGraph(utils.object2):
missing
,
excess
)
missing
,
excess
)
for
variable
in
variables
:
for
variable
in
variables
:
if
(
variable
.
owner
is
None
and
if
(
variable
.
owner
is
None
and
variable
not
in
self
.
inputs
and
variable
not
in
self
.
inputs
and
not
isinstance
(
variable
,
graph
.
Constant
)):
not
isinstance
(
variable
,
graph
.
Constant
)):
raise
Exception
(
"Undeclared input."
,
variable
)
raise
Exception
(
"Undeclared input."
,
variable
)
if
variable
.
fgraph
is
not
self
:
if
variable
.
fgraph
is
not
self
:
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
raise
Exception
(
"Variable should belong to the FunctionGraph."
,
...
@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
...
@@ -737,7 +736,7 @@ class FunctionGraph(utils.object2):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
#
## clone ##
#
#
clone
#
def
clone
(
self
,
check_integrity
=
True
):
def
clone
(
self
,
check_integrity
=
True
):
"""WRITEME"""
"""WRITEME"""
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
return
self
.
clone_get_equiv
(
check_integrity
)[
0
]
...
...
theano/gof/link.py
浏览文件 @
cba9c812
...
@@ -7,14 +7,14 @@ import traceback
...
@@ -7,14 +7,14 @@ import traceback
import
numpy
import
numpy
import
theano
import
theano
from
theano.compat
import
PY3
,
izip
from
theano.compat
import
izip
from
six
import
reraise
from
six
import
reraise
from
six.moves
import
StringIO
from
six.moves
import
StringIO
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof
import
graph
from
theano.gof
import
graph
from
theano.gof.type
import
Type
from
theano.gof.type
import
Type
from
.utils
import
MethodNotDefined
,
undef
from
.utils
import
undef
__excepthook
=
sys
.
excepthook
__excepthook
=
sys
.
excepthook
...
@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
...
@@ -281,9 +281,9 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
else
:
else
:
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
" TotalSize:
%
s Byte(s)
%.3
f GB
\n
"
%
(
detailed_err_msg
+=
" TotalSize:
%
s Byte(s)
%.3
f GB
\n
"
%
(
total_size
,
total_size
/
1024.
/
1024
/
1024
)
total_size
,
total_size
/
1024.
/
1024
/
1024
)
detailed_err_msg
+=
" TotalSize inputs:
%
s Byte(s)
%.3
f BG
\n
"
%
(
detailed_err_msg
+=
" TotalSize inputs:
%
s Byte(s)
%.3
f BG
\n
"
%
(
total_size_inputs
,
total_size_inputs
/
1024.
/
1024
/
1024
)
total_size_inputs
,
total_size_inputs
/
1024.
/
1024
/
1024
)
else
:
else
:
hints
.
append
(
hints
.
append
(
...
@@ -326,7 +326,7 @@ class Linker(object):
...
@@ -326,7 +326,7 @@ class Linker(object):
raise
utils
.
MethodNotDefined
(
"make_thunk"
,
type
(
self
),
raise
utils
.
MethodNotDefined
(
"make_thunk"
,
type
(
self
),
self
.
__class__
.
__name__
)
self
.
__class__
.
__name__
)
#
# DELETEME #
#
#
DELETEME
#
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
"""
"""
Returns a function that takes values corresponding to the inputs of the
Returns a function that takes values corresponding to the inputs of the
...
@@ -350,8 +350,8 @@ class Linker(object):
...
@@ -350,8 +350,8 @@ class Linker(object):
def
execute
(
*
args
):
def
execute
(
*
args
):
def
e_arity
(
takes
,
got
):
def
e_arity
(
takes
,
got
):
return
'Function call takes exactly
%
i
%
s (
%
i given)'
\
return
'Function call takes exactly
%
i
%
s (
%
i given)'
%
(
%
(
takes
,
[
'argument'
,
'arguments'
][
takes
>
1
],
got
)
takes
,
[
'argument'
,
'arguments'
][
takes
>
1
],
got
)
if
(
len
(
args
)
!=
len
(
inputs
)):
if
(
len
(
args
)
!=
len
(
inputs
)):
raise
TypeError
(
e_arity
(
len
(
inputs
),
len
(
args
)))
raise
TypeError
(
e_arity
(
len
(
inputs
),
len
(
args
)))
for
arg
,
variable
in
izip
(
args
,
inputs
):
for
arg
,
variable
in
izip
(
args
,
inputs
):
...
@@ -394,7 +394,7 @@ class Container(object):
...
@@ -394,7 +394,7 @@ class Container(object):
"""
"""
if
not
isinstance
(
storage
,
list
)
or
not
len
(
storage
)
>=
1
:
if
not
isinstance
(
storage
,
list
)
or
not
len
(
storage
)
>=
1
:
raise
TypeError
(
"storage must be a list of length at least one"
)
raise
TypeError
(
"storage must be a list of length at least one"
)
#self.r = r
#
self.r = r
if
isinstance
(
r
,
Type
):
if
isinstance
(
r
,
Type
):
self
.
type
=
r
self
.
type
=
r
else
:
else
:
...
@@ -454,12 +454,11 @@ class Container(object):
...
@@ -454,12 +454,11 @@ class Container(object):
deepcopy
(
self
.
strict
,
memo
=
memo
),
deepcopy
(
self
.
strict
,
memo
=
memo
),
deepcopy
(
self
.
allow_downcast
,
memo
=
memo
),
deepcopy
(
self
.
allow_downcast
,
memo
=
memo
),
deepcopy
(
self
.
name
,
memo
=
memo
),
deepcopy
(
self
.
name
,
memo
=
memo
),
)
)
# Work around NumPy deepcopy of ndarray with 0 dimention that
# Work around NumPy deepcopy of ndarray with 0 dimention that
# don't return an ndarray.
# don't return an ndarray.
if
(
r
.
storage
[
0
]
is
not
None
and
if
(
r
.
storage
[
0
]
is
not
None
and
not
self
.
type
.
is_valid_value
(
r
.
storage
[
0
])):
not
self
.
type
.
is_valid_value
(
r
.
storage
[
0
])):
assert
not
data_was_in_memo
assert
not
data_was_in_memo
assert
self
.
type
.
is_valid_value
(
self
.
storage
[
0
])
assert
self
.
type
.
is_valid_value
(
self
.
storage
[
0
])
# This should also work for read only container.
# This should also work for read only container.
...
@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
...
@@ -672,7 +671,7 @@ class PerformLinker(LocalLinker):
no_recycling
=
[]
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
return
type
(
self
)(
allow_gc
=
self
.
allow_gc
)
.
accept
(
fgraph
,
no_recycling
)
return
type
(
self
)(
allow_gc
=
self
.
allow_gc
)
.
accept
(
fgraph
,
no_recycling
)
#raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
#
raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self
.
fgraph
=
fgraph
self
.
fgraph
=
fgraph
self
.
no_recycling
=
no_recycling
self
.
no_recycling
=
no_recycling
return
self
return
self
...
@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
...
@@ -721,9 +720,12 @@ class PerformLinker(LocalLinker):
for
node
in
order
:
for
node
in
order
:
if
self
.
allow_gc
:
if
self
.
allow_gc
:
post_thunk_old_storage
.
append
([
storage_map
[
input
]
post_thunk_old_storage
.
append
(
for
input
in
node
.
inputs
[
storage_map
[
input
]
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
node
==
last_user
[
input
]])
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])])
if
no_recycling
is
True
:
if
no_recycling
is
True
:
# True seems like some special code for *everything*?? -JB
# True seems like some special code for *everything*?? -JB
...
@@ -855,7 +857,7 @@ class WrapLinker(Linker):
...
@@ -855,7 +857,7 @@ class WrapLinker(Linker):
make_all
+=
[
l
.
make_all
(
**
kwargs
)
for
l
in
self
.
linkers
[
1
:]]
make_all
+=
[
l
.
make_all
(
**
kwargs
)
for
l
in
self
.
linkers
[
1
:]]
fns
,
input_lists
,
output_lists
,
thunk_lists
,
order_lists
\
fns
,
input_lists
,
output_lists
,
thunk_lists
,
order_lists
\
=
zip
(
*
make_all
)
=
zip
(
*
make_all
)
order_list0
=
order_lists
[
0
]
order_list0
=
order_lists
[
0
]
for
order_list
in
order_lists
[
1
:]:
for
order_list
in
order_lists
[
1
:]:
...
...
theano/gof/opt.py
浏览文件 @
cba9c812
...
@@ -29,6 +29,7 @@ from . import destroyhandler as dh
...
@@ -29,6 +29,7 @@ from . import destroyhandler as dh
_logger
=
logging
.
getLogger
(
'theano.gof.opt'
)
_logger
=
logging
.
getLogger
(
'theano.gof.opt'
)
_optimizer_idx
=
[
0
]
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
fgraph
):
def
_list_of_nodes
(
fgraph
):
return
list
(
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
))
return
list
(
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
))
...
@@ -99,7 +100,7 @@ class Optimizer(object):
...
@@ -99,7 +100,7 @@ class Optimizer(object):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
(
"
%
s
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
)),
file
=
stream
)
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
)),
file
=
stream
)
@staticmethod
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
def
print_profile
(
stream
,
prof
,
level
=
0
):
...
@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer):
...
@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
"
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
' '
*
level
,
str
(
self
.
apply
),
str
(
self
.
apply
),
id
(
self
)),
file
=
stream
)
id
(
self
)),
file
=
stream
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
fn
(
*
args
,
**
kwargs
)
return
self
.
fn
(
*
args
,
**
kwargs
)
...
@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
(
"
%
s
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
)),
file
=
stream
)
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
)),
file
=
stream
)
# This way, -1 will do all depth
# This way, -1 will do all depth
if
depth
!=
0
:
if
depth
!=
0
:
depth
-=
1
depth
-=
1
...
@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list):
...
@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list):
elif
hasattr
(
opts
,
"__name__"
):
elif
hasattr
(
opts
,
"__name__"
):
print
(
blanc
,
opts
.
__name__
,
end
=
' '
,
file
=
stream
)
print
(
blanc
,
opts
.
__name__
,
end
=
' '
,
file
=
stream
)
print
((
" time
%.3
fs for
%
d/
%
d nodes"
print
((
" time
%.3
fs for
%
d/
%
d nodes"
" before/after optimization"
%
(
" before/after optimization"
%
(
sum
(
prof
),
nb_node_before
,
nb_node_after
)),
file
=
stream
)
sum
(
prof
),
nb_node_before
,
nb_node_after
)),
file
=
stream
)
print
(
blanc
,
"
%.3
fs for fgraph.validate()"
%
(
validate_time
),
file
=
stream
)
print
(
blanc
,
"
%.3
fs for fgraph.validate()"
%
(
validate_time
),
file
=
stream
)
print
(
blanc
,
"
%.3
fs for callback"
%
(
callback_time
),
file
=
stream
)
print
(
blanc
,
"
%.3
fs for callback"
%
(
callback_time
),
file
=
stream
)
if
level
==
0
:
if
level
==
0
:
...
@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list):
...
@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list):
new_t
[
idx
]
+=
p
[
1
][
p
[
0
]
.
index
(
l
)]
new_t
[
idx
]
+=
p
[
1
][
p
[
0
]
.
index
(
l
)]
if
hasattr
(
l
,
'merge_profile'
):
if
hasattr
(
l
,
'merge_profile'
):
assert
len
(
p
[
6
][
p
[
0
]
.
index
(
l
)])
==
\
assert
len
(
p
[
6
][
p
[
0
]
.
index
(
l
)])
==
\
len
(
new_sub_profile
[
idx
])
len
(
new_sub_profile
[
idx
])
new_sub_profile
[
idx
]
=
l
.
merge_profile
(
new_sub_profile
[
idx
]
=
l
.
merge_profile
(
new_sub_profile
[
idx
],
p
[
6
][
p
[
0
]
.
index
(
l
)])
new_sub_profile
[
idx
],
p
[
6
][
p
[
0
]
.
index
(
l
)])
else
:
else
:
...
@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
...
@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
const_sig_inv
=
{}
const_sig_inv
=
{}
if
isinstance
(
vars
,
graph
.
Variable
):
if
isinstance
(
vars
,
graph
.
Variable
):
vars
=
[
vars
]
vars
=
[
vars
]
def
recursive_merge
(
var
):
def
recursive_merge
(
var
):
if
var
in
seen_var
:
if
var
in
seen_var
:
return
var
return
var
...
@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
...
@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
########################
########################
#
## Local Optimizers ##
#
#
Local Optimizers
#
########################
########################
class
LocalOptimizer
(
object
):
class
LocalOptimizer
(
object
):
...
@@ -817,12 +819,14 @@ class LocalOptimizer(object):
...
@@ -817,12 +819,14 @@ class LocalOptimizer(object):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
"
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
)),
file
=
stream
)
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
)),
file
=
stream
)
theano
.
configparser
.
AddConfigVar
(
'metaopt.verbose'
,
theano
.
configparser
.
AddConfigVar
(
"Enable verbose output for meta optimizers"
,
'metaopt.verbose'
,
theano
.
configparser
.
BoolParam
(
False
),
in_c_key
=
False
)
"Enable verbose output for meta optimizers"
,
theano
.
configparser
.
BoolParam
(
False
),
in_c_key
=
False
)
class
LocalMetaOptimizer
(
LocalOptimizer
):
class
LocalMetaOptimizer
(
LocalOptimizer
):
...
@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
...
@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
"
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s id=
%
i"
%
(
' '
*
level
,
' '
*
level
,
str
(
self
.
transform
),
str
(
self
.
transform
),
id
(
self
)),
file
=
stream
)
id
(
self
)),
file
=
stream
)
def
local_optimizer
(
tracks
,
inplace
=
False
):
def
local_optimizer
(
tracks
,
inplace
=
False
):
...
@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
"
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
)),
file
=
stream
)
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
)),
file
=
stream
)
if
depth
!=
0
:
if
depth
!=
0
:
depth
-=
1
depth
-=
1
for
lopt
in
self
.
opts
:
for
lopt
in
self
.
opts
:
...
@@ -1003,19 +1007,6 @@ class LocalOptGroup(LocalOptimizer):
...
@@ -1003,19 +1007,6 @@ class LocalOptGroup(LocalOptimizer):
opt
.
add_requirements
(
fgraph
)
opt
.
add_requirements
(
fgraph
)
class
_LocalOpKeyOptGroup
(
LocalOptGroup
):
"""WRITEME"""
def
__init__
(
self
,
optimizers
):
if
any
(
not
hasattr
(
opt
,
'op_key'
),
optimizers
):
raise
TypeError
(
"All LocalOptimizers passed here must have an op_key method."
)
CompositeLocalOptimizer
.
__init__
(
self
,
optimizers
)
def
op_key
(
self
):
return
[
opt
.
op_key
()
for
opt
in
self
.
opts
]
class
OpSub
(
LocalOptimizer
):
class
OpSub
(
LocalOptimizer
):
"""WRITEME
"""WRITEME
Replaces the application of a certain op by the application of
Replaces the application of a certain op by the application of
...
@@ -1086,10 +1077,10 @@ class OpRemove(LocalOptimizer):
...
@@ -1086,10 +1077,10 @@ class OpRemove(LocalOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
"
%
s
%
s(
%
s) id=
%
i"
%
(
print
(
"
%
s
%
s(
%
s) id=
%
i"
%
(
' '
*
level
,
' '
*
level
,
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
str
(
self
.
op
),
str
(
self
.
op
),
id
(
self
)),
file
=
stream
)
id
(
self
)),
file
=
stream
)
class
PatternSub
(
LocalOptimizer
):
class
PatternSub
(
LocalOptimizer
):
...
@@ -1217,6 +1208,7 @@ class PatternSub(LocalOptimizer):
...
@@ -1217,6 +1208,7 @@ class PatternSub(LocalOptimizer):
if
node
.
op
!=
self
.
op
:
if
node
.
op
!=
self
.
op
:
return
False
return
False
# TODO: if we remove pdb, do this speed things up?
# TODO: if we remove pdb, do this speed things up?
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
def
match
(
pattern
,
expr
,
u
,
allow_multiple_clients
=
False
,
pdb
=
False
):
# TODO move outside match
# TODO move outside match
def
retry_with_equiv
():
def
retry_with_equiv
():
...
@@ -1233,9 +1225,8 @@ class PatternSub(LocalOptimizer):
...
@@ -1233,9 +1225,8 @@ class PatternSub(LocalOptimizer):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
expr
.
owner
is
None
:
if
expr
.
owner
is
None
:
return
False
return
False
if
(
not
(
expr
.
owner
.
op
==
pattern
[
0
])
if
(
not
(
expr
.
owner
.
op
==
pattern
[
0
])
or
or
(
not
allow_multiple_clients
(
not
allow_multiple_clients
and
len
(
expr
.
clients
)
>
1
)):
and
len
(
expr
.
clients
)
>
1
)):
return
retry_with_equiv
()
return
retry_with_equiv
()
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
if
len
(
pattern
)
-
1
!=
len
(
expr
.
owner
.
inputs
):
return
retry_with_equiv
()
return
retry_with_equiv
()
...
@@ -1263,16 +1254,16 @@ class PatternSub(LocalOptimizer):
...
@@ -1263,16 +1254,16 @@ class PatternSub(LocalOptimizer):
return
retry_with_equiv
()
return
retry_with_equiv
()
else
:
else
:
u
=
u
.
merge
(
expr
,
v
)
u
=
u
.
merge
(
expr
,
v
)
elif
(
isinstance
(
pattern
,
(
int
,
float
))
elif
(
isinstance
(
pattern
,
(
int
,
float
))
and
and
isinstance
(
expr
,
graph
.
Constant
)):
isinstance
(
expr
,
graph
.
Constant
)):
if
numpy
.
all
(
if
numpy
.
all
(
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
theano
.
tensor
.
constant
(
pattern
)
.
value
==
expr
.
value
):
return
u
return
u
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
elif
(
isinstance
(
pattern
,
graph
.
Constant
)
elif
(
isinstance
(
pattern
,
graph
.
Constant
)
and
and
isinstance
(
expr
,
graph
.
Constant
)
isinstance
(
expr
,
graph
.
Constant
)
and
and
pattern
.
equals
(
expr
)):
pattern
.
equals
(
expr
)):
return
u
return
u
else
:
else
:
return
retry_with_equiv
()
return
retry_with_equiv
()
...
@@ -1308,17 +1299,17 @@ class PatternSub(LocalOptimizer):
...
@@ -1308,17 +1299,17 @@ class PatternSub(LocalOptimizer):
def
pattern_to_str
(
pattern
):
def
pattern_to_str
(
pattern
):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
if
isinstance
(
pattern
,
(
list
,
tuple
)):
return
"
%
s(
%
s)"
%
(
return
"
%
s(
%
s)"
%
(
str
(
pattern
[
0
]),
str
(
pattern
[
0
]),
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
", "
.
join
([
pattern_to_str
(
p
)
for
p
in
pattern
[
1
:]]))
elif
isinstance
(
pattern
,
dict
):
elif
isinstance
(
pattern
,
dict
):
return
"
%
s subject to
%
s"
%
(
return
"
%
s subject to
%
s"
%
(
pattern_to_str
(
pattern
[
'pattern'
]),
pattern_to_str
(
pattern
[
'pattern'
]),
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
str
(
pattern
.
get
(
'constraint'
,
'no conditions'
)))
else
:
else
:
return
str
(
pattern
)
return
str
(
pattern
)
return
"
%
s ->
%
s"
%
(
return
"
%
s ->
%
s"
%
(
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
in_pattern
),
pattern_to_str
(
self
.
out_pattern
))
pattern_to_str
(
self
.
out_pattern
))
def
__repr__
(
self
):
def
__repr__
(
self
):
return
str
(
self
)
return
str
(
self
)
...
@@ -1326,16 +1317,16 @@ class PatternSub(LocalOptimizer):
...
@@ -1326,16 +1317,16 @@ class PatternSub(LocalOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
name
=
getattr
(
self
,
'__name__'
,
getattr
(
self
,
'name'
,
None
))
print
(
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
print
(
"
%
s
%
s
%
s(
%
s,
%
s) id=
%
i"
%
(
' '
*
level
,
' '
*
level
,
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
name
,
name
,
str
(
self
.
in_pattern
),
str
(
self
.
in_pattern
),
str
(
self
.
out_pattern
),
str
(
self
.
out_pattern
),
id
(
self
)),
file
=
stream
)
id
(
self
)),
file
=
stream
)
##################
##################
#
## Navigators ##
#
#
Navigators
#
##################
##################
# Use the following classes to apply LocalOptimizers
# Use the following classes to apply LocalOptimizers
...
@@ -1545,7 +1536,7 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1545,7 +1536,7 @@ class NavigatorOptimizer(Optimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
print
(
"
%
s
%
s (
%
i)"
%
(
print
(
"
%
s
%
s (
%
i)"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
)),
file
=
stream
)
(
' '
*
level
),
self
.
__class__
.
__name__
,
id
(
self
)),
file
=
stream
)
if
depth
!=
0
:
if
depth
!=
0
:
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
self
.
local_opt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
depth
=
(
depth
-
1
))
depth
=
(
depth
-
1
))
...
@@ -1734,7 +1725,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1734,7 +1725,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self
.
final_optimizers
=
final_optimizers
self
.
final_optimizers
=
final_optimizers
self
.
max_use_ratio
=
max_use_ratio
self
.
max_use_ratio
=
max_use_ratio
assert
self
.
max_use_ratio
is
not
None
,
(
assert
self
.
max_use_ratio
is
not
None
,
(
'max_use_ratio has to be a number'
)
'max_use_ratio has to be a number'
)
def
get_local_optimizers
(
self
):
def
get_local_optimizers
(
self
):
for
opt
in
self
.
local_optimizers_all
:
for
opt
in
self
.
local_optimizers_all
:
...
@@ -1811,8 +1802,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1811,8 +1802,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created
[
gopt
]
+=
change_tracker
.
nb_imported
-
nb
node_created
[
gopt
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
gopt
]
>
max_use
:
if
global_process_count
[
gopt
]
>
max_use
:
max_use_abort
=
True
max_use_abort
=
True
opt_name
=
(
getattr
(
gopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
gopt
,
"name"
,
None
)
or
or
getattr
(
gopt
,
"__name__"
,
""
))
getattr
(
gopt
,
"__name__"
,
""
))
global_sub_profs
.
append
(
sub_profs
)
global_sub_profs
.
append
(
sub_profs
)
global_opt_timing
.
append
(
float
(
time
.
time
()
-
t0
))
global_opt_timing
.
append
(
float
(
time
.
time
()
-
t0
))
...
@@ -1858,8 +1849,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1858,8 +1849,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created
[
lopt
]
+=
change_tracker
.
nb_imported
-
nb
node_created
[
lopt
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
lopt
]
>
max_use
:
if
global_process_count
[
lopt
]
>
max_use
:
max_use_abort
=
True
max_use_abort
=
True
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
or
or
getattr
(
lopt
,
"__name__"
,
""
))
getattr
(
lopt
,
"__name__"
,
""
))
if
node
not
in
fgraph
.
apply_nodes
:
if
node
not
in
fgraph
.
apply_nodes
:
# go to next node
# go to next node
break
break
...
@@ -1884,8 +1875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1884,8 +1875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created
[
gopt
]
+=
change_tracker
.
nb_imported
-
nb
node_created
[
gopt
]
+=
change_tracker
.
nb_imported
-
nb
if
global_process_count
[
gopt
]
>
max_use
:
if
global_process_count
[
gopt
]
>
max_use
:
max_use_abort
=
True
max_use_abort
=
True
opt_name
=
(
getattr
(
gopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
gopt
,
"name"
,
None
)
or
or
getattr
(
gopt
,
"__name__"
,
""
))
getattr
(
gopt
,
"__name__"
,
""
))
final_sub_profs
.
append
(
sub_profs
)
final_sub_profs
.
append
(
sub_profs
)
global_opt_timing
[
-
1
]
+=
time
.
time
()
-
t_before_final_opt
global_opt_timing
[
-
1
]
+=
time
.
time
()
-
t_before_final_opt
...
@@ -1896,9 +1887,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1896,9 +1887,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
end_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
end_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
if
max_use_abort
:
if
max_use_abort
:
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
_logger
.
error
(
"EquilibriumOptimizer max'ed out by '
%
s'"
%
opt_name
+
+
". You can safely raise the current threshold of "
". You can safely raise the current threshold of "
+
+
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
config
.
optdb
.
max_use_ratio
)
config
.
optdb
.
max_use_ratio
)
fgraph
.
remove_feature
(
change_tracker
)
fgraph
.
remove_feature
(
change_tracker
)
return
(
self
,
loop_timing
,
loop_process_count
,
return
(
self
,
loop_timing
,
loop_process_count
,
...
@@ -1909,7 +1900,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1909,7 +1900,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
name
=
getattr
(
self
,
'name'
,
None
)
name
=
getattr
(
self
,
'name'
,
None
)
print
(
"
%
s
%
s
%
s id=
%
i"
%
(
print
(
"
%
s
%
s
%
s id=
%
i"
%
(
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
)),
file
=
stream
)
(
' '
*
level
),
self
.
__class__
.
__name__
,
name
,
id
(
self
)),
file
=
stream
)
if
depth
!=
0
:
if
depth
!=
0
:
for
lopt
in
self
.
get_local_optimizers
():
for
lopt
in
self
.
get_local_optimizers
():
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
lopt
.
print_summary
(
stream
,
level
=
(
level
+
2
),
...
@@ -1925,11 +1916,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1925,11 +1916,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
blanc
=
(
' '
*
level
)
blanc
=
(
' '
*
level
)
print
(
blanc
,
"EquilibriumOptimizer"
,
end
=
' '
,
file
=
stream
)
print
(
blanc
,
"EquilibriumOptimizer"
,
end
=
' '
,
file
=
stream
)
print
(
blanc
,
getattr
(
opt
,
"name"
,
print
(
blanc
,
getattr
(
opt
,
"name"
,
getattr
(
opt
,
"__name__"
,
""
)),
file
=
stream
)
getattr
(
opt
,
"__name__"
,
""
)),
file
=
stream
)
print
(
blanc
,
" time
%.3
fs for
%
d passes"
%
(
print
(
blanc
,
" time
%.3
fs for
%
d passes"
%
(
sum
(
loop_timing
),
len
(
loop_timing
)),
file
=
stream
)
sum
(
loop_timing
),
len
(
loop_timing
)),
file
=
stream
)
print
(
blanc
,
" nb nodes (start, end, max)
%
d
%
d
%
d"
%
(
print
(
blanc
,
" nb nodes (start, end, max)
%
d
%
d
%
d"
%
(
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
file
=
stream
)
start_nb_nodes
,
end_nb_nodes
,
max_nb_nodes
),
file
=
stream
)
print
(
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
print
(
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
io_toposort_timing
),
file
=
stream
)
io_toposort_timing
),
file
=
stream
)
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
get_local_optimizers
()])
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
get_local_optimizers
()])
...
@@ -1948,12 +1939,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1948,12 +1939,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if
len
(
d
)
>
5
:
if
len
(
d
)
>
5
:
lopt
+=
" ..."
lopt
+=
" ..."
print
(
blanc
,
(
'
%2
d -
%.3
fs
%
d (
%.3
fs in global opts, '
print
(
blanc
,
(
'
%2
d -
%.3
fs
%
d (
%.3
fs in global opts, '
'
%.3
fs io_toposort) -
%
d nodes -
%
s'
%
(
'
%.3
fs io_toposort) -
%
d nodes -
%
s'
%
(
i
,
loop_timing
[
i
],
i
,
loop_timing
[
i
],
sum
(
loop_process_count
[
i
]
.
values
()),
sum
(
loop_process_count
[
i
]
.
values
()),
global_opt_timing
[
i
],
global_opt_timing
[
i
],
io_toposort_timing
[
i
],
nb_nodes
[
i
],
io_toposort_timing
[
i
],
nb_nodes
[
i
],
lopt
)),
file
=
stream
)
lopt
)),
file
=
stream
)
count_opt
=
[]
count_opt
=
[]
not_used
=
[]
not_used
=
[]
...
@@ -1975,8 +1966,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1975,8 +1966,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used_time
+=
time_opts
[
o
]
not_used_time
+=
time_opts
[
o
]
if
count_opt
:
if
count_opt
:
print
(
blanc
,
\
print
(
blanc
,
' times - times applied - nb node created - name:'
,
file
=
stream
)
' times - times applied - nb node created - name:'
,
file
=
stream
)
count_opt
.
sort
()
count_opt
.
sort
()
for
(
t
,
count
,
n_created
,
o
)
in
count_opt
[::
-
1
]:
for
(
t
,
count
,
n_created
,
o
)
in
count_opt
[::
-
1
]:
print
(
blanc
,
'
%.3
fs -
%
d -
%
d -
%
s'
%
(
print
(
blanc
,
'
%.3
fs -
%
d -
%
d -
%
s'
%
(
...
@@ -2010,7 +2002,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2010,7 +2002,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
@staticmethod
def
merge_profile
(
prof1
,
prof2
):
def
merge_profile
(
prof1
,
prof2
):
#(opt, loop_timing, loop_process_count, max_nb_nodes,
#
(opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers
=
OrderedSet
(
prof1
[
0
]
.
get_local_optimizers
())
.
union
(
local_optimizers
=
OrderedSet
(
prof1
[
0
]
.
get_local_optimizers
())
.
union
(
prof2
[
0
]
.
get_local_optimizers
())
prof2
[
0
]
.
get_local_optimizers
())
...
@@ -2085,7 +2077,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -2085,7 +2077,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs
)
final_sub_profs
)
#################
#################
#
## Utilities ##
#
#
Utilities
#
#################
#################
...
@@ -2096,7 +2088,7 @@ def _check_chain(r, chain):
...
@@ -2096,7 +2088,7 @@ def _check_chain(r, chain):
while
chain
:
while
chain
:
elem
=
chain
.
pop
()
elem
=
chain
.
pop
()
if
elem
is
None
:
if
elem
is
None
:
if
not
r
.
owner
is
None
:
if
r
.
owner
is
not
None
:
return
False
return
False
elif
r
.
owner
is
None
:
elif
r
.
owner
is
None
:
return
False
return
False
...
@@ -2105,20 +2097,20 @@ def _check_chain(r, chain):
...
@@ -2105,20 +2097,20 @@ def _check_chain(r, chain):
return
False
return
False
else
:
else
:
try
:
try
:
if
(
issubclass
(
elem
,
op
.
Op
)
if
(
issubclass
(
elem
,
op
.
Op
)
and
and
not
isinstance
(
r
.
owner
.
op
,
elem
)):
not
isinstance
(
r
.
owner
.
op
,
elem
)):
return
False
return
False
except
TypeError
:
except
TypeError
:
return
False
return
False
if
chain
:
if
chain
:
r
=
r
.
owner
.
inputs
[
chain
.
pop
()]
r
=
r
.
owner
.
inputs
[
chain
.
pop
()]
# print 'check_chain', _check_chain.n_calls
# print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
#
_check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
# be used as Booleans (the results of comparisons, for instance)
return
(
r
is
not
None
)
return
(
r
is
not
None
)
#_check_chain.n_calls = 0
#
_check_chain.n_calls = 0
def
check_chain
(
r
,
*
chain
):
def
check_chain
(
r
,
*
chain
):
...
...
theano/gof/utils.py
浏览文件 @
cba9c812
...
@@ -3,9 +3,11 @@ import linecache
...
@@ -3,9 +3,11 @@ import linecache
import
traceback
import
traceback
import
sys
import
sys
import
numpy
from
six
import
iteritems
from
six
import
iteritems
from
theano
import
config
from
theano
import
config
from
theano.compat
import
PY3
def
simple_extract_stack
(
f
=
None
,
limit
=
None
):
def
simple_extract_stack
(
f
=
None
,
limit
=
None
):
...
@@ -435,3 +437,31 @@ def remove(predicate, coll):
...
@@ -435,3 +437,31 @@ def remove(predicate, coll):
[1, 3]
[1, 3]
"""
"""
return
[
x
for
x
in
coll
if
not
predicate
(
x
)]
return
[
x
for
x
in
coll
if
not
predicate
(
x
)]
if
PY3
:
import
hashlib
def
hash_from_code
(
msg
):
# hashlib.md5() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
if
isinstance
(
msg
,
str
):
msg
=
msg
.
encode
()
# Python 3 does not like module names that start with
# a digit.
return
'm'
+
hashlib
.
md5
(
msg
)
.
hexdigest
()
else
:
import
hashlib
def
hash_from_code
(
msg
):
try
:
return
hashlib
.
md5
(
msg
)
.
hexdigest
()
except
TypeError
:
assert
isinstance
(
msg
,
numpy
.
ndarray
)
return
hashlib
.
md5
(
numpy
.
getbuffer
(
msg
))
.
hexdigest
()
def
hash_from_file
(
file_path
):
"""Return the MD5 hash of a file."""
return
hash_from_code
(
open
(
file_path
,
'rb'
)
.
read
())
theano/sandbox/cuda/nvcc_compiler.py
浏览文件 @
cba9c812
...
@@ -10,7 +10,7 @@ import numpy
...
@@ -10,7 +10,7 @@ import numpy
from
theano.compat
import
decode
,
decode_iter
from
theano.compat
import
decode
,
decode_iter
from
theano.gof
import
local_bitwidth
from
theano.gof
import
local_bitwidth
from
theano.gof.
cc
import
hash_from_file
from
theano.gof.
utils
import
hash_from_file
from
theano.gof.cmodule
import
(
std_libs
,
std_lib_dirs
,
from
theano.gof.cmodule
import
(
std_libs
,
std_lib_dirs
,
std_include_dirs
,
dlimport
,
std_include_dirs
,
dlimport
,
Compiler
,
Compiler
,
...
...
theano/sparse/utils.py
浏览文件 @
cba9c812
from
theano.gof.
cc
import
hash_from_code
from
theano.gof.
utils
import
hash_from_code
def
hash_from_sparse
(
data
):
def
hash_from_sparse
(
data
):
...
...
theano/tensor/utils.py
浏览文件 @
cba9c812
...
@@ -2,7 +2,7 @@ import numpy
...
@@ -2,7 +2,7 @@ import numpy
import
theano
import
theano
from
theano.compat
import
izip
from
theano.compat
import
izip
from
theano.gof.
cc
import
hash_from_code
from
theano.gof.
utils
import
hash_from_code
def
hash_from_ndarray
(
data
):
def
hash_from_ndarray
(
data
):
...
...
theano/tests/test_flake8.py
浏览文件 @
cba9c812
...
@@ -233,16 +233,10 @@ whitelist_flake8 = [
...
@@ -233,16 +233,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/sp2.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/sp.py"
,
"sparse/sandbox/sp.py"
,
"gof/destroyhandler.py"
,
"gof/unify.py"
,
"gof/unify.py"
,
"gof/graph.py"
,
"gof/graph.py"
,
"gof/__init__.py"
,
"gof/__init__.py"
,
"gof/cc.py"
,
"gof/opt.py"
,
"gof/link.py"
,
"gof/fg.py"
,
"gof/op.py"
,
"gof/op.py"
,
"gof/cmodule.py"
,
"gof/tests/test_cmodule.py"
,
"gof/tests/test_cmodule.py"
,
"gof/tests/test_destroyhandler.py"
,
"gof/tests/test_destroyhandler.py"
,
"gof/tests/test_opt.py"
,
"gof/tests/test_opt.py"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论