Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4cf7afb4
提交
4cf7afb4
authored
5月 25, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2952 from abergeron/flake8
Flake8 work
上级
1003abb0
0761b0b8
全部展开
显示空白字符变更
内嵌
并排
正在显示
31 个修改的文件
包含
424 行增加
和
369 行删除
+424
-369
builders.py
theano/compile/builders.py
+7
-6
debugmode.py
theano/compile/debugmode.py
+0
-0
function.py
theano/compile/function.py
+60
-51
function_module.py
theano/compile/function_module.py
+0
-0
io.py
theano/compile/io.py
+66
-34
mode.py
theano/compile/mode.py
+17
-17
monitormode.py
theano/compile/monitormode.py
+1
-1
ops.py
theano/compile/ops.py
+50
-37
pfunc.py
theano/compile/pfunc.py
+0
-0
profilemode.py
theano/compile/profilemode.py
+0
-0
profiling.py
theano/compile/profiling.py
+0
-0
sharedvalue.py
theano/compile/sharedvalue.py
+11
-8
callcache.py
theano/gof/callcache.py
+2
-7
compiledir.py
theano/gof/compiledir.py
+4
-4
compilelock.py
theano/gof/compilelock.py
+2
-2
cutils.py
theano/gof/cutils.py
+42
-32
lazylinker_c.py
theano/gof/lazylinker_c.py
+8
-7
optdb.py
theano/gof/optdb.py
+6
-4
sched.py
theano/gof/sched.py
+5
-5
test_utils.py
theano/gof/tests/test_utils.py
+6
-3
toolbox.py
theano/gof/toolbox.py
+14
-10
type.py
theano/gof/type.py
+48
-35
utils.py
theano/gof/utils.py
+15
-15
vm.py
theano/gof/vm.py
+40
-30
opt_util.py
theano/sandbox/cuda/opt_util.py
+0
-1
test_fftconv.py
theano/sandbox/cuda/tests/test_fftconv.py
+3
-3
_test_mpi_roundtrip.py
theano/tensor/tests/_test_mpi_roundtrip.py
+2
-1
disturb_mem.py
theano/tests/disturb_mem.py
+2
-2
main.py
theano/tests/main.py
+12
-14
test_flake8.py
theano/tests/test_flake8.py
+0
-39
version.py
theano/version.py
+1
-1
没有找到文件。
theano/compile/builders.py
浏览文件 @
4cf7afb4
...
@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
...
@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO:
TODO:
- examples for a multi-layer mlp. where?
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs)
- __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead?
- c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs
- opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface
- grad() make it support DisconnectedType and the new interface
...
@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
...
@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient.
# not see them. Otherwise their is problem with the gradient.
self
.
shared_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
self
.
shared_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
if
isinstance
(
var
,
SharedVariable
)]
if
isinstance
(
var
,
SharedVariable
)]
used_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
if
not
isinstance
(
var
,
gof
.
Constant
)]
shared_vars
=
[
var
.
type
()
for
var
in
self
.
shared_inputs
]
shared_vars
=
[
var
.
type
()
for
var
in
self
.
shared_inputs
]
new
=
rebuild_collect_shared
(
outputs
,
inputs
=
inputs
+
shared_vars
,
new
=
rebuild_collect_shared
(
outputs
,
inputs
=
inputs
+
shared_vars
,
replace
=
dict
(
zip
(
self
.
shared_inputs
,
replace
=
dict
(
zip
(
self
.
shared_inputs
,
...
@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
...
@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
for
input
,
type
in
zip
(
inputs
,
self
.
input_types
):
for
input
,
type
in
zip
(
inputs
,
self
.
input_types
):
if
not
type
==
input
.
type
:
if
not
type
==
input
.
type
:
raise
TypeError
(
"Wrong type, expected
%
s but got
%
s"
raise
TypeError
(
"Wrong type, expected
%
s but got
%
s"
%
%
(
type
,
input
.
type
))
(
type
,
input
.
type
))
return
gof
.
Apply
(
self
,
return
gof
.
Apply
(
self
,
list
(
inputs
)
+
self
.
shared_inputs
,
list
(
inputs
)
+
self
.
shared_inputs
,
[
type
()
for
type
in
self
.
output_types
])
[
type
()
for
type
in
self
.
output_types
])
...
@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op):
...
@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op):
grad_ops
=
self
.
grad_ops
grad_ops
=
self
.
grad_ops
else
:
else
:
gs
=
theano
.
gradient
.
grad
(
cost
=
None
,
gs
=
theano
.
gradient
.
grad
(
cost
=
None
,
known_grads
=
dict
(
zip
(
self
.
new_outputs
,
output_grads
)),
known_grads
=
dict
(
zip
(
self
.
new_outputs
,
output_grads
)),
wrt
=
self
.
new_inputs
,
wrt
=
self
.
new_inputs
,
disconnected_inputs
=
'ignore'
)
disconnected_inputs
=
'ignore'
)
...
...
theano/compile/debugmode.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/function.py
浏览文件 @
4cf7afb4
"""Define the `function` function
"""Define the `function` function
"""
"""
__docformat__
=
"restructuredtext en"
import
cPickle
import
cPickle
import
logging
import
logging
_logger
=
logging
.
getLogger
(
'theano.compile.function'
)
import
traceback
as
tb
import
traceback
as
tb
import
re
import
re
...
@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function
...
@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function
from
theano.compile.pfunc
import
pfunc
from
theano.compile.pfunc
import
pfunc
from
numpy
import
any
from
numpy
import
any
import
warnings
import
warnings
from
theano
import
gof
from
theano
import
compat
from
theano
import
compat
__docformat__
=
"restructuredtext en"
_logger
=
logging
.
getLogger
(
'theano.compile.function'
)
def
function_dump
(
filename
,
inputs
,
outputs
=
None
,
mode
=
None
,
updates
=
None
,
def
function_dump
(
filename
,
inputs
,
outputs
=
None
,
mode
=
None
,
updates
=
None
,
givens
=
None
,
givens
=
None
,
...
@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
...
@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `Mode` instance.
:type mode: string or `Mode` instance.
:param mode: compilation mode
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or OrderedDict.
:type updates: iterable over pairs (shared_variable, new_expression).
:param updates: update the values for SharedVariable inputs according to these expressions
List, tuple or OrderedDict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List,
tuple or dict. The Var1
:type givens: iterable over pairs (Var1, Var2) of Variables. List,
and Var2 in each pair must have the same Type.
tuple or dict. The Var1 and Var2 in each pair must
have the same Type.
:param givens: specific substitutions to make in the computation
graph (Var2 replaces
:param givens: specific substitutions to make in the computation
Var1).
graph (Var2 replaces
Var1).
:type no_default_updates: either bool or list of Variables
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
:param no_default_updates: if True, do not perform any automatic
If False (default), perform them all. Else, perform automatic updates on all Variables
update on Variables. If False (default), perform them
that are neither in "updates" nor in "no_default_updates".
all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function.
:param name: an optional name for this function. The profile mode
:param rebuild_strict: True (Default) is the safer and better tested setting, in which case
will print the time spent in this function.
`givens` must substitute new variables with the same Type as the variables they replace.
False is a you-better-know-what-you-are-doing setting, that permits `givens` to replace
:param rebuild_strict: True (Default) is the safer and better
variables with new variables of any Type. The consequence of changing a Type is that all
tested setting, in which case `givens` must substitute new
results depending on that variable may have a different Type too (the graph is rebuilt from
variables with the same Type as the variables they replace.
inputs to outputs). If one of the new types does not make sense for one of the Ops in the
False is a you-better-know-what-you-are-doing setting, that
permits `givens` to replace variables with new variables of
any Type. The consequence of changing a Type is that all
results depending on that variable may have a different Type
too (the graph is rebuilt from inputs to outputs). If one of
the new types does not make sense for one of the Ops in the
graph, an Exception will be raised.
graph, an Exception will be raised.
:type allow_input_downcast: Boolean or None
:type allow_input_downcast: Boolean or None
:param allow_input_downcast: True means that the values passed as
:param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit
inputs when calling the function can be silently downcasted to
the dtype of the corresponding Variable, which may lose precision.
fit the dtype of the corresponding Variable, which may lose
False means that it will only be cast to a more general, or
precision. False means that it will only be cast to a more
precise, type. None (default) is almost like False, but allows
general, or precise, type. None (default) is almost like
downcasting of Python float scalars to floatX.
False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, or ProfileStats instance
:type profile: None, True, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats
:param profile: accumulate profiling information into a given
instance. If argument is `True` then a new ProfileStats instance will be
ProfileStats instance. If argument is `True` then a new
used. This profiling object will be available via self.profile.
ProfileStats instance will be used. This profiling object
will be available via self.profile.
:param on_unused_input: What to do if a variable in the 'inputs' list is
:param on_unused_input: What to do if a variable in the 'inputs'
not used in the graph. Possible values are 'raise', 'warn', 'ignore' and None.
list is not used in the graph. Possible values are 'raise',
'warn', 'ignore' and None.
:rtype: Function instance
:rtype: Function instance
:returns: a callable object that will compute the outputs (given the inputs)
:returns: a callable object that will compute the outputs (given
and update the implicit function arguments according to the `updates`.
the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
:note: Regarding givens: Be careful to make sure that these
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
substitutions are independent--behaviour when Var1 of one pair
another expression is undefined. Replacements specified with givens are different from
appears in the graph leading to Var2 in another expression is
optimizations in that Var2 is not expected to be equivalent to Var1.
undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
Internal documentation:
Internal documentation:
...
@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
...
@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
was easier to develop the VM in Python then translate it to C instead
was easier to develop the VM in Python then translate it to C instead
of just writing it in C from scratch.
of just writing it in C from scratch.
CVM stands for C Virtual Machine.
CVM stands for C Virtual Machine.
"""
"""
if
isinstance
(
outputs
,
dict
):
if
isinstance
(
outputs
,
dict
):
output_items
=
outputs
.
items
()
output_items
=
outputs
.
items
()
...
@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
...
@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
output_keys
.
append
(
pair
[
0
])
output_keys
.
append
(
pair
[
0
])
outputs
.
append
(
pair
[
1
])
outputs
.
append
(
pair
[
1
])
else
:
else
:
output_keys
=
None
output_keys
=
None
...
@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
...
@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if
givens
is
None
:
if
givens
is
None
:
givens
=
[]
givens
=
[]
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
raise
Exception
(
"Input variables of a Theano function should be"
raise
Exception
(
"Input variables of a Theano function should be "
" contained in a list, even when there is a single input."
)
"contained in a list, even when there is a single "
"input."
)
# compute some features of the arguments:
# compute some features of the arguments:
uses_In
=
any
([
isinstance
(
i
,
In
)
for
i
in
inputs
])
# N.B. the square brackets are ncessary
uses_In
=
any
([
isinstance
(
i
,
In
)
for
i
in
inputs
])
uses_tuple
=
any
([
isinstance
(
i
,
(
list
,
tuple
))
for
i
in
inputs
])
# N.B. the square brackets are ncessary
uses_tuple
=
any
([
isinstance
(
i
,
(
list
,
tuple
))
for
i
in
inputs
])
uses_updates
=
bool
(
updates
)
uses_updates
=
bool
(
updates
)
uses_givens
=
bool
(
givens
)
uses_givens
=
bool
(
givens
)
...
@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
...
@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if
uses_In
or
uses_tuple
:
if
uses_In
or
uses_tuple
:
# we must use old semantics in this case.
# we must use old semantics in this case.
if
profile
:
if
profile
:
raise
NotImplementedError
(
'profiling not supported in old-style function'
)
raise
NotImplementedError
(
"profiling not supported in old-style "
"function"
)
if
uses_updates
or
uses_givens
:
if
uses_updates
or
uses_givens
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"In() instances and tuple inputs trigger the old "
"In() instances and tuple inputs trigger the old "
...
@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
...
@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode
=
mode
,
mode
=
mode
,
accept_inplace
=
accept_inplace
,
name
=
name
)
accept_inplace
=
accept_inplace
,
name
=
name
)
else
:
else
:
# note: pfunc will also call orig_function-- orig_function is
a choke point
# note: pfunc will also call orig_function-- orig_function is
# that all compilation must pass through
#
a choke point
that all compilation must pass through
fn
=
pfunc
(
params
=
inputs
,
fn
=
pfunc
(
params
=
inputs
,
outputs
=
outputs
,
outputs
=
outputs
,
mode
=
mode
,
mode
=
mode
,
...
...
theano/compile/function_module.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/io.py
浏览文件 @
4cf7afb4
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
__docformat__
=
'restructuredtext en'
from
theano
import
gof
from
theano
import
gof
from
sharedvalue
import
SharedVariable
from
sharedvalue
import
SharedVariable
...
@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable
...
@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable
import
logging
import
logging
_logger
=
logging
.
getLogger
(
"theano.compile.io"
)
_logger
=
logging
.
getLogger
(
"theano.compile.io"
)
__docformat__
=
'restructuredtext en'
class
SymbolicInput
(
object
):
class
SymbolicInput
(
object
):
"""
"""
...
@@ -17,34 +18,47 @@ class SymbolicInput(object):
...
@@ -17,34 +18,47 @@ class SymbolicInput(object):
not computed from its owner.
not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name).
name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by
kwarg, and its value
If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>.
kwarg, and its value
can be accessed by self.<name>.
update: Variable instance (default: None)
update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call.
value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input.
variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is
not None)
True: permit the compiled function to modify the python object
being passed as the input
mutable: Bool (default: False if update is None, True if update is not None)
False: do not permit the compiled function to modify the
True: permit the compiled function to modify the python object being passed as the input
python object being passed as the input.
False: do not permit the compiled function to modify the python object being passed as the input.
strict: Bool (default: False)
strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type
True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None)
allow_downcast: Bool or None (default: None)
Only applies when `strict` is False.
Only applies when `strict` is False.
True: the value you pass for this input can be silently
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True)
autoname: Bool (default: True)
See the name option.
See the name option.
implicit: Bool (default: False)
implicit: Bool (default: False)
See help(In). Note that 'None' is not allowed here, since we
are in the
See help(In). Note that 'None' is not allowed here, since we
symbolic case.
are in the
symbolic case.
"""
"""
def
__init__
(
self
,
variable
,
name
=
None
,
update
=
None
,
mutable
=
None
,
def
__init__
(
self
,
variable
,
name
=
None
,
update
=
None
,
mutable
=
None
,
...
@@ -146,36 +160,54 @@ class In(SymbolicInput):
...
@@ -146,36 +160,54 @@ class In(SymbolicInput):
not computed from its owner.
not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name).
name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by
kwarg, and its value
If name is a valid Python identifier, this input can be set by
can be accessed by self.<name>.
kwarg, and its value
can be accessed by self.<name>.
value: Any type.
value: Any type.
The initial/default value for this input. If update is None, this input acts just like
The initial/default value for this input. If update is None,
an argument with a default value in Python. If update is not None, changes to this
this input acts just like an argument with a default value in
value will "stick around", whether due to an update or a user's explicit action.
Python. If update is not None, changes to this value will
"stick around", whether due to an update or a user's explicit
action.
update: Variable instance (default: None)
update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call.
value (see previous) will be replaced with this expression
If update is None, the update will be the default value of the input.
variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is not None)
mutable: Bool (default: False if update is None, True if update is
True: permit the compiled function to modify the python object being passed as the input
not None)
False: do not permit the compiled function to modify the python object being passed as the input.
True: permit the compiled function to modify the python object
being passed as the input
False: do not permit the compiled function to modify the
python object being passed as the input.
borrow: Bool (default: take the same value as mutable)
borrow: Bool (default: take the same value as mutable)
True: permit the output of the compiled function to be aliased to the input
True: permit the output of the compiled function to be aliased
to the input
False: do not permit any output to be aliased to the input
False: do not permit any output to be aliased to the input
strict: Bool (default: False)
strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type
True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None)
allow_downcast: Bool or None (default: None)
Only applies when `strict` is False.
Only applies when `strict` is False.
True: the value you pass for this input can be silently
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True)
autoname: Bool (default: True)
See the name option.
See the name option.
...
@@ -194,11 +226,11 @@ class In(SymbolicInput):
...
@@ -194,11 +226,11 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt,
# Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized.
# try to keep it synchronized.
def
__init__
(
self
,
variable
,
name
=
None
,
value
=
None
,
update
=
None
,
def
__init__
(
self
,
variable
,
name
=
None
,
value
=
None
,
update
=
None
,
mutable
=
None
,
strict
=
False
,
allow_downcast
=
None
,
autoname
=
Tru
e
,
mutable
=
None
,
strict
=
False
,
allow_downcast
=
Non
e
,
implicit
=
None
,
borrow
=
None
,
shared
=
False
):
autoname
=
True
,
implicit
=
None
,
borrow
=
None
,
shared
=
False
):
# if shared, an input's value comes from its persistent
#
if shared, an input's value comes from its persistent storage, not from a default stored
#
storage, not from a default stored in the function or from
#
in the function or from
the caller
# the caller
self
.
shared
=
shared
self
.
shared
=
shared
if
borrow
is
None
:
if
borrow
is
None
:
...
...
theano/compile/mode.py
浏览文件 @
4cf7afb4
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
"""
"""
from
__future__
import
print_function
from
__future__
import
print_function
import
logging
import
logging
import
warnings
from
textwrap
import
dedent
import
numpy
import
numpy
...
@@ -11,24 +9,24 @@ import theano
...
@@ -11,24 +9,24 @@ import theano
from
theano
import
gof
from
theano
import
gof
import
theano.gof.vm
import
theano.gof.vm
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
from
theano.compile.ops
import
register_view_op_c_code
,
_output_guard
from
theano.compile.ops
import
_output_guard
_logger
=
logging
.
getLogger
(
'theano.compile.mode'
)
_logger
=
logging
.
getLogger
(
'theano.compile.mode'
)
AddConfigVar
(
'optimizer_excluding'
,
AddConfigVar
(
'optimizer_excluding'
,
(
"When using the default mode, we will remove optimizer with these
"
(
"When using the default mode, we will remove optimizer with
"
"
tags. Separate tags with ':'."
),
"these
tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'optimizer_including'
,
AddConfigVar
(
'optimizer_including'
,
(
"When using the default mode, we will add optimizer with these tags.
"
(
"When using the default mode, we will add optimizer with
"
"
Separate tags with ':'."
),
"these tags.
Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
AddConfigVar
(
'optimizer_requiring'
,
AddConfigVar
(
'optimizer_requiring'
,
(
"When using the default mode, we will require optimizer with these
"
(
"When using the default mode, we will require optimizer with
"
"
tags. Separate tags with ':'."
),
"these
tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
in_c_key
=
False
)
...
@@ -50,9 +48,9 @@ def check_equal(x, y):
...
@@ -50,9 +48,9 @@ def check_equal(x, y):
y
=
y
.
todense
()
y
=
y
.
todense
()
if
isinstance
(
x
,
numpy
.
ndarray
)
and
isinstance
(
y
,
numpy
.
ndarray
):
if
isinstance
(
x
,
numpy
.
ndarray
)
and
isinstance
(
y
,
numpy
.
ndarray
):
if
(
x
.
dtype
!=
y
.
dtype
if
(
x
.
dtype
!=
y
.
dtype
or
or
x
.
shape
!=
y
.
shape
x
.
shape
!=
y
.
shape
or
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
)):
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
)):
raise
Exception
(
"Output mismatch."
,
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
{
'performlinker'
:
x
,
'clinker'
:
y
})
else
:
else
:
...
@@ -287,7 +285,8 @@ class Mode(object):
...
@@ -287,7 +285,8 @@ class Mode(object):
def
__str__
(
self
):
def
__str__
(
self
):
return
"
%
s(linker =
%
s, optimizer =
%
s)"
%
(
self
.
__class__
.
__name__
,
return
"
%
s(linker =
%
s, optimizer =
%
s)"
%
(
self
.
__class__
.
__name__
,
self
.
provided_linker
,
self
.
provided_optimizer
)
self
.
provided_linker
,
self
.
provided_optimizer
)
def
__get_optimizer
(
self
):
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
...
@@ -364,10 +363,11 @@ def get_mode(orig_string):
...
@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker.
# DebugMode use its own linker.
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
else
:
else
:
# The import is needed in case string is ProfileMode
# This might be required if the string is 'ProfileMode'
from
profilemode
import
ProfileMode
,
prof_mode_instance_to_print
from
profilemode
import
ProfileMode
# noqa
ret
=
eval
(
string
from
profilemode
import
prof_mode_instance_to_print
+
'(linker=config.linker, optimizer=config.optimizer)'
)
ret
=
eval
(
string
+
'(linker=config.linker, optimizer=config.optimizer)'
)
elif
string
in
predefined_modes
:
elif
string
in
predefined_modes
:
ret
=
predefined_modes
[
string
]
ret
=
predefined_modes
[
string
]
else
:
else
:
...
...
theano/compile/monitormode.py
浏览文件 @
4cf7afb4
from
__future__
import
print_function
from
__future__
import
print_function
# Note: this code was initially copied from the 'pyutools' package by its
# Note: this code was initially copied from the 'pyutools' package by its
# original author, and re-licensed under Theano's license.
# original author, and re-licensed under Theano's license.
import
numpy
import
theano
import
theano
from
theano.compile.mode
import
Mode
from
theano.compile.mode
import
Mode
...
...
theano/compile/ops.py
浏览文件 @
4cf7afb4
...
@@ -71,11 +71,12 @@ class ViewOp(gof.Op):
...
@@ -71,11 +71,12 @@ class ViewOp(gof.Op):
version
=
[]
version
=
[]
# If any of the c code is unversionned, we have to return ()
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for ViewOp, but it has "
warnings
.
warn
(
"Type
%
s has C code for ViewOp, but it has
no
"
"no version. You should add a 'version' keyword arg
"
"version. You should add a 'version' keyword
"
"
when calling register_view_op_c_code."
%
t
,
"arg
when calling register_view_op_c_code."
%
t
,
stacklevel
=
2
)
stacklevel
=
2
)
return
()
return
()
version
.
append
((
str
(
t
),
v
))
version
.
append
((
str
(
t
),
v
))
...
@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op):
...
@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op):
version
=
[]
version
=
[]
# If any of the c code is unversionned, we have to return ()
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for DeepCopyOp, but it has "
warnings
.
warn
(
"Type
%
s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg "
"no version. You should add a 'version' keyword"
"when calling register_deep_copy_op_c_code."
%
t
,
" arg when calling "
"register_deep_copy_op_c_code."
%
t
,
stacklevel
=
2
)
stacklevel
=
2
)
return
()
return
()
version
.
append
((
str
(
t
),
v
))
version
.
append
((
str
(
t
),
v
))
...
@@ -284,11 +287,12 @@ class Shape(gof.Op):
...
@@ -284,11 +287,12 @@ class Shape(gof.Op):
version
=
[]
version
=
[]
# If any of the c code is unversionned, we have to return ()
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for Shape, but it has "
warnings
.
warn
(
"Type
%
s has C code for Shape, but it has
no
"
"no version. You should add a 'version' keyword arg
"
"version. You should add a 'version' keyword
"
"
when calling register_shape_c_code."
%
t
,
"arg
when calling register_shape_c_code."
%
t
,
stacklevel
=
2
)
stacklevel
=
2
)
return
()
return
()
version
.
append
((
str
(
t
),
v
))
version
.
append
((
str
(
t
),
v
))
...
@@ -301,7 +305,6 @@ class Shape(gof.Op):
...
@@ -301,7 +305,6 @@ class Shape(gof.Op):
shape
=
Shape
()
shape
=
Shape
()
_shape
=
shape
# was used in the past, now use shape directly.
_shape
=
shape
# was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape'))
class
Shape_i
(
gof
.
Op
):
class
Shape_i
(
gof
.
Op
):
...
@@ -389,8 +392,11 @@ class Shape_i(gof.Op):
...
@@ -389,8 +392,11 @@ class Shape_i(gof.Op):
return
[()]
return
[()]
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
return
[
theano
.
gradient
.
grad_not_implemented
(
op
=
self
,
x_pos
=
0
,
x
=
inp
[
0
],
return
[
theano
.
gradient
.
grad_not_implemented
(
comment
=
"No gradient for the shape of a matrix is implemented."
)]
op
=
self
,
x_pos
=
0
,
x
=
inp
[
0
],
comment
=
(
"No gradient for the shape of a matrix "
"is implemented."
))]
def
shape_i
(
var
,
i
,
fgraph
=
None
):
def
shape_i
(
var
,
i
,
fgraph
=
None
):
"""Equivalent of var.shape[i], but apply if possible the shape
"""Equivalent of var.shape[i], but apply if possible the shape
...
@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None):
...
@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None):
def
register_shape_i_c_code
(
typ
,
code
,
check_input
,
version
=
()):
def
register_shape_i_c_code
(
typ
,
code
,
check_input
,
version
=
()):
""" Tell Shape_i how to generate C code for a Theano Type
""" Tell Shape_i how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an
:param typ: A Theano type. It must be the Theano class itself and not
instance of the class.
an instance of the class.
:param code: C code that gets the shape of dimensions
%(i)
s for the Theano type 'typ'.
:param code: C code that gets the shape of dimensions
%(i)
s for the
Theano type 'typ'.
Use
%(iname)
s and
%(oname)
s for the input and output C
Use
%(iname)
s and
%(oname)
s for the input and output C
variable names respectively.
variable names respectively.
:param version: A number indicating the version of the code, for cache.
:param version: A number indicating the version of the code, for cache.
...
@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op):
...
@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op):
return
type
(
self
)
==
type
(
other
)
and
self
.
axis
==
other
.
axis
return
type
(
self
)
==
type
(
other
)
and
self
.
axis
==
other
.
axis
def
__hash__
(
self
):
def
__hash__
(
self
):
items
=
sorted
(
self
.
axis
.
iteritems
())
# no ambiguity because each item key is unique
# no ambiguity because each item key is unique
items
=
sorted
(
self
.
axis
.
iteritems
())
return
hash
((
type
(
self
),
tuple
(
items
)))
return
hash
((
type
(
self
),
tuple
(
items
)))
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op):
...
@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op):
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
if
self
.
axis
.
keys
()
and
(
x
.
ndim
<=
numpy
.
max
(
self
.
axis
.
keys
())):
if
self
.
axis
.
keys
()
and
(
x
.
ndim
<=
numpy
.
max
(
self
.
axis
.
keys
())):
raise
ValueError
(
'Trying to rebroadcast non-existent dimension'
)
raise
ValueError
(
'Trying to rebroadcast non-existent dimension'
)
t
=
x
.
type
.
clone
(
broadcastable
=
[
self
.
axis
.
get
(
i
,
b
)
t
=
x
.
type
.
clone
(
for
i
,
b
in
enumerate
(
broadcastable
=
[
self
.
axis
.
get
(
i
,
b
)
x
.
type
.
broadcastable
)])
for
i
,
b
in
enumerate
(
x
.
type
.
broadcastable
)])
return
gof
.
Apply
(
self
,
[
x
],
[
t
()])
return
gof
.
Apply
(
self
,
[
x
],
[
t
()])
def
perform
(
self
,
node
,
inp
,
out_
):
def
perform
(
self
,
node
,
inp
,
out_
):
...
@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op):
...
@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for Rebroadcast, but it has "
warnings
.
warn
(
"Type
%
s has C code for Rebroadcast, but it "
"no version. You should add a 'version' keyword arg "
"has no version. You should add a 'version' "
"when calling register_rebroadcast_c_code."
%
t
,
"keyword arg when calling "
"register_rebroadcast_c_code."
%
t
,
stacklevel
=
2
)
stacklevel
=
2
)
return
()
return
()
version
.
append
((
str
(
t
),
v
))
version
.
append
((
str
(
t
),
v
))
...
@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(),
...
@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(),
c_support_code_apply
=
None
):
c_support_code_apply
=
None
):
""" Tell SpecifyShape how to generate C code for a Theano Type
""" Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and
not an
:param typ: A Theano type. It must be the Theano class itself and
instance of the class.
not an
instance of the class.
:param code: C code that checks the shape and returns a view for
the Theano type 'typ'.
:param code: C code that checks the shape and returns a view for
Use
%(iname)
s and
%(oname)
s for the input and output C
the Theano type 'typ'. Use
%(iname)
s and
%(oname)
s
variable names respectively.
for the input and output C variable names
%(shape)
s is the vector of shape of
%(iname)
s.
respectively.
%(shape)
s is the vector of shape of
Check that its length is good.
%(iname)
s.
Check that its length is good.
:param version: A number indicating the version of the code, for cache.
:param version: A number indicating the version of the code, for cache.
:param c_support_code_apply: extra code.
:param c_support_code_apply: extra code.
"""
"""
SpecifyShape
.
c_code_and_version
[
typ
]
=
(
code
,
version
,
c_support_code_apply
)
SpecifyShape
.
c_code_and_version
[
typ
]
=
(
code
,
version
,
c_support_code_apply
)
class
SpecifyShape
(
gof
.
Op
):
class
SpecifyShape
(
gof
.
Op
):
...
@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op):
...
@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op):
new_shape
=
[]
new_shape
=
[]
for
dim
in
xrange
(
node
.
inputs
[
0
]
.
ndim
):
for
dim
in
xrange
(
node
.
inputs
[
0
]
.
ndim
):
try
:
try
:
s
=
theano
.
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
theano
.
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
theano
.
tensor
.
as_tensor_variable
(
s
)
s
=
theano
.
tensor
.
as_tensor_variable
(
s
)
new_shape
.
append
(
s
)
new_shape
.
append
(
s
)
except
theano
.
tensor
.
NotScalarConstantError
:
except
theano
.
tensor
.
NotScalarConstantError
:
...
@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op):
...
@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op):
code
,
version
,
_
=
self
.
c_code_and_version
[
itype
]
code
,
version
,
_
=
self
.
c_code_and_version
[
itype
]
return
code
%
locals
()
return
code
%
locals
()
return
super
(
SpecifyShape
,
self
)
.
c_code
(
node
,
node
,
inames
,
onames
,
sub
)
return
super
(
SpecifyShape
,
self
)
.
c_code
(
node
,
node
,
inames
,
onames
,
sub
)
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
version
=
[]
version
=
[]
...
@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op):
...
@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op):
for
t
,
(
c
,
v
,
_
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
for
t
,
(
c
,
v
,
_
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for SpecifyShape, but it has "
warnings
.
warn
(
"Type
%
s has C code for SpecifyShape, but it "
"no version. You should add a 'version' keyword arg "
"has no version. You should add a 'version' "
"when calling register_specify_shape_c_code."
%
t
,
"keyword arg when calling "
"register_specify_shape_c_code."
%
t
,
stacklevel
=
2
)
stacklevel
=
2
)
return
()
return
()
version
.
append
((
str
(
t
),
v
))
version
.
append
((
str
(
t
),
v
))
...
...
theano/compile/pfunc.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/profilemode.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/profiling.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/sharedvalue.py
浏览文件 @
4cf7afb4
"""Provide a simple user friendly API to Theano-managed memory"""
"""Provide a simple user friendly API to Theano-managed memory"""
__docformat__
=
'restructuredtext en'
# Standard imports
# Standard imports
import
copy
import
copy
import
logging
import
logging
...
@@ -12,6 +10,7 @@ import numpy
...
@@ -12,6 +10,7 @@ import numpy
from
theano.gof
import
Container
,
Variable
,
generic
,
utils
from
theano.gof
import
Container
,
Variable
,
generic
,
utils
_logger
=
logging
.
getLogger
(
'theano.compile.sharedvalue'
)
_logger
=
logging
.
getLogger
(
'theano.compile.sharedvalue'
)
__docformat__
=
'restructuredtext en'
class
SharedVariable
(
Variable
):
class
SharedVariable
(
Variable
):
...
@@ -49,7 +48,8 @@ class SharedVariable(Variable):
...
@@ -49,7 +48,8 @@ class SharedVariable(Variable):
or copied, so they must have the correct type.
or copied, so they must have the correct type.
:param allow_downcast: Only applies if `strict` is False.
:param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment.
True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
None -> only allow downcasting of a Python float to a scalar floatX.
...
@@ -65,12 +65,13 @@ class SharedVariable(Variable):
...
@@ -65,12 +65,13 @@ class SharedVariable(Variable):
if
container
is
not
None
:
if
container
is
not
None
:
self
.
container
=
container
self
.
container
=
container
if
(
value
is
not
None
)
or
(
strict
is
not
None
):
if
(
value
is
not
None
)
or
(
strict
is
not
None
):
raise
TypeError
(
raise
TypeError
(
'value and strict are ignored if you pass '
'value and strict are ignored if you pass
a container here'
)
'
a container here'
)
else
:
else
:
if
container
is
not
None
:
if
container
is
not
None
:
raise
TypeError
(
'Error to specify both value and container'
)
raise
TypeError
(
'Error to specify both value and container'
)
self
.
container
=
Container
(
self
,
self
.
container
=
Container
(
self
,
storage
=
[
type
.
filter
(
value
,
strict
=
strict
,
storage
=
[
type
.
filter
(
value
,
strict
=
strict
,
allow_downcast
=
allow_downcast
)],
allow_downcast
=
allow_downcast
)],
readonly
=
False
,
readonly
=
False
,
...
@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
...
@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
potential constructors to those that can accept those kwargs.
potential constructors to those that can accept those kwargs.
:note: Some shared variable have ``borrow`` as extra kwargs.
:note: Some shared variable have ``borrow`` as extra kwargs.
`See <http://deeplearning.net/software/theano/tutorial/aliasing.html#borrowing-when-creating-shared-variables>`_ for detail.
`See <http://deeplearning.net/software/theano/tutorial/aliasing.
\
html#borrowing-when-creating-shared-variables>`_ for detail.
:note: Some shared variable have ``broadcastable`` as extra kwargs.
:note: Some shared variable have ``broadcastable`` as extra kwargs.
As shared variable shapes can change, all dimensions default
As shared variable shapes can change, all dimensions default
...
@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
...
@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
try
:
try
:
if
isinstance
(
value
,
Variable
):
if
isinstance
(
value
,
Variable
):
raise
TypeError
(
" Shared variable constructor needs numeric values and not symbolic variables."
)
raise
TypeError
(
"Shared variable constructor needs numeric "
"values and not symbolic variables."
)
for
ctor
in
reversed
(
shared
.
constructors
):
for
ctor
in
reversed
(
shared
.
constructors
):
try
:
try
:
...
...
theano/gof/callcache.py
浏览文件 @
4cf7afb4
import
cPickle
,
logging
import
cPickle
import
logging
_logger
=
logging
.
getLogger
(
"theano.gof.callcache"
)
_logger
=
logging
.
getLogger
(
"theano.gof.callcache"
)
...
@@ -18,9 +19,6 @@ class CallCache(object):
...
@@ -18,9 +19,6 @@ class CallCache(object):
def
persist
(
self
,
filename
=
None
):
def
persist
(
self
,
filename
=
None
):
if
filename
is
None
:
if
filename
is
None
:
filename
=
self
.
filename
filename
=
self
.
filename
# backport
#filename = self.filename if filename is None else filename
f
=
open
(
filename
,
'w'
)
f
=
open
(
filename
,
'w'
)
cPickle
.
dump
(
self
.
cache
,
f
)
cPickle
.
dump
(
self
.
cache
,
f
)
f
.
close
()
f
.
close
()
...
@@ -28,9 +26,6 @@ class CallCache(object):
...
@@ -28,9 +26,6 @@ class CallCache(object):
def
call
(
self
,
fn
,
args
=
(),
key
=
None
):
def
call
(
self
,
fn
,
args
=
(),
key
=
None
):
if
key
is
None
:
if
key
is
None
:
key
=
(
fn
,
tuple
(
args
))
key
=
(
fn
,
tuple
(
args
))
# backport
#key = (fn, tuple(args)) if key is None else key
if
key
not
in
self
.
cache
:
if
key
not
in
self
.
cache
:
_logger
.
debug
(
'cache miss
%
i'
,
len
(
self
.
cache
))
_logger
.
debug
(
'cache miss
%
i'
,
len
(
self
.
cache
))
self
.
cache
[
key
]
=
fn
(
*
args
)
self
.
cache
[
key
]
=
fn
(
*
args
)
...
...
theano/gof/compiledir.py
浏览文件 @
4cf7afb4
...
@@ -8,7 +8,6 @@ import re
...
@@ -8,7 +8,6 @@ import re
import
shutil
import
shutil
import
struct
import
struct
import
socket
import
socket
import
subprocess
import
sys
import
sys
import
textwrap
import
textwrap
...
@@ -295,7 +294,8 @@ def cleanup():
...
@@ -295,7 +294,8 @@ def cleanup():
have_npy_abi_version
=
True
have_npy_abi_version
=
True
elif
obj
.
startswith
(
'c_compiler_str='
):
elif
obj
.
startswith
(
'c_compiler_str='
):
have_c_compiler
=
True
have_c_compiler
=
True
elif
(
isinstance
(
obj
,
(
theano
.
gof
.
Op
,
theano
.
gof
.
Type
))
and
elif
(
isinstance
(
obj
,
(
theano
.
gof
.
Op
,
theano
.
gof
.
Type
))
and
hasattr
(
obj
,
'c_code_cache_version'
)):
hasattr
(
obj
,
'c_code_cache_version'
)):
v
=
obj
.
c_code_cache_version
()
v
=
obj
.
c_code_cache_version
()
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
...
@@ -310,7 +310,7 @@ def cleanup():
...
@@ -310,7 +310,7 @@ def cleanup():
if
keydata
.
key_pkl
!=
filename
:
if
keydata
.
key_pkl
!=
filename
:
keydata
.
key_pkl
=
filename
keydata
.
key_pkl
=
filename
keydata
.
remove_key
(
key
)
keydata
.
remove_key
(
key
)
except
IOError
as
e
:
except
IOError
:
_logger
.
error
(
_logger
.
error
(
"Could not remove file '
%
s'. To complete "
"Could not remove file '
%
s'. To complete "
"the clean-up, please remove manually "
"the clean-up, please remove manually "
...
@@ -395,7 +395,7 @@ def print_compiledir_content():
...
@@ -395,7 +395,7 @@ def print_compiledir_content():
if
big_key_files
:
if
big_key_files
:
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_total_size
=
sum
([
s
ize
for
dir
,
size
,
ops
in
big_key_files
])
big_total_size
=
sum
([
s
z
for
_
,
sz
,
_
in
big_key_files
])
print
((
"There are directories with key files bigger than
%
d bytes "
print
((
"There are directories with key files bigger than
%
d bytes "
"(they probably contain big tensor constants)"
%
"(they probably contain big tensor constants)"
%
max_key_file_size
))
max_key_file_size
))
...
...
theano/gof/compilelock.py
浏览文件 @
4cf7afb4
...
@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw):
...
@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw):
# the lock state and raise an error.
# the lock state and raise an error.
while
get_lock
.
n_lock
>
0
:
while
get_lock
.
n_lock
>
0
:
release_lock
()
release_lock
()
raise
Exception
(
"For some unknow reason, the lock was already
taken,
"
raise
Exception
(
"For some unknow reason, the lock was already "
" but no start time was registered."
)
"
taken,
but no start time was registered."
)
now
=
time
.
time
()
now
=
time
.
time
()
if
now
-
get_lock
.
start_time
>
config
.
compile
.
timeout
/
2
:
if
now
-
get_lock
.
start_time
>
config
.
compile
.
timeout
/
2
:
lockpath
=
os
.
path
.
join
(
get_lock
.
lock_dir
,
'lock'
)
lockpath
=
os
.
path
.
join
(
get_lock
.
lock_dir
,
'lock'
)
...
...
theano/gof/cutils.py
浏览文件 @
4cf7afb4
...
@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
...
@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def
compile_cutils_code
():
def
compile_cutils_code
():
types
=
[
'npy_'
+
t
for
t
in
[
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'int128'
,
types
=
[
'npy_'
+
t
for
t
in
[
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'int128'
,
'int256'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
,
'uint128'
,
'uint256'
,
'int256'
,
'uint8'
,
'uint16'
,
'uint32'
,
'float16'
,
'float32'
,
'float64'
,
'float80'
,
'float96'
,
'float128'
,
'uint64'
,
'uint128'
,
'uint256'
,
'float16'
,
'float32'
,
'float64'
,
'float80'
,
'float96'
,
'float128'
,
'float256'
]]
'float256'
]]
complex_types
=
[
'npy_'
+
t
for
t
in
[
'complex32'
,
'complex64'
,
complex_types
=
[
'npy_'
+
t
for
t
in
[
'complex32'
,
'complex64'
,
'complex128'
,
'complex160'
,
'complex192'
,
'complex512'
]]
'complex128'
,
'complex160'
,
'complex192'
,
'complex512'
]]
inplace_map_template
=
"""
inplace_map_template
=
"""
#if defined(
%(typen)
s)
#if defined(
%(typen)
s)
static void
%(type)
s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set)
static void
%(type)
s_inplace_add(PyArrayMapIterObject *mit,
PyArrayIterObject *it, int inc_or_set)
{
{
int index = mit->size;
int index = mit->size;
while (index--) {
while (index--) {
...
@@ -38,10 +41,13 @@ def compile_cutils_code():
...
@@ -38,10 +41,13 @@ def compile_cutils_code():
#endif
#endif
"""
"""
floatadd
=
"((
%(type)
s*)mit->dataptr)[0] = inc_or_set * ((
%(type)
s*)mit->dataptr)[0] + ((
%(type)
s*)it->dataptr)[0];"
floatadd
=
(
"((
%(type)
s*)mit->dataptr)[0] = inc_or_set * "
"((
%(type)
s*)mit->dataptr)[0] + ((
%(type)
s*)it->dataptr)[0];"
)
complexadd
=
"""
complexadd
=
"""
((
%(type)
s*)mit->dataptr)[0].real = inc_or_set * ((
%(type)
s*)mit->dataptr)[0].real + ((
%(type)
s*)it->dataptr)[0].real;
((
%(type)
s*)mit->dataptr)[0].real = inc_or_set *
((
%(type)
s*)mit->dataptr)[0].imag = inc_or_set * ((
%(type)
s*)mit->dataptr)[0].imag + ((
%(type)
s*)it->dataptr)[0].imag;
((
%(type)
s*)mit->dataptr)[0].real + ((
%(type)
s*)it->dataptr)[0].real;
((
%(type)
s*)mit->dataptr)[0].imag = inc_or_set *
((
%(type)
s*)mit->dataptr)[0].imag + ((
%(type)
s*)it->dataptr)[0].imag;
"""
"""
fns
=
''
.
join
([
inplace_map_template
%
{
'type'
:
t
,
'typen'
:
t
.
upper
(),
fns
=
''
.
join
([
inplace_map_template
%
{
'type'
:
t
,
'typen'
:
t
.
upper
(),
...
@@ -51,33 +57,36 @@ def compile_cutils_code():
...
@@ -51,33 +57,36 @@ def compile_cutils_code():
'op'
:
complexadd
%
{
'type'
:
t
}}
'op'
:
complexadd
%
{
'type'
:
t
}}
for
t
in
complex_types
])
for
t
in
complex_types
])
def
gen_binop
(
type
,
typen
):
return
"""
#if defined(
%(typen)
s)
%(type)
s_inplace_add,
#endif
"""
%
dict
(
type
=
type
,
typen
=
typen
)
fn_array
=
(
"static inplace_map_binop addition_funcs[] = {"
+
fn_array
=
(
"static inplace_map_binop addition_funcs[] = {"
+
''
.
join
([
"""
''
.
join
([
gen_binop
(
type
=
t
,
typen
=
t
.
upper
())
#if defined(
%(typen)
s)
for
t
in
types
+
complex_types
])
+
"NULL};
\n
"
)
%(type)
s_inplace_add,
#endif
def
gen_num
(
typen
):
"""
%
{
'type'
:
t
,
'typen'
:
t
.
upper
()}
return
"""
for
t
in
types
+
complex_types
])
+
#if defined(
%(typen)
s)
"""NULL};
%(typen)
s,
"""
)
#endif
"""
%
dict
(
type
=
type
,
typen
=
typen
)
type_number_array
=
(
"static int type_numbers[] = {"
+
type_number_array
=
(
"static int type_numbers[] = {"
+
''
.
join
([
"""
''
.
join
([
gen_num
(
typen
=
t
.
upper
())
#if defined(
%(typen)
s)
for
t
in
types
+
complex_types
])
+
"-1000};"
)
%(typen)
s,
#endif
"""
%
{
'type'
:
t
,
'typen'
:
t
.
upper
()}
for
t
in
types
+
complex_types
])
+
"-1000};"
)
code
=
(
"""
code
=
(
"""
#if NPY_API_VERSION >= 0x00000008
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int inc_or_set);
typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
"""
+
fns
+
fn_array
+
type_number_array
+
PyArrayIterObject *, int inc_or_set);
"""
+
fns
+
fn_array
+
type_number_array
+
"""
"""
static int
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set)
map_increment(PyArrayMapIterObject *mit, PyObject *op,
inplace_map_binop add_inplace, int inc_or_set)
{
{
PyArrayObject *arr = NULL;
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
PyArrayIterObject *it;
...
@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args)
...
@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args)
return NULL;
return NULL;
}
}
if (!PyArray_Check(arg_a)) {
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
PyErr_SetString(PyExc_ValueError,
"needs an ndarray as first argument");
return NULL;
return NULL;
}
}
...
@@ -285,7 +295,7 @@ try:
...
@@ -285,7 +295,7 @@ try:
open
(
os
.
path
.
join
(
location
,
'__init__.py'
),
'w'
)
.
close
()
open
(
os
.
path
.
join
(
location
,
'__init__.py'
),
'w'
)
.
close
()
try
:
try
:
from
cutils_ext.cutils_ext
import
*
from
cutils_ext.cutils_ext
import
*
# noqa
except
ImportError
:
except
ImportError
:
get_lock
()
get_lock
()
# Ensure no-one else is currently modifying the content of the compilation
# Ensure no-one else is currently modifying the content of the compilation
...
@@ -296,11 +306,11 @@ try:
...
@@ -296,11 +306,11 @@ try:
# We must retry to import it as some other process could
# We must retry to import it as some other process could
# have been compiling it between the first failed import
# have been compiling it between the first failed import
# and when we receive the lock
# and when we receive the lock
from
cutils_ext.cutils_ext
import
*
from
cutils_ext.cutils_ext
import
*
# noqa
except
ImportError
:
except
ImportError
:
compile_cutils
()
compile_cutils
()
from
cutils_ext.cutils_ext
import
*
from
cutils_ext.cutils_ext
import
*
# noqa
finally
:
finally
:
# Release lock on compilation directory.
# Release lock on compilation directory.
...
...
theano/gof/lazylinker_c.py
浏览文件 @
4cf7afb4
...
@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
...
@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
force_compile
=
False
force_compile
=
False
version
=
0.21
# must match constant returned in function get_version()
version
=
0.21
# must match constant returned in function get_version()
lazylinker_ext
=
None
def
try_import
():
def
try_import
():
global
lazylinker_ext
global
lazylinker_ext
sys
.
path
[
0
:
0
]
=
[
config
.
compiledir
]
sys
.
path
[
0
:
0
]
=
[
config
.
compiledir
]
import
lazylinker_ext
import
lazylinker_ext
# noqa
del
sys
.
path
[
0
]
del
sys
.
path
[
0
]
...
@@ -43,11 +44,11 @@ try:
...
@@ -43,11 +44,11 @@ try:
# Try to make the location
# Try to make the location
os
.
mkdir
(
location
)
os
.
mkdir
(
location
)
except
OSError
as
e
:
except
OSError
as
e
:
# If we get an error, verify that the error was # 17, the
path already exists,
# If we get an error, verify that the error was # 17, the
#
and that it is a directory
#
path already exists, and that it is a directory Note: we
#
Note: we can't check if it exists before making it, because we are not holding
#
can't check if it exists before making it, because we
#
the lock right now, so we could race another process and get error 17 if we los
e
#
are not holding the lock right now, so we could rac
e
# the race
#
another process and get error 17 if we lose
the race
assert
e
.
errno
==
errno
.
EEXIST
assert
e
.
errno
==
errno
.
EEXIST
assert
os
.
path
.
isdir
(
location
)
assert
os
.
path
.
isdir
(
location
)
...
@@ -142,5 +143,5 @@ except ImportError:
...
@@ -142,5 +143,5 @@ except ImportError:
# Release lock on compilation directory.
# Release lock on compilation directory.
release_lock
()
release_lock
()
from
lazylinker_ext.lazylinker_ext
import
*
from
lazylinker_ext.lazylinker_ext
import
*
# noqa
assert
force_compile
or
(
version
==
get_version
())
assert
force_compile
or
(
version
==
get_version
())
theano/gof/optdb.py
浏览文件 @
4cf7afb4
...
@@ -32,7 +32,7 @@ class DB(object):
...
@@ -32,7 +32,7 @@ class DB(object):
self
.
__db__
=
DefaultOrderedDict
(
OrderedSet
)
self
.
__db__
=
DefaultOrderedDict
(
OrderedSet
)
self
.
_names
=
set
()
self
.
_names
=
set
()
self
.
name
=
None
# will be reset by register
self
.
name
=
None
# will be reset by register
#(via obj.name by the thing doing the registering)
#
(via obj.name by the thing doing the registering)
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
"""
"""
...
@@ -175,8 +175,10 @@ class Query(object):
...
@@ -175,8 +175,10 @@ class Query(object):
self
.
exclude
=
OrderedSet
(
self
.
exclude
)
self
.
exclude
=
OrderedSet
(
self
.
exclude
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"Query{inc=
%
s,ex=
%
s,require=
%
s,subquery=
%
s,position_cutoff=
%
d}"
%
(
return
(
"Query{inc=
%
s,ex=
%
s,require=
%
s,subquery=
%
s,"
self
.
include
,
self
.
exclude
,
self
.
require
,
self
.
subquery
,
self
.
position_cutoff
)
"position_cutoff=
%
d}"
%
(
self
.
include
,
self
.
exclude
,
self
.
require
,
self
.
subquery
,
self
.
position_cutoff
))
# add all opt with this tag
# add all opt with this tag
def
including
(
self
,
*
tags
):
def
including
(
self
,
*
tags
):
...
@@ -268,7 +270,7 @@ class SequenceDB(DB):
...
@@ -268,7 +270,7 @@ class SequenceDB(DB):
position_cutoff
=
kwtags
.
pop
(
'position_cutoff'
,
position_cutoff
=
kwtags
.
pop
(
'position_cutoff'
,
config
.
optdb
.
position_cutoff
)
config
.
optdb
.
position_cutoff
)
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
# the call to super should have raise an error with a good message
# the call to super should have raise an error with a good message
assert
len
(
tags
)
==
1
assert
len
(
tags
)
==
1
if
getattr
(
tags
[
0
],
'position_cutoff'
,
None
):
if
getattr
(
tags
[
0
],
'position_cutoff'
,
None
):
position_cutoff
=
tags
[
0
]
.
position_cutoff
position_cutoff
=
tags
[
0
]
.
position_cutoff
...
...
theano/gof/sched.py
浏览文件 @
4cf7afb4
...
@@ -39,8 +39,8 @@ def make_depends():
...
@@ -39,8 +39,8 @@ def make_depends():
def
depends
(
pair
):
def
depends
(
pair
):
""" Returns True if a depends on b """
""" Returns True if a depends on b """
a
,
b
=
pair
a
,
b
=
pair
return
(
any
(
bout
in
a
.
inputs
for
bout
in
b
.
outputs
)
return
(
any
(
bout
in
a
.
inputs
for
bout
in
b
.
outputs
)
or
or
any
(
depends
((
ainp
.
owner
,
b
))
for
ainp
in
a
.
inputs
any
(
depends
((
ainp
.
owner
,
b
))
for
ainp
in
a
.
inputs
if
ainp
.
owner
))
if
ainp
.
owner
))
return
depends
return
depends
...
@@ -160,12 +160,12 @@ def posort(l, *cmps):
...
@@ -160,12 +160,12 @@ def posort(l, *cmps):
for
b
in
l
:
for
b
in
l
:
assert
not
(
b
in
comes_after
[
a
]
and
a
in
comes_after
[
b
])
assert
not
(
b
in
comes_after
[
a
]
and
a
in
comes_after
[
b
])
for
cmp
in
cmps
:
for
cmp
_fn
in
cmps
:
for
a
in
l
:
for
a
in
l
:
for
b
in
l
:
for
b
in
l
:
if
cmp
(
a
,
b
)
<
0
:
# a wants to come before b
if
cmp
_fn
(
a
,
b
)
<
0
:
# a wants to come before b
# if this wouldn't cause a cycle and isn't already known
# if this wouldn't cause a cycle and isn't already known
if
not
b
in
comes_before
[
a
]
and
not
b
in
comes_after
[
a
]:
if
b
not
in
comes_before
[
a
]
and
b
not
in
comes_after
[
a
]:
add_links
(
a
,
b
)
add_links
(
a
,
b
)
# check() # debug code
# check() # debug code
...
...
theano/gof/tests/test_utils.py
浏览文件 @
4cf7afb4
...
@@ -36,8 +36,11 @@ def test_give_variables_names_small():
...
@@ -36,8 +36,11 @@ def test_give_variables_names_small():
def
test_remove
():
def
test_remove
():
even
=
lambda
x
:
x
%
2
==
0
def
even
(
x
):
odd
=
lambda
x
:
x
%
2
==
1
return
x
%
2
==
0
# The list are neede as with python 3, remove and filter return generators
def
odd
(
x
):
return
x
%
2
==
1
# The list are needed as with python 3, remove and filter return generators
# and we can't compare generators.
# and we can't compare generators.
assert
list
(
remove
(
even
,
range
(
5
)))
==
list
(
filter
(
odd
,
range
(
5
)))
assert
list
(
remove
(
even
,
range
(
5
)))
==
list
(
filter
(
odd
,
range
(
5
)))
theano/gof/toolbox.py
浏览文件 @
4cf7afb4
...
@@ -214,9 +214,9 @@ class Validator(Feature):
...
@@ -214,9 +214,9 @@ class Validator(Feature):
class
ReplaceValidate
(
History
,
Validator
):
class
ReplaceValidate
(
History
,
Validator
):
pickle_rm_attr
=
[
"replace_validate"
,
"replace_all_validate"
,
pickle_rm_attr
=
(
[
"replace_validate"
,
"replace_all_validate"
,
"replace_all_validate_remove"
]
+
\
"replace_all_validate_remove"
]
+
History
.
pickle_rm_attr
+
Validator
.
pickle_rm_attr
History
.
pickle_rm_attr
+
Validator
.
pickle_rm_attr
)
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
for
attr
in
(
'replace_validate'
,
'replace_all_validate'
,
for
attr
in
(
'replace_validate'
,
'replace_all_validate'
,
...
@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator):
...
@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator):
try
:
try
:
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
,
verbose
=
False
)
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
,
verbose
=
False
)
except
Exception
as
e
:
except
Exception
as
e
:
if
(
'The type of the replacement must be the same'
not
in
msg
=
str
(
e
)
str
(
e
)
and
'does not belong to this FunctionGraph'
not
in
str
(
e
)):
s1
=
'The type of the replacement must be the same'
s2
=
'does not belong to this FunctionGraph'
if
(
s1
not
in
msg
and
s2
not
in
msg
):
out
=
sys
.
stderr
out
=
sys
.
stderr
print
(
"<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>"
,
end
=
' '
,
file
=
out
)
print
(
"<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>"
,
print
(
type
(
e
),
e
,
reason
,
file
=
out
)
type
(
e
),
e
,
reason
,
file
=
out
)
# this might fail if the error is in a listener:
# this might fail if the error is in a listener:
# (fgraph.replace kinda needs better internal error handling)
# (fgraph.replace kinda needs better internal error handling)
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
...
@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator):
...
@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator):
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
if
warn
:
if
warn
:
out
=
sys
.
stderr
out
=
sys
.
stderr
print
(
(
print
(
"WARNING: An optimization wanted to replace a Variable"
"WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't"
" in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization."
" remove it. We disabled the optimization."
" Your function runs correctly, but it would be"
" Your function runs correctly, but it would be"
" appreciated if you submit this problem to the"
" appreciated if you submit this problem to the"
" mailing list theano-users so that we can fix it."
),
file
=
out
)
" mailing list theano-users so that we can fix it."
,
file
=
out
)
print
(
reason
,
replacements
,
file
=
out
)
print
(
reason
,
replacements
,
file
=
out
)
raise
ReplacementDidntRemovedError
()
raise
ReplacementDidntRemovedError
()
...
@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper):
...
@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper):
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
if
self
.
fgraph
is
not
None
:
if
self
.
fgraph
is
not
None
:
raise
Exception
(
"A NodeFinder instance can only serve one FunctionGraph."
)
raise
Exception
(
"A NodeFinder instance can only serve one "
"FunctionGraph."
)
if
hasattr
(
fgraph
,
'get_nodes'
):
if
hasattr
(
fgraph
,
'get_nodes'
):
raise
AlreadyThere
(
"NodeFinder is already present or in conflict"
raise
AlreadyThere
(
"NodeFinder is already present or in conflict"
" with another plugin."
)
" with another plugin."
)
...
...
theano/gof/type.py
浏览文件 @
4cf7afb4
"""WRITEME Defines the `Type` class."""
"""WRITEME Defines the `Type` class."""
__docformat__
=
"restructuredtext en"
from
theano.compat
import
PY3
from
theano.compat
import
PY3
from
theano.gof
import
utils
from
theano.gof
import
utils
...
@@ -13,6 +10,8 @@ from theano.gof import graph
...
@@ -13,6 +10,8 @@ from theano.gof import graph
########
########
from
theano.gof.op
import
CLinkerObject
from
theano.gof.op
import
CLinkerObject
__docformat__
=
"restructuredtext en"
class
CLinkerType
(
CLinkerObject
):
class
CLinkerType
(
CLinkerObject
):
"""Interface specification for Types that can be arguments to a `CLinkerOp`.
"""Interface specification for Types that can be arguments to a `CLinkerOp`.
...
@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject):
...
@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method
- `MethodNotDefined`: Subclass does not implement this method
"""
"""
raise
MethodNotDefined
(
"c_literal"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
MethodNotDefined
(
"c_literal"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
"""Required: Return c code to declare variables that will be
"""Required: Return c code to declare variables that will be
...
@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject):
...
@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject):
return "PyObject ** addr_of_
%(name)
s;"
return "PyObject ** addr_of_
%(name)
s;"
:param name: the name of the ``PyObject *`` pointer that will the value for this Type
:param name: the name of the ``PyObject *`` pointer that will
the value for this Type
:type name: string
:type name: string
...
@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject):
...
@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method
- `MethodNotDefined`: Subclass does not implement this method
"""
"""
raise
MethodNotDefined
(
"c_extract"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
MethodNotDefined
(
"c_extract"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
c_extract_out
(
self
,
name
,
sub
,
check_input
=
True
):
def
c_extract_out
(
self
,
name
,
sub
,
check_input
=
True
):
"""Optional: C code to extract a PyObject * instance.
"""Optional: C code to extract a PyObject * instance.
...
@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject):
...
@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject):
def
c_sync
(
self
,
name
,
sub
):
def
c_sync
(
self
,
name
,
sub
):
"""Required: Return c code to pack C types back into a PyObject.
"""Required: Return c code to pack C types back into a PyObject.
The code returned from this function must be templated using "
%(name)
s",
The code returned from this function must be templated using
representing the name that the caller wants to call this Variable. The
"
%(name)
s", representing the name that the caller wants to
returned code may set "py_
%(name)
s" to a PyObject* and that PyObject*
call this Variable. The returned code may set "py_
%(name)
s"
will be accessible from Python via variable.data. Do not forget to adjust
to a PyObject* and that PyObject* will be accessible from
reference counts if "py_
%(name)
s" is changed from its original value.
Python via variable.data. Do not forget to adjust reference
counts if "py_
%(name)
s" is changed from its original value.
:Parameters:
:Parameters:
- `name`: WRITEME
- `name`: WRITEME
...
@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject):
...
@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
"""Return a tuple of integers indicating the version of this Type.
"""Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not be cached between processes.
An empty tuple indicates an 'unversioned' Type that will not
be cached between processes.
The cache mechanism may erase cached modules that have been
superceded by newer
The cache mechanism may erase cached modules that have been
versions. See `ModuleCache` for details.
superceded by newer
versions. See `ModuleCache` for details.
"""
"""
return
()
return
()
...
@@ -221,19 +225,21 @@ class PureType(object):
...
@@ -221,19 +225,21 @@ class PureType(object):
- creating `Variable` instances (conventionally, `__call__` does this), and
- creating `Variable` instances (conventionally, `__call__` does this), and
- filtering a value assigned to a `Variable` so that the value
conforms to restrictions
- filtering a value assigned to a `Variable` so that the value
imposed by the type (also known as casting, this is done by `filter`),
conforms to restrictions imposed by the type (also known as
casting, this is done by `filter`),
"""
"""
# the type that will be created by call to make_variable.
Variable
=
graph
.
Variable
Variable
=
graph
.
Variable
# the type that will be created by call to make_variable.
# the type that will be created by call to make_constant
Constant
=
graph
.
Constant
# the type that will be created by call to make_constant
Constant
=
graph
.
Constant
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
"""Required: Return data or an appropriately wrapped/converted data.
"""Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if
the data is not of an
Subclass implementation should raise a TypeError exception if
acceptable type.
the data is not of an
acceptable type.
If strict is True, the data returned must be the same as the
If strict is True, the data returned must be the same as the
data passed as an argument. If it is False, and allow_downcast
data passed as an argument. If it is False, and allow_downcast
...
@@ -283,7 +289,8 @@ class PureType(object):
...
@@ -283,7 +289,8 @@ class PureType(object):
return
other
return
other
def
is_valid_value
(
self
,
a
):
def
is_valid_value
(
self
,
a
):
"""Required: Return True for any python object `a` that would be a legal value for a Variable of this Type"""
"""Required: Return True for any python object `a` that would be a
legal value for a Variable of this Type"""
try
:
try
:
self
.
filter
(
a
,
strict
=
True
)
self
.
filter
(
a
,
strict
=
True
)
return
True
return
True
...
@@ -291,7 +298,8 @@ class PureType(object):
...
@@ -291,7 +298,8 @@ class PureType(object):
return
False
return
False
def
value_validity_msg
(
self
,
a
):
def
value_validity_msg
(
self
,
a
):
"""Optional: return a message explaining the output of is_valid_value"""
"""Optional: return a message explaining the output of
is_valid_value"""
return
"none"
return
"none"
def
make_variable
(
self
,
name
=
None
):
def
make_variable
(
self
,
name
=
None
):
...
@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType):
...
@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType):
But you are encouraged to write your own, as described in WRITEME.
But you are encouraged to write your own, as described in WRITEME.
The following following code illustrates the use of a Type instance, here tensor.fvector:
The following following code illustrates the use of a Type
instance, here tensor.fvector:
.. code-block:: python
.. code-block:: python
...
@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType):
...
@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType):
# Create a second Variable with the same Type instance
# Create a second Variable with the same Type instance
c = tensor.fvector()
c = tensor.fvector()
Whenever you create a symbolic variable in theano (technically, `Variable`) it will contain a
Whenever you create a symbolic variable in theano (technically,
reference to a Type instance. That reference is typically constant during the lifetime of
`Variable`) it will contain a reference to a Type instance. That
the Variable. Many variables can refer to a single Type instance, as do b and c above. The
reference is typically constant during the lifetime of the
Type instance defines the kind of value which might end up in that variable when executing
Variable. Many variables can refer to a single Type instance, as
a `Function`. In this sense, theano is like a strongly-typed language because the types
do b and c above. The Type instance defines the kind of value
are included in the graph before the values. In our example above, b is a Variable which is
which might end up in that variable when executing a `Function`.
guaranteed to correspond to a numpy.ndarray of rank 1 when we try to do some computations
In this sense, theano is like a strongly-typed language because
the types are included in the graph before the values. In our
example above, b is a Variable which is guaranteed to correspond
to a numpy.ndarray of rank 1 when we try to do some computations
with it.
with it.
Many `Op` instances will raise an exception if they are applied to inputs with incorrect
Many `Op` instances will raise an exception if they are applied to
types. Type references are also useful to do type-checking in pattern-based optimizations.
inputs with incorrect types. Type references are also useful to
do type-checking in pattern-based optimizations.
"""
"""
def
convert_variable
(
self
,
var
):
def
convert_variable
(
self
,
var
):
...
@@ -451,8 +464,8 @@ class Generic(SingletonType):
...
@@ -451,8 +464,8 @@ class Generic(SingletonType):
"""
"""
Represents a generic Python object.
Represents a generic Python object.
This class implements the `PureType` and `CLinkerType` interfaces
for generic PyObject
This class implements the `PureType` and `CLinkerType` interfaces
instances.
for generic PyObject
instances.
EXAMPLE of what this means, or when you would use this type.
EXAMPLE of what this means, or when you would use this type.
...
...
theano/gof/utils.py
浏览文件 @
4cf7afb4
from
__future__
import
print_function
from
__future__
import
print_function
import
linecache
import
linecache
import
traceback
import
traceback
import
re
import
sys
import
sys
from
theano
import
config
from
theano
import
config
...
@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None):
...
@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None):
This is because this update cause an call to os.stat to get the
This is because this update cause an call to os.stat to get the
line content. This cause too much long on cluster.
line content. This cause too much long on cluster.
"""
"""
if
f
is
None
:
if
f
is
None
:
try
:
try
:
...
@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4):
...
@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4):
# I enable my implementation only for some python version just to
# I enable my implementation only for some python version just to
# be sure the Python internal do not change. If this work with
# be sure the Python internal do not change. If this work with
# other python version, you can enable it.
# other python version, you can enable it.
simple_extract_stack
=
traceback
.
extract_stack
simple_extract_stack
=
traceback
.
extract_stack
# noqa
def
add_tag_trace
(
thing
,
user_line
=
1
):
def
add_tag_trace
(
thing
,
user_line
=
1
):
...
@@ -190,8 +188,8 @@ def deprecated(filename, msg=''):
...
@@ -190,8 +188,8 @@ def deprecated(filename, msg=''):
def
g
(
*
args
,
**
kwargs
):
def
g
(
*
args
,
**
kwargs
):
if
printme
[
0
]:
if
printme
[
0
]:
print
(
'WARNING:
%
s.
%
s deprecated.
%
s'
\
print
(
'WARNING:
%
s.
%
s deprecated.
%
s'
%
%
(
filename
,
f
.
__name__
,
msg
))
(
filename
,
f
.
__name__
,
msg
))
printme
[
0
]
=
False
printme
[
0
]
=
False
return
f
(
*
args
,
**
kwargs
)
return
f
(
*
args
,
**
kwargs
)
return
g
return
g
...
@@ -220,7 +218,7 @@ def difference(seq1, seq2):
...
@@ -220,7 +218,7 @@ def difference(seq1, seq2):
raise
Exception
(
'not worth it'
)
raise
Exception
(
'not worth it'
)
set2
=
set
(
seq2
)
set2
=
set
(
seq2
)
return
[
x
for
x
in
seq1
if
x
not
in
set2
]
return
[
x
for
x
in
seq1
if
x
not
in
set2
]
except
Exception
as
e
:
except
Exception
:
# maybe a seq2 element is not hashable
# maybe a seq2 element is not hashable
# maybe seq2 is too short
# maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo
# -> use O(len(seq1) * len(seq2)) algo
...
@@ -311,11 +309,11 @@ def comm_guard(type1, type2):
...
@@ -311,11 +309,11 @@ def comm_guard(type1, type2):
old_f
=
f
.
func_globals
[
f
.
__name__
]
old_f
=
f
.
func_globals
[
f
.
__name__
]
def
new_f
(
arg1
,
arg2
,
*
rest
):
def
new_f
(
arg1
,
arg2
,
*
rest
):
if
(
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
))
\
if
(
(
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
))
and
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg2
,
type2
)):
(
type2
is
ANY_TYPE
or
isinstance
(
arg2
,
type2
)
)):
pass
pass
elif
(
type1
is
ANY_TYPE
or
isinstance
(
arg2
,
type1
))
\
elif
(
(
type1
is
ANY_TYPE
or
isinstance
(
arg2
,
type1
))
and
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg1
,
type2
)):
(
type2
is
ANY_TYPE
or
isinstance
(
arg1
,
type2
)
)):
arg1
,
arg2
=
arg2
,
arg1
arg1
,
arg2
=
arg2
,
arg1
else
:
else
:
return
old_f
(
arg1
,
arg2
,
*
rest
)
return
old_f
(
arg1
,
arg2
,
*
rest
)
...
@@ -337,7 +335,8 @@ def comm_guard(type1, type2):
...
@@ -337,7 +335,8 @@ def comm_guard(type1, type2):
return
type
.
__name__
return
type
.
__name__
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,
type2
)])
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,
type2
)])
+
"
\n
"
+
str
(
f
.
__doc__
or
""
))
"
\n
"
+
str
(
f
.
__doc__
or
""
))
return
new_f
return
new_f
...
@@ -406,15 +405,16 @@ def give_variables_names(variables):
...
@@ -406,15 +405,16 @@ def give_variables_names(variables):
This function is idempotent."""
This function is idempotent."""
names
=
map
(
lambda
var
:
var
.
name
,
variables
)
names
=
map
(
lambda
var
:
var
.
name
,
variables
)
h
=
hist
(
names
)
h
=
hist
(
names
)
bad_var
=
lambda
var
:
not
var
.
name
or
h
[
var
.
name
]
>
1
def
bad_var
(
var
):
return
not
var
.
name
or
h
[
var
.
name
]
>
1
for
i
,
var
in
enumerate
(
filter
(
bad_var
,
variables
)):
for
i
,
var
in
enumerate
(
filter
(
bad_var
,
variables
)):
var
.
name
=
(
var
.
name
or
""
)
+
"_
%
d"
%
i
var
.
name
=
(
var
.
name
or
""
)
+
"_
%
d"
%
i
if
not
unique
(
map
(
str
,
variables
)):
if
not
unique
(
map
(
str
,
variables
)):
raise
ValueError
(
"Not all variables have unique names."
raise
ValueError
(
"Not all variables have unique names. Maybe you've "
"Maybe you've named some of the variables identically"
)
"named some of the variables identically"
)
return
variables
return
variables
...
...
theano/gof/vm.py
浏览文件 @
4cf7afb4
...
@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy',
...
@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy',
in_c_key
=
False
)
in_c_key
=
False
)
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
):
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
):
reallocated_info
=
{}
reallocated_info
=
{}
viewed_by
=
{}
viewed_by
=
{}
for
var
in
fgraph
.
variables
:
for
var
in
fgraph
.
variables
:
...
@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
ins
=
None
ins
=
None
if
dmap
and
idx_o
in
dmap
:
if
dmap
and
idx_o
in
dmap
:
idx_v
=
dmap
[
idx_o
]
idx_v
=
dmap
[
idx_o
]
assert
len
(
assert
len
(
idx_v
)
==
1
,
(
"Here we only support the possibility"
idx_v
)
==
1
,
"Here we only support the possibility to destroy one input"
" to destroy one input"
)
ins
=
node
.
inputs
[
idx_v
[
0
]]
ins
=
node
.
inputs
[
idx_v
[
0
]]
if
vmap
and
idx_o
in
vmap
:
if
vmap
and
idx_o
in
vmap
:
assert
ins
is
None
assert
ins
is
None
idx_v
=
vmap
[
idx_o
]
idx_v
=
vmap
[
idx_o
]
assert
len
(
assert
len
(
idx_v
)
==
1
,
(
"Here we only support the possibility"
idx_v
)
==
1
,
"Here we only support the possibility to view one input"
" to view one input"
)
ins
=
node
.
inputs
[
idx_v
[
0
]]
ins
=
node
.
inputs
[
idx_v
[
0
]]
if
ins
is
not
None
:
if
ins
is
not
None
:
assert
isinstance
(
ins
,
theano
.
Variable
)
assert
isinstance
(
ins
,
theano
.
Variable
)
...
@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for
ins
in
node
.
inputs
:
for
ins
in
node
.
inputs
:
assert
not
(
ins
in
view_of
and
viewed_by
[
ins
])
assert
not
(
ins
in
view_of
and
viewed_by
[
ins
])
if
(
getattr
(
ins
,
'ndim'
,
None
)
==
0
and
not
storage_map
[
ins
][
0
]
if
(
getattr
(
ins
,
'ndim'
,
None
)
==
0
and
not
storage_map
[
ins
][
0
]
and
and
ins
not
in
fgraph
.
outputs
and
ins
.
owner
ins
not
in
fgraph
.
outputs
and
ins
.
owner
and
and
all
([
compute_map_re
[
v
][
0
]
for
v
in
dependencies
.
get
(
ins
,
[])])
all
([
compute_map_re
[
v
][
0
]
and
ins
not
in
allocated
):
for
v
in
dependencies
.
get
(
ins
,
[])])
and
ins
not
in
allocated
):
# Constant Memory cannot be changed
# Constant Memory cannot be changed
# Constant and shared variables' storage_map value is not empty
# Constant and shared variables' storage_map value is not empty
reuse_out
=
None
reuse_out
=
None
...
@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if
reuse_out
:
if
reuse_out
:
break
break
for
out
in
order
[
i
]
.
outputs
:
for
out
in
order
[
i
]
.
outputs
:
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
out
not
in
pre_allocated
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
and
ins
.
type
==
out
.
type
):
out
not
in
pre_allocated
and
ins
.
type
==
out
.
type
):
reuse_out
=
out
reuse_out
=
out
pre_allocated
.
add
(
out
)
pre_allocated
.
add
(
out
)
allocated
.
add
(
ins
)
allocated
.
add
(
ins
)
...
@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if
reuse_out
:
if
reuse_out
:
break
break
for
out
in
order
[
i
]
.
outputs
:
for
out
in
order
[
i
]
.
outputs
:
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
out
not
in
pre_allocated
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
and
ins
.
type
==
out
.
type
):
out
not
in
pre_allocated
and
ins
.
type
==
out
.
type
):
reuse_out
=
out
reuse_out
=
out
pre_allocated
.
add
(
out
)
pre_allocated
.
add
(
out
)
allocated
.
add
(
ins
)
allocated
.
add
(
ins
)
...
@@ -508,7 +512,8 @@ class Stack(VM):
...
@@ -508,7 +512,8 @@ class Stack(VM):
st
=
"c"
st
=
"c"
self
.
variable_strides
[
var
]
=
st
self
.
variable_strides
[
var
]
=
st
except
Exception
:
except
Exception
:
link
.
raise_with_op
(
current_apply
,
link
.
raise_with_op
(
current_apply
,
self
.
thunks
[
self
.
node_idx
[
current_apply
]],
self
.
thunks
[
self
.
node_idx
[
current_apply
]],
storage_map
=
storage_map
)
storage_map
=
storage_map
)
for
o
in
current_apply
.
outputs
:
for
o
in
current_apply
.
outputs
:
...
@@ -521,9 +526,9 @@ class Stack(VM):
...
@@ -521,9 +526,9 @@ class Stack(VM):
for
i
in
current_apply
.
inputs
:
for
i
in
current_apply
.
inputs
:
# Garbage Collection -> check if anybody else uses
# Garbage Collection -> check if anybody else uses
# this input
# this input
if
(
dependencies
[
i
]
if
(
dependencies
[
i
]
and
and
i
.
owner
i
.
owner
and
and
i
not
in
self
.
outputs
):
i
not
in
self
.
outputs
):
if
all
(
compute_map
[
v
][
0
]
if
all
(
compute_map
[
v
][
0
]
for
v
in
dependencies
[
i
]):
for
v
in
dependencies
[
i
]):
storage_map
[
i
][
0
]
=
None
storage_map
[
i
][
0
]
=
None
...
@@ -544,10 +549,13 @@ class Stack(VM):
...
@@ -544,10 +549,13 @@ class Stack(VM):
'destroy_map'
,
'destroy_map'
,
False
)):
False
)):
warnings
.
warn
(
warnings
.
warn
(
"There was a bug that existed in the default Theano configuration,"
"There was a bug that existed in "
" only in the development version between July 5th 2012"
"the default Theano configuration,"
" and July 30th 2012. This was not in a released version."
" only in the development version "
" The bug was affecting this script."
,
"between July 5th 2012 and "
"July 30th 2012. This was not in "
"a released version. The bug was "
"affecting this script."
,
# The stack level is not good when
# The stack level is not good when
# inside a Scan.
# inside a Scan.
stacklevel
=
3
stacklevel
=
3
...
@@ -578,7 +586,8 @@ class Stack(VM):
...
@@ -578,7 +586,8 @@ class Stack(VM):
self
.
call_times
[
current_idx
]
+=
dt
self
.
call_times
[
current_idx
]
+=
dt
except
Exception
:
except
Exception
:
link
.
raise_with_op
(
current_apply
,
link
.
raise_with_op
(
current_apply
,
self
.
thunks
[
self
.
node_idx
[
current_apply
]],
self
.
thunks
[
self
.
node_idx
[
current_apply
]],
storage_map
=
storage_map
)
storage_map
=
storage_map
)
...
@@ -639,7 +648,7 @@ class Stack(VM):
...
@@ -639,7 +648,7 @@ class Stack(VM):
if
self
.
allow_gc
:
if
self
.
allow_gc
:
for
v
in
storage_map
:
for
v
in
storage_map
:
if
v
.
owner
and
not
v
in
self
.
outputs
:
if
v
.
owner
and
v
not
in
self
.
outputs
:
if
compute_map
[
v
][
0
]
==
2
:
if
compute_map
[
v
][
0
]
==
2
:
continue
continue
else
:
else
:
...
@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker):
...
@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker):
vars_idx_inv
[
i
]
=
var
vars_idx_inv
[
i
]
=
var
# put storage_map and compute_map into a int-based scheme
# put storage_map and compute_map into a int-based scheme
n_applies
=
len
(
nodes
)
storage_map_list
=
[
storage_map
[
vars_idx_inv
[
i
]]
storage_map_list
=
[
storage_map
[
vars_idx_inv
[
i
]]
for
i
in
xrange
(
len
(
vars_idx_inv
))]
for
i
in
xrange
(
len
(
vars_idx_inv
))]
compute_map_list
=
[
compute_map
[
vars_idx_inv
[
i
]]
compute_map_list
=
[
compute_map
[
vars_idx_inv
[
i
]]
...
@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker):
...
@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker):
else
:
else
:
dependencies
=
self
.
compute_gc_dependencies
(
storage_map
)
dependencies
=
self
.
compute_gc_dependencies
(
storage_map
)
reallocated_info
=
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
)
reallocated_info
=
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
)
for
node
in
order
:
for
node
in
order
:
try
:
try
:
...
@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker):
...
@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker):
lazy
=
config
.
vm
.
lazy
lazy
=
config
.
vm
.
lazy
if
lazy
is
None
:
if
lazy
is
None
:
lazy
=
not
all
([(
not
th
.
lazy
)
for
th
in
thunks
])
lazy
=
not
all
([(
not
th
.
lazy
)
for
th
in
thunks
])
if
not
(
lazy
or
(
config
.
profile
and
config
.
profile_memory
)
or
self
.
use_cloop
or
self
.
callback
):
if
not
(
lazy
or
(
config
.
profile
and
config
.
profile_memory
)
or
self
.
use_cloop
or
self
.
callback
):
for
pair
in
reallocated_info
.
values
():
for
pair
in
reallocated_info
.
values
():
storage_map
[
pair
[
1
]]
=
storage_map
[
pair
[
0
]]
storage_map
[
pair
[
1
]]
=
storage_map
[
pair
[
0
]]
...
@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker):
...
@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker):
for
node
in
order
:
for
node
in
order
:
clear_after_this_thunk
=
[]
clear_after_this_thunk
=
[]
for
input
in
node
.
inputs
:
for
input
in
node
.
inputs
:
if
(
(
input
in
computed
)
if
(
input
in
computed
and
and
(
input
not
in
fgraph
.
outputs
)
input
not
in
fgraph
.
outputs
and
and
(
node
==
last_user
[
input
])
node
==
last_user
[
input
]
and
and
input
not
in
reallocated_info
.
keys
()):
input
not
in
reallocated_info
.
keys
()):
clear_after_this_thunk
.
append
(
storage_map
[
input
])
clear_after_this_thunk
.
append
(
storage_map
[
input
])
post_thunk_clear
.
append
(
clear_after_this_thunk
)
post_thunk_clear
.
append
(
clear_after_this_thunk
)
else
:
else
:
...
...
theano/sandbox/cuda/opt_util.py
浏览文件 @
4cf7afb4
...
@@ -2,7 +2,6 @@ from functools import wraps
...
@@ -2,7 +2,6 @@ from functools import wraps
import
numpy
import
numpy
import
theano
from
theano
import
scalar
as
scal
,
Constant
from
theano
import
scalar
as
scal
,
Constant
from
theano.gof
import
local_optimizer
from
theano.gof
import
local_optimizer
from
theano.tensor
import
(
DimShuffle
,
get_scalar_constant_value
,
from
theano.tensor
import
(
DimShuffle
,
get_scalar_constant_value
,
...
...
theano/sandbox/cuda/tests/test_fftconv.py
浏览文件 @
4cf7afb4
...
@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt
...
@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt
# Skip tests if cuda_ndarray is not available.
# Skip tests if cuda_ndarray is not available.
from
nose.plugins.skip
import
SkipTest
from
nose.plugins.skip
import
SkipTest
import
theano.sandbox.cuda
as
cuda_ndarray
import
theano.sandbox.cuda
as
cuda_ndarray
if
not
cuda_ndarray
.
cuda_available
:
if
not
cuda_ndarray
.
cuda_available
:
# noqa
raise
SkipTest
(
'Optional package cuda not available'
)
raise
SkipTest
(
'Optional package cuda not available'
)
from
theano.misc.pycuda_init
import
pycuda_available
from
theano.misc.pycuda_init
import
pycuda_available
if
not
pycuda_available
:
if
not
pycuda_available
:
# noqa
raise
SkipTest
(
'Optional package pycuda not available'
)
raise
SkipTest
(
'Optional package pycuda not available'
)
from
theano.sandbox.cuda.fftconv
import
scikits_cuda_available
from
theano.sandbox.cuda.fftconv
import
scikits_cuda_available
if
not
scikits_cuda_available
:
if
not
scikits_cuda_available
:
# noqa
raise
SkipTest
(
'Optional package scikits.cuda not available'
)
raise
SkipTest
(
'Optional package scikits.cuda not available'
)
from
theano.sandbox.cuda
import
float32_shared_constructor
as
shared
from
theano.sandbox.cuda
import
float32_shared_constructor
as
shared
...
...
theano/tensor/tests/_test_mpi_roundtrip.py
浏览文件 @
4cf7afb4
...
@@ -2,13 +2,14 @@
...
@@ -2,13 +2,14 @@
# mpiexec -np 2 python _test_mpi_roundtrip.py
# mpiexec -np 2 python _test_mpi_roundtrip.py
from
mpi4py
import
MPI
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
import
theano
import
theano
from
theano.tensor.io
import
send
,
recv
,
mpi_cmps
from
theano.tensor.io
import
send
,
recv
,
mpi_cmps
from
theano.gof.sched
import
sort_schedule_fn
from
theano.gof.sched
import
sort_schedule_fn
import
numpy
as
np
import
numpy
as
np
from
sys
import
stdout
,
stderr
,
exit
from
sys
import
stdout
,
stderr
,
exit
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
Get_rank
()
rank
=
comm
.
Get_rank
()
size
=
comm
.
Get_size
()
size
=
comm
.
Get_size
()
...
...
theano/tests/disturb_mem.py
浏览文件 @
4cf7afb4
from
datetime
import
datetime
__authors__
=
"Ian Goodfellow"
__authors__
=
"Ian Goodfellow"
__credits__
=
[
"Ian Goodfellow"
]
__credits__
=
[
"Ian Goodfellow"
]
__license__
=
"3-clause BSD"
__license__
=
"3-clause BSD"
__maintainer__
=
"Ian Goodfellow"
__maintainer__
=
"Ian Goodfellow"
__email__
=
"goodfeli@iro"
__email__
=
"goodfeli@iro"
from
datetime
import
datetime
def
disturb_mem
():
def
disturb_mem
():
# Allocate a time-dependent amount of objects to increase
# Allocate a time-dependent amount of objects to increase
...
...
theano/tests/main.py
浏览文件 @
4cf7afb4
from
__future__
import
print_function
from
__future__
import
print_function
import
os
,
unittest
,
sys
import
os
import
nose.plugins.builtin
import
unittest
import
sys
from
nose.config
import
Config
from
nose.config
import
Config
from
nose.plugins.manager
import
PluginManager
from
nose.plugins.manager
import
PluginManager
from
numpy.testing.nosetester
import
import_nose
,
NoseTester
import
nose.plugins.builtin
from
numpy.testing.nosetester
import
NoseTester
from
numpy.testing.noseclasses
import
KnownFailure
,
NumpyTestProgram
from
numpy.testing.noseclasses
import
KnownFailure
,
NumpyTestProgram
...
@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester):
...
@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester):
:type extra_argv: list
:type extra_argv: list
:param extra_argv: List with any extra arguments to pass to nosetests.
:param extra_argv: List with any extra arguments to pass to nosetests.
"""
"""
#self.package_path = os.path.abspath(self.package_path)
#
self.package_path = os.path.abspath(self.package_path)
argv
=
[
__file__
,
self
.
package_path
]
argv
=
[
__file__
,
self
.
package_path
]
argv
+=
[
'--verbosity'
,
str
(
verbose
)]
argv
+=
[
'--verbosity'
,
str
(
verbose
)]
if
extra_argv
:
if
extra_argv
:
...
@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester):
...
@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester):
return
argv
return
argv
def
_show_system_info
(
self
):
def
_show_system_info
(
self
):
nose
=
import_nose
()
import
theano
import
theano
print
(
"Theano version
%
s"
%
theano
.
__version__
)
print
(
"Theano version
%
s"
%
theano
.
__version__
)
theano_dir
=
os
.
path
.
dirname
(
theano
.
__file__
)
theano_dir
=
os
.
path
.
dirname
(
theano
.
__file__
)
...
@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester):
...
@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester):
Takes the same arguments as `test`.
Takes the same arguments as `test`.
"""
"""
# fail with nice error message if nose is not present
nose
=
import_nose
()
# compile argv
# compile argv
argv
=
self
.
_test_argv
(
verbose
,
extra_argv
)
argv
=
self
.
_test_argv
(
verbose
,
extra_argv
)
# numpy way of doing coverage
# numpy way of doing coverage
if
coverage
:
if
coverage
:
argv
+=
[
'--cover-package=
%
s'
%
self
.
package_name
,
'--with-coverage'
,
argv
+=
[
'--cover-package=
%
s'
%
self
.
package_name
,
'--cover-tests'
,
'--cover-inclusive'
,
'--cover-erase'
]
'--with-coverage'
,
'--cover-tests'
,
'--cover-inclusive'
,
'--cover-erase'
]
# Capture output only if needed
# Capture output only if needed
if
not
capture
:
if
not
capture
:
...
@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester):
...
@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester):
:param extra_argv: List with any extra arguments to pass to nosetests.
:param extra_argv: List with any extra arguments to pass to nosetests.
:type coverage: bool
:type coverage: bool
:param coverage: If True, report coverage of Theano code. Default is False.
:param coverage: If True, report coverage of Theano
code. Default is False.
:type capture: bool
:type capture: bool
:param capture: If True, capture the standard output of the tests, like
:param capture: If True, capture the standard output of the tests, like
...
@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester):
...
@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester):
def
main
(
modulename
):
def
main
(
modulename
):
debug
=
False
if
0
:
if
0
:
unittest
.
main
()
unittest
.
main
()
elif
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
]
==
"--debug"
:
elif
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
]
==
"--debug"
:
...
...
theano/tests/test_flake8.py
浏览文件 @
4cf7afb4
...
@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>"
...
@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>"
whitelist_flake8
=
[
whitelist_flake8
=
[
"__init__.py"
,
"__init__.py"
,
"version.py"
,
"tests/test_gradient.py"
,
"tests/test_gradient.py"
,
"tests/test_config.py"
,
"tests/test_config.py"
,
"tests/diverse_tests.py"
,
"tests/diverse_tests.py"
,
...
@@ -31,37 +30,20 @@ whitelist_flake8 = [
...
@@ -31,37 +30,20 @@ whitelist_flake8 = [
"tests/test_record.py"
,
"tests/test_record.py"
,
"tests/__init__.py"
,
"tests/__init__.py"
,
"tests/test_updates.py"
,
"tests/test_updates.py"
,
"tests/main.py"
,
"tests/test_pickle_unpickle_theano_fn.py"
,
"tests/test_pickle_unpickle_theano_fn.py"
,
"tests/test_determinism.py"
,
"tests/test_determinism.py"
,
"tests/record.py"
,
"tests/record.py"
,
"tests/test_printing.py"
,
"tests/test_tutorial.py"
,
"tests/test_tutorial.py"
,
"tests/disturb_mem.py"
,
"tests/unittest_tools.py"
,
"tests/unittest_tools.py"
,
"compile/ops.py"
,
"compile/debugmode.py"
,
"compile/function.py"
,
"compile/pfunc.py"
,
"compile/mode.py"
,
"compile/profilemode.py"
,
"compile/builders.py"
,
"compile/__init__.py"
,
"compile/__init__.py"
,
"compile/profiling.py"
,
"compile/profiling.py"
,
"compile/function_module.py"
,
"compile/sharedvalue.py"
,
"compile/monitormode.py"
,
"compile/io.py"
,
"compile/module.py"
,
"compile/tests/test_builders.py"
,
"compile/tests/test_builders.py"
,
"compile/tests/test_misc.py"
,
"compile/tests/test_misc.py"
,
"compile/tests/test_monitormode.py"
,
"compile/tests/test_monitormode.py"
,
"compile/tests/test_function_module.py"
,
"compile/tests/test_function_module.py"
,
"compile/tests/test_inplace_opt_for_value.py"
,
"compile/tests/test_shared.py"
,
"compile/tests/test_shared.py"
,
"compile/tests/test_ops.py"
,
"compile/tests/test_ops.py"
,
"compile/tests/test_pfunc.py"
,
"compile/tests/test_pfunc.py"
,
"compile/tests/test_module.py"
,
"compile/tests/test_debugmode.py"
,
"compile/tests/test_debugmode.py"
,
"compile/tests/test_profiling.py"
,
"compile/tests/test_profiling.py"
,
"typed_list/type.py"
,
"typed_list/type.py"
,
...
@@ -94,16 +76,13 @@ whitelist_flake8 = [
...
@@ -94,16 +76,13 @@ whitelist_flake8 = [
"tensor/io.py"
,
"tensor/io.py"
,
"tensor/elemwise_cgen.py"
,
"tensor/elemwise_cgen.py"
,
"tensor/raw_random.py"
,
"tensor/raw_random.py"
,
"tensor/randomstreams.py"
,
"tensor/blas_scipy.py"
,
"tensor/blas_scipy.py"
,
"tensor/basic.py"
,
"tensor/basic.py"
,
"tensor/tests/test_subtensor.py"
,
"tensor/tests/test_subtensor.py"
,
"tensor/tests/test_utils.py"
,
"tensor/tests/test_utils.py"
,
"tensor/tests/test_nlinalg.py"
,
"tensor/tests/test_nlinalg.py"
,
"tensor/tests/test_randomstreams.py"
,
"tensor/tests/test_shared_randomstreams.py"
,
"tensor/tests/test_shared_randomstreams.py"
,
"tensor/tests/test_misc.py"
,
"tensor/tests/test_misc.py"
,
"tensor/tests/test_naacl09.py"
,
"tensor/tests/mlp_test.py"
,
"tensor/tests/mlp_test.py"
,
"tensor/tests/test_opt_uncanonicalize.py"
,
"tensor/tests/test_opt_uncanonicalize.py"
,
"tensor/tests/test_opt.py"
,
"tensor/tests/test_opt.py"
,
...
@@ -155,7 +134,6 @@ whitelist_flake8 = [
...
@@ -155,7 +134,6 @@ whitelist_flake8 = [
"sandbox/test_theano_object.py"
,
"sandbox/test_theano_object.py"
,
"sandbox/test_scan.py"
,
"sandbox/test_scan.py"
,
"sandbox/rng_mrg.py"
,
"sandbox/rng_mrg.py"
,
"sandbox/downsample.py"
,
"sandbox/solve.py"
,
"sandbox/solve.py"
,
"sandbox/theano_object.py"
,
"sandbox/theano_object.py"
,
"sandbox/scan.py"
,
"sandbox/scan.py"
,
...
@@ -190,7 +168,6 @@ whitelist_flake8 = [
...
@@ -190,7 +168,6 @@ whitelist_flake8 = [
"sandbox/cuda/nvcc_compiler.py"
,
"sandbox/cuda/nvcc_compiler.py"
,
"sandbox/cuda/neighbours.py"
,
"sandbox/cuda/neighbours.py"
,
"sandbox/cuda/tests/walltime.py"
,
"sandbox/cuda/tests/walltime.py"
,
"sandbox/cuda/tests/test_fftconv.py"
,
"sandbox/cuda/tests/test_gradient.py"
,
"sandbox/cuda/tests/test_gradient.py"
,
"sandbox/cuda/tests/test_neighbours.py"
,
"sandbox/cuda/tests/test_neighbours.py"
,
"sandbox/cuda/tests/test_conv_cuda_ndarray.py"
,
"sandbox/cuda/tests/test_conv_cuda_ndarray.py"
,
...
@@ -218,7 +195,6 @@ whitelist_flake8 = [
...
@@ -218,7 +195,6 @@ whitelist_flake8 = [
"sandbox/scan_module/tests/test_utils.py"
,
"sandbox/scan_module/tests/test_utils.py"
,
"sandbox/scan_module/tests/test_scan.py"
,
"sandbox/scan_module/tests/test_scan.py"
,
"sandbox/linalg/ops.py"
,
"sandbox/linalg/ops.py"
,
"sandbox/linalg/kron.py"
,
"sandbox/linalg/__init__.py"
,
"sandbox/linalg/__init__.py"
,
"sandbox/linalg/tests/test_linalg.py"
,
"sandbox/linalg/tests/test_linalg.py"
,
"sandbox/gpuarray/comp.py"
,
"sandbox/gpuarray/comp.py"
,
...
@@ -288,24 +264,12 @@ whitelist_flake8 = [
...
@@ -288,24 +264,12 @@ whitelist_flake8 = [
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/sp.py"
,
"sparse/sandbox/sp.py"
,
"gof/destroyhandler.py"
,
"gof/destroyhandler.py"
,
"gof/vm.py"
,
"gof/cutils.py"
,
"gof/compiledir.py"
,
"gof/unify.py"
,
"gof/unify.py"
,
"gof/lazylinker_c.py"
,
"gof/optdb.py"
,
"gof/utils.py"
,
"gof/graph.py"
,
"gof/graph.py"
,
"gof/callcache.py"
,
"gof/python25.py"
,
"gof/type.py"
,
"gof/__init__.py"
,
"gof/__init__.py"
,
"gof/cc.py"
,
"gof/cc.py"
,
"gof/opt.py"
,
"gof/opt.py"
,
"gof/compilelock.py"
,
"gof/link.py"
,
"gof/link.py"
,
"gof/sched.py"
,
"gof/toolbox.py"
,
"gof/fg.py"
,
"gof/fg.py"
,
"gof/op.py"
,
"gof/op.py"
,
"gof/cmodule.py"
,
"gof/cmodule.py"
,
...
@@ -322,9 +286,6 @@ whitelist_flake8 = [
...
@@ -322,9 +286,6 @@ whitelist_flake8 = [
"gof/tests/test_cc.py"
,
"gof/tests/test_cc.py"
,
"gof/tests/test_compute_test_value.py"
,
"gof/tests/test_compute_test_value.py"
,
"gof/sandbox/equilibrium.py"
,
"gof/sandbox/equilibrium.py"
,
"sandbox/cuda/opt_util.py"
,
"gof/tests/test_utils.py"
,
"tensor/tests/_test_mpi_roundtrip.py"
,
]
]
...
...
theano/version.py
浏览文件 @
4cf7afb4
try
:
try
:
from
theano.generated_version
import
*
from
theano.generated_version
import
*
# noqa
except
ImportError
:
except
ImportError
:
short_version
=
'unknown'
short_version
=
'unknown'
version
=
'unknown'
version
=
'unknown'
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论