Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4cf7afb4
提交
4cf7afb4
authored
5月 25, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2952 from abergeron/flake8
Flake8 work
上级
1003abb0
0761b0b8
全部展开
显示空白字符变更
内嵌
并排
正在显示
31 个修改的文件
包含
424 行增加
和
369 行删除
+424
-369
builders.py
theano/compile/builders.py
+7
-6
debugmode.py
theano/compile/debugmode.py
+0
-0
function.py
theano/compile/function.py
+60
-51
function_module.py
theano/compile/function_module.py
+0
-0
io.py
theano/compile/io.py
+66
-34
mode.py
theano/compile/mode.py
+17
-17
monitormode.py
theano/compile/monitormode.py
+1
-1
ops.py
theano/compile/ops.py
+50
-37
pfunc.py
theano/compile/pfunc.py
+0
-0
profilemode.py
theano/compile/profilemode.py
+0
-0
profiling.py
theano/compile/profiling.py
+0
-0
sharedvalue.py
theano/compile/sharedvalue.py
+11
-8
callcache.py
theano/gof/callcache.py
+2
-7
compiledir.py
theano/gof/compiledir.py
+4
-4
compilelock.py
theano/gof/compilelock.py
+2
-2
cutils.py
theano/gof/cutils.py
+42
-32
lazylinker_c.py
theano/gof/lazylinker_c.py
+8
-7
optdb.py
theano/gof/optdb.py
+6
-4
sched.py
theano/gof/sched.py
+5
-5
test_utils.py
theano/gof/tests/test_utils.py
+6
-3
toolbox.py
theano/gof/toolbox.py
+14
-10
type.py
theano/gof/type.py
+48
-35
utils.py
theano/gof/utils.py
+15
-15
vm.py
theano/gof/vm.py
+40
-30
opt_util.py
theano/sandbox/cuda/opt_util.py
+0
-1
test_fftconv.py
theano/sandbox/cuda/tests/test_fftconv.py
+3
-3
_test_mpi_roundtrip.py
theano/tensor/tests/_test_mpi_roundtrip.py
+2
-1
disturb_mem.py
theano/tests/disturb_mem.py
+2
-2
main.py
theano/tests/main.py
+12
-14
test_flake8.py
theano/tests/test_flake8.py
+0
-39
version.py
theano/version.py
+1
-1
没有找到文件。
theano/compile/builders.py
浏览文件 @
4cf7afb4
...
...
@@ -15,7 +15,9 @@ class OpFromGraph(gof.Op):
TODO:
- examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try gof.opt.is_same_graph_with_merge(op1.new_outputs, op2, new_outputs)
- __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.new_outputs, op2,
new_outputs)
- c_code() to remove the double overhead?
- opt to unfold it, work inplace on inputs
- grad() make it support DisconnectedType and the new interface
...
...
@@ -76,8 +78,6 @@ class OpFromGraph(gof.Op):
# not see them. Otherwise their is problem with the gradient.
self
.
shared_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
if
isinstance
(
var
,
SharedVariable
)]
used_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
if
not
isinstance
(
var
,
gof
.
Constant
)]
shared_vars
=
[
var
.
type
()
for
var
in
self
.
shared_inputs
]
new
=
rebuild_collect_shared
(
outputs
,
inputs
=
inputs
+
shared_vars
,
replace
=
dict
(
zip
(
self
.
shared_inputs
,
...
...
@@ -110,8 +110,8 @@ class OpFromGraph(gof.Op):
def
make_node
(
self
,
*
inputs
):
for
input
,
type
in
zip
(
inputs
,
self
.
input_types
):
if
not
type
==
input
.
type
:
raise
TypeError
(
"Wrong type, expected
%
s but got
%
s"
%
(
type
,
input
.
type
))
raise
TypeError
(
"Wrong type, expected
%
s but got
%
s"
%
(
type
,
input
.
type
))
return
gof
.
Apply
(
self
,
list
(
inputs
)
+
self
.
shared_inputs
,
[
type
()
for
type
in
self
.
output_types
])
...
...
@@ -143,7 +143,8 @@ class OpFromGraph(gof.Op):
grad_ops
=
self
.
grad_ops
else
:
gs
=
theano
.
gradient
.
grad
(
cost
=
None
,
known_grads
=
dict
(
zip
(
self
.
new_outputs
,
output_grads
)),
known_grads
=
dict
(
zip
(
self
.
new_outputs
,
output_grads
)),
wrt
=
self
.
new_inputs
,
disconnected_inputs
=
'ignore'
)
...
...
theano/compile/debugmode.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/function.py
浏览文件 @
4cf7afb4
"""Define the `function` function
"""
__docformat__
=
"restructuredtext en"
import
cPickle
import
logging
_logger
=
logging
.
getLogger
(
'theano.compile.function'
)
import
traceback
as
tb
import
re
...
...
@@ -14,9 +11,11 @@ from theano.compile.function_module import orig_function
from
theano.compile.pfunc
import
pfunc
from
numpy
import
any
import
warnings
from
theano
import
gof
from
theano
import
compat
__docformat__
=
"restructuredtext en"
_logger
=
logging
.
getLogger
(
'theano.compile.function'
)
def
function_dump
(
filename
,
inputs
,
outputs
=
None
,
mode
=
None
,
updates
=
None
,
givens
=
None
,
...
...
@@ -70,54 +69,67 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
:type mode: string or `Mode` instance.
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or OrderedDict.
:param updates: update the values for SharedVariable inputs according to these expressions
:type updates: iterable over pairs (shared_variable, new_expression).
List, tuple or OrderedDict.
:param updates: update the values for SharedVariable inputs
according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List,
tuple or dict. The Var1
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation
graph (Var2 replaces
Var1).
:type givens: iterable over pairs (Var1, Var2) of Variables. List,
tuple or dict. The Var1 and Var2 in each pair must
have the same Type.
:param givens: specific substitutions to make in the computation
graph (Var2 replaces
Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function.
:param rebuild_strict: True (Default) is the safer and better tested setting, in which case
`givens` must substitute new variables with the same Type as the variables they replace.
False is a you-better-know-what-you-are-doing setting, that permits `givens` to replace
variables with new variables of any Type. The consequence of changing a Type is that all
results depending on that variable may have a different Type too (the graph is rebuilt from
inputs to outputs). If one of the new types does not make sense for one of the Ops in the
:param no_default_updates: if True, do not perform any automatic
update on Variables. If False (default), perform them
all. Else, perform automatic updates on all Variables that are
neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode
will print the time spent in this function.
:param rebuild_strict: True (Default) is the safer and better
tested setting, in which case `givens` must substitute new
variables with the same Type as the variables they replace.
False is a you-better-know-what-you-are-doing setting, that
permits `givens` to replace variables with new variables of
any Type. The consequence of changing a Type is that all
results depending on that variable may have a different Type
too (the graph is rebuilt from inputs to outputs). If one of
the new types does not make sense for one of the Ops in the
graph, an Exception will be raised.
:type allow_input_downcast: Boolean or None
:param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision.
False means that it will only be cast to a more general, or
precise, type. None (default) is almost like False, but allows
downcasting of Python float scalars to floatX.
inputs when calling the function can be silently downcasted to
fit the dtype of the corresponding Variable, which may lose
precision. False means that it will only be cast to a more
general, or precise, type. None (default) is almost like
False, but allows downcasting of Python float scalars to
floatX.
:type profile: None, True, or ProfileStats instance
:param profile: accumulate profiling information into a given ProfileStats
instance. If argument is `True` then a new ProfileStats instance will be
used. This profiling object will be available via self.profile.
:param profile: accumulate profiling information into a given
ProfileStats instance. If argument is `True` then a new
ProfileStats instance will be used. This profiling object
will be available via self.profile.
:param on_unused_input: What to do if a variable in the 'inputs' list is
not used in the graph. Possible values are 'raise', 'warn', 'ignore' and None.
:param on_unused_input: What to do if a variable in the 'inputs'
list is not used in the graph. Possible values are 'raise',
'warn', 'ignore' and None.
:rtype: Function instance
:returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`.
:returns: a callable object that will compute the outputs (given
the inputs) and update the implicit function arguments
according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
:note: Regarding givens: Be careful to make sure that these
substitutions are independent--behaviour when Var1 of one pair
appears in the graph leading to Var2 in another expression is
undefined. Replacements specified with givens are different
from optimizations in that Var2 is not expected to be
equivalent to Var1.
Internal documentation:
...
...
@@ -195,10 +207,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
was easier to develop the VM in Python then translate it to C instead
of just writing it in C from scratch.
CVM stands for C Virtual Machine.
"""
if
isinstance
(
outputs
,
dict
):
output_items
=
outputs
.
items
()
...
...
@@ -214,7 +222,6 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
output_keys
.
append
(
pair
[
0
])
outputs
.
append
(
pair
[
1
])
else
:
output_keys
=
None
...
...
@@ -256,12 +263,13 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if
givens
is
None
:
givens
=
[]
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
raise
Exception
(
"Input variables of a Theano function should be"
" contained in a list, even when there is a single input."
)
raise
Exception
(
"Input variables of a Theano function should be "
"contained in a list, even when there is a single "
"input."
)
# compute some features of the arguments:
uses_In
=
any
([
isinstance
(
i
,
In
)
for
i
in
inputs
])
# N.B. the square brackets are ncessary
uses_tuple
=
any
([
isinstance
(
i
,
(
list
,
tuple
))
for
i
in
inputs
])
# N.B. the square brackets are ncessary
uses_In
=
any
([
isinstance
(
i
,
In
)
for
i
in
inputs
])
uses_tuple
=
any
([
isinstance
(
i
,
(
list
,
tuple
))
for
i
in
inputs
])
uses_updates
=
bool
(
updates
)
uses_givens
=
bool
(
givens
)
...
...
@@ -275,7 +283,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
if
uses_In
or
uses_tuple
:
# we must use old semantics in this case.
if
profile
:
raise
NotImplementedError
(
'profiling not supported in old-style function'
)
raise
NotImplementedError
(
"profiling not supported in old-style "
"function"
)
if
uses_updates
or
uses_givens
:
raise
NotImplementedError
(
"In() instances and tuple inputs trigger the old "
...
...
@@ -284,8 +293,8 @@ def function(inputs, outputs=None, mode=None, updates=None, givens=None,
mode
=
mode
,
accept_inplace
=
accept_inplace
,
name
=
name
)
else
:
# note: pfunc will also call orig_function-- orig_function is
a choke point
# that all compilation must pass through
# note: pfunc will also call orig_function-- orig_function is
#
a choke point
that all compilation must pass through
fn
=
pfunc
(
params
=
inputs
,
outputs
=
outputs
,
mode
=
mode
,
...
...
theano/compile/function_module.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/io.py
浏览文件 @
4cf7afb4
"""Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out` """
__docformat__
=
'restructuredtext en'
from
theano
import
gof
from
sharedvalue
import
SharedVariable
...
...
@@ -7,6 +6,8 @@ from sharedvalue import SharedVariable
import
logging
_logger
=
logging
.
getLogger
(
"theano.compile.io"
)
__docformat__
=
'restructuredtext en'
class
SymbolicInput
(
object
):
"""
...
...
@@ -17,34 +18,47 @@ class SymbolicInput(object):
not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by
kwarg, and its value
can be accessed by self.<name>.
If name is a valid Python identifier, this input can be set by
kwarg, and its value
can be accessed by self.<name>.
update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call.
If update is None, the update will be the default value of the input.
value (see previous) will be replaced with this expression
variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is
not None)
True: permit the compiled function to modify the python object
being passed as the input
mutable: Bool (default: False if update is None, True if update is not None)
True: permit the compiled function to modify the python object being passed as the input
False: do not permit the compiled function to modify the python object being passed as the input.
False: do not permit the compiled function to modify the
python object being passed as the input.
strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type
True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None)
Only applies when `strict` is False.
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True)
See the name option.
implicit: Bool (default: False)
See help(In). Note that 'None' is not allowed here, since we
are in the
symbolic case.
See help(In). Note that 'None' is not allowed here, since we
are in the
symbolic case.
"""
def
__init__
(
self
,
variable
,
name
=
None
,
update
=
None
,
mutable
=
None
,
...
...
@@ -146,36 +160,54 @@ class In(SymbolicInput):
not computed from its owner.
name: Any type. (If autoname=True, defaults to variable.name).
If name is a valid Python identifier, this input can be set by
kwarg, and its value
can be accessed by self.<name>.
If name is a valid Python identifier, this input can be set by
kwarg, and its value
can be accessed by self.<name>.
value: Any type.
The initial/default value for this input. If update is None, this input acts just like
an argument with a default value in Python. If update is not None, changes to this
value will "stick around", whether due to an update or a user's explicit action.
The initial/default value for this input. If update is None,
this input acts just like an argument with a default value in
Python. If update is not None, changes to this value will
"stick around", whether due to an update or a user's explicit
action.
update: Variable instance (default: None)
value (see previous) will be replaced with this expression variable after each function call.
If update is None, the update will be the default value of the input.
value (see previous) will be replaced with this expression
variable after each function call. If update is None, the
update will be the default value of the input.
mutable: Bool (default: False if update is None, True if update is not None)
True: permit the compiled function to modify the python object being passed as the input
False: do not permit the compiled function to modify the python object being passed as the input.
mutable: Bool (default: False if update is None, True if update is
not None)
True: permit the compiled function to modify the python object
being passed as the input
False: do not permit the compiled function to modify the
python object being passed as the input.
borrow: Bool (default: take the same value as mutable)
True: permit the output of the compiled function to be aliased to the input
True: permit the output of the compiled function to be aliased
to the input
False: do not permit any output to be aliased to the input
strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type
True: means that the value you pass for this input must have
exactly the right type
False: the value you pass for this input may be cast
automatically to the proper type
allow_downcast: Bool or None (default: None)
Only applies when `strict` is False.
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
False: the value will only be cast to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
False: the value will only be cast to a more general, or
precise, type. None: Almost like False, but allows downcast
of Python floats to floatX.
autoname: Bool (default: True)
See the name option.
...
...
@@ -194,11 +226,11 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized.
def
__init__
(
self
,
variable
,
name
=
None
,
value
=
None
,
update
=
None
,
mutable
=
None
,
strict
=
False
,
allow_downcast
=
None
,
autoname
=
Tru
e
,
implicit
=
None
,
borrow
=
None
,
shared
=
False
):
#
if shared, an input's value comes from its persistent storage, not from a default stored
#
in the function or from
the caller
mutable
=
None
,
strict
=
False
,
allow_downcast
=
Non
e
,
autoname
=
True
,
implicit
=
None
,
borrow
=
None
,
shared
=
False
):
# if shared, an input's value comes from its persistent
#
storage, not from a default stored in the function or from
# the caller
self
.
shared
=
shared
if
borrow
is
None
:
...
...
theano/compile/mode.py
浏览文件 @
4cf7afb4
...
...
@@ -2,8 +2,6 @@
"""
from
__future__
import
print_function
import
logging
import
warnings
from
textwrap
import
dedent
import
numpy
...
...
@@ -11,24 +9,24 @@ import theano
from
theano
import
gof
import
theano.gof.vm
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
from
theano.compile.ops
import
register_view_op_c_code
,
_output_guard
from
theano.compile.ops
import
_output_guard
_logger
=
logging
.
getLogger
(
'theano.compile.mode'
)
AddConfigVar
(
'optimizer_excluding'
,
(
"When using the default mode, we will remove optimizer with these
"
"
tags. Separate tags with ':'."
),
(
"When using the default mode, we will remove optimizer with
"
"these
tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
AddConfigVar
(
'optimizer_including'
,
(
"When using the default mode, we will add optimizer with these tags.
"
"
Separate tags with ':'."
),
(
"When using the default mode, we will add optimizer with
"
"these tags.
Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
AddConfigVar
(
'optimizer_requiring'
,
(
"When using the default mode, we will require optimizer with these
"
"
tags. Separate tags with ':'."
),
(
"When using the default mode, we will require optimizer with
"
"these
tags. Separate tags with ':'."
),
StrParam
(
""
,
allow_override
=
False
),
in_c_key
=
False
)
...
...
@@ -50,9 +48,9 @@ def check_equal(x, y):
y
=
y
.
todense
()
if
isinstance
(
x
,
numpy
.
ndarray
)
and
isinstance
(
y
,
numpy
.
ndarray
):
if
(
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
)):
if
(
x
.
dtype
!=
y
.
dtype
or
x
.
shape
!=
y
.
shape
or
numpy
.
any
(
abs
(
x
-
y
)
>
1e-10
)):
raise
Exception
(
"Output mismatch."
,
{
'performlinker'
:
x
,
'clinker'
:
y
})
else
:
...
...
@@ -287,7 +285,8 @@ class Mode(object):
def
__str__
(
self
):
return
"
%
s(linker =
%
s, optimizer =
%
s)"
%
(
self
.
__class__
.
__name__
,
self
.
provided_linker
,
self
.
provided_optimizer
)
self
.
provided_linker
,
self
.
provided_optimizer
)
def
__get_optimizer
(
self
):
if
isinstance
(
self
.
_optimizer
,
gof
.
Query
):
...
...
@@ -364,10 +363,11 @@ def get_mode(orig_string):
# DebugMode use its own linker.
ret
=
DebugMode
(
optimizer
=
config
.
optimizer
)
else
:
# The import is needed in case string is ProfileMode
from
profilemode
import
ProfileMode
,
prof_mode_instance_to_print
ret
=
eval
(
string
+
'(linker=config.linker, optimizer=config.optimizer)'
)
# This might be required if the string is 'ProfileMode'
from
profilemode
import
ProfileMode
# noqa
from
profilemode
import
prof_mode_instance_to_print
ret
=
eval
(
string
+
'(linker=config.linker, optimizer=config.optimizer)'
)
elif
string
in
predefined_modes
:
ret
=
predefined_modes
[
string
]
else
:
...
...
theano/compile/monitormode.py
浏览文件 @
4cf7afb4
from
__future__
import
print_function
# Note: this code was initially copied from the 'pyutools' package by its
# original author, and re-licensed under Theano's license.
import
numpy
import
theano
from
theano.compile.mode
import
Mode
...
...
theano/compile/ops.py
浏览文件 @
4cf7afb4
...
...
@@ -71,11 +71,12 @@ class ViewOp(gof.Op):
version
=
[]
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for ViewOp, but it has "
"no version. You should add a 'version' keyword arg
"
"
when calling register_view_op_c_code."
%
t
,
warnings
.
warn
(
"Type
%
s has C code for ViewOp, but it has
no
"
"version. You should add a 'version' keyword
"
"arg
when calling register_view_op_c_code."
%
t
,
stacklevel
=
2
)
return
()
version
.
append
((
str
(
t
),
v
))
...
...
@@ -165,11 +166,13 @@ class DeepCopyOp(gof.Op):
version
=
[]
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for DeepCopyOp, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_deep_copy_op_c_code."
%
t
,
"no version. You should add a 'version' keyword"
" arg when calling "
"register_deep_copy_op_c_code."
%
t
,
stacklevel
=
2
)
return
()
version
.
append
((
str
(
t
),
v
))
...
...
@@ -284,11 +287,12 @@ class Shape(gof.Op):
version
=
[]
# If any of the c code is unversionned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for Shape, but it has "
"no version. You should add a 'version' keyword arg
"
"
when calling register_shape_c_code."
%
t
,
warnings
.
warn
(
"Type
%
s has C code for Shape, but it has
no
"
"version. You should add a 'version' keyword
"
"arg
when calling register_shape_c_code."
%
t
,
stacklevel
=
2
)
return
()
version
.
append
((
str
(
t
),
v
))
...
...
@@ -301,7 +305,6 @@ class Shape(gof.Op):
shape
=
Shape
()
_shape
=
shape
# was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape'))
class
Shape_i
(
gof
.
Op
):
...
...
@@ -389,8 +392,11 @@ class Shape_i(gof.Op):
return
[()]
def
grad
(
self
,
inp
,
grads
):
return
[
theano
.
gradient
.
grad_not_implemented
(
op
=
self
,
x_pos
=
0
,
x
=
inp
[
0
],
comment
=
"No gradient for the shape of a matrix is implemented."
)]
return
[
theano
.
gradient
.
grad_not_implemented
(
op
=
self
,
x_pos
=
0
,
x
=
inp
[
0
],
comment
=
(
"No gradient for the shape of a matrix "
"is implemented."
))]
def
shape_i
(
var
,
i
,
fgraph
=
None
):
"""Equivalent of var.shape[i], but apply if possible the shape
...
...
@@ -435,9 +441,10 @@ def shape_i(var, i, fgraph=None):
def
register_shape_i_c_code
(
typ
,
code
,
check_input
,
version
=
()):
""" Tell Shape_i how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and not an
instance of the class.
:param code: C code that gets the shape of dimensions
%(i)
s for the Theano type 'typ'.
:param typ: A Theano type. It must be the Theano class itself and not
an instance of the class.
:param code: C code that gets the shape of dimensions
%(i)
s for the
Theano type 'typ'.
Use
%(iname)
s and
%(oname)
s for the input and output C
variable names respectively.
:param version: A number indicating the version of the code, for cache.
...
...
@@ -620,7 +627,8 @@ class Rebroadcast(gof.Op):
return
type
(
self
)
==
type
(
other
)
and
self
.
axis
==
other
.
axis
def
__hash__
(
self
):
items
=
sorted
(
self
.
axis
.
iteritems
())
# no ambiguity because each item key is unique
# no ambiguity because each item key is unique
items
=
sorted
(
self
.
axis
.
iteritems
())
return
hash
((
type
(
self
),
tuple
(
items
)))
def
__str__
(
self
):
...
...
@@ -637,9 +645,9 @@ class Rebroadcast(gof.Op):
def
make_node
(
self
,
x
):
if
self
.
axis
.
keys
()
and
(
x
.
ndim
<=
numpy
.
max
(
self
.
axis
.
keys
())):
raise
ValueError
(
'Trying to rebroadcast non-existent dimension'
)
t
=
x
.
type
.
clone
(
broadcastable
=
[
self
.
axis
.
get
(
i
,
b
)
for
i
,
b
in
enumerate
(
x
.
type
.
broadcastable
)])
t
=
x
.
type
.
clone
(
broadcastable
=
[
self
.
axis
.
get
(
i
,
b
)
for
i
,
b
in
enumerate
(
x
.
type
.
broadcastable
)])
return
gof
.
Apply
(
self
,
[
x
],
[
t
()])
def
perform
(
self
,
node
,
inp
,
out_
):
...
...
@@ -702,9 +710,10 @@ class Rebroadcast(gof.Op):
for
t
,
(
c
,
v
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for Rebroadcast, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_rebroadcast_c_code."
%
t
,
warnings
.
warn
(
"Type
%
s has C code for Rebroadcast, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_rebroadcast_c_code."
%
t
,
stacklevel
=
2
)
return
()
version
.
append
((
str
(
t
),
v
))
...
...
@@ -718,17 +727,18 @@ def register_specify_shape_c_code(typ, code, version=(),
c_support_code_apply
=
None
):
""" Tell SpecifyShape how to generate C code for a Theano Type
:param typ: A Theano type. It must be the Theano class itself and
not an
instance of the class.
:param code: C code that checks the shape and returns a view for
the Theano type 'typ'.
Use
%(iname)
s and
%(oname)
s for the input and output C
variable names respectively.
%(shape)
s is the vector of shape of
%(iname)
s.
Check that its length is good.
:param typ: A Theano type. It must be the Theano class itself and
not an
instance of the class.
:param code: C code that checks the shape and returns a view for
the Theano type 'typ'. Use
%(iname)
s and
%(oname)
s
for the input and output C variable names
respectively.
%(shape)
s is the vector of shape of
%(iname)
s.
Check that its length is good.
:param version: A number indicating the version of the code, for cache.
:param c_support_code_apply: extra code.
"""
SpecifyShape
.
c_code_and_version
[
typ
]
=
(
code
,
version
,
c_support_code_apply
)
SpecifyShape
.
c_code_and_version
[
typ
]
=
(
code
,
version
,
c_support_code_apply
)
class
SpecifyShape
(
gof
.
Op
):
...
...
@@ -784,7 +794,8 @@ class SpecifyShape(gof.Op):
new_shape
=
[]
for
dim
in
xrange
(
node
.
inputs
[
0
]
.
ndim
):
try
:
s
=
theano
.
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
theano
.
tensor
.
get_scalar_constant_value
(
node
.
inputs
[
1
][
dim
])
s
=
theano
.
tensor
.
as_tensor_variable
(
s
)
new_shape
.
append
(
s
)
except
theano
.
tensor
.
NotScalarConstantError
:
...
...
@@ -832,7 +843,8 @@ class SpecifyShape(gof.Op):
code
,
version
,
_
=
self
.
c_code_and_version
[
itype
]
return
code
%
locals
()
return
super
(
SpecifyShape
,
self
)
.
c_code
(
node
,
node
,
inames
,
onames
,
sub
)
return
super
(
SpecifyShape
,
self
)
.
c_code
(
node
,
node
,
inames
,
onames
,
sub
)
def
c_code_cache_version
(
self
):
version
=
[]
...
...
@@ -841,9 +853,10 @@ class SpecifyShape(gof.Op):
for
t
,
(
c
,
v
,
_
)
in
sorted
(
self
.
c_code_and_version
.
items
(),
key
=
lambda
pair
:
str
(
pair
[
0
])):
if
not
v
:
warnings
.
warn
(
"Type
%
s has C code for SpecifyShape, but it has "
"no version. You should add a 'version' keyword arg "
"when calling register_specify_shape_c_code."
%
t
,
warnings
.
warn
(
"Type
%
s has C code for SpecifyShape, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_specify_shape_c_code."
%
t
,
stacklevel
=
2
)
return
()
version
.
append
((
str
(
t
),
v
))
...
...
theano/compile/pfunc.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/profilemode.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/profiling.py
浏览文件 @
4cf7afb4
差异被折叠。
点击展开。
theano/compile/sharedvalue.py
浏览文件 @
4cf7afb4
"""Provide a simple user friendly API to Theano-managed memory"""
__docformat__
=
'restructuredtext en'
# Standard imports
import
copy
import
logging
...
...
@@ -12,6 +10,7 @@ import numpy
from
theano.gof
import
Container
,
Variable
,
generic
,
utils
_logger
=
logging
.
getLogger
(
'theano.compile.sharedvalue'
)
__docformat__
=
'restructuredtext en'
class
SharedVariable
(
Variable
):
...
...
@@ -49,7 +48,8 @@ class SharedVariable(Variable):
or copied, so they must have the correct type.
:param allow_downcast: Only applies if `strict` is False.
True -> allow assigned value to lose precision when cast during assignment.
True -> allow assigned value to lose precision when cast
during assignment.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
...
...
@@ -65,12 +65,13 @@ class SharedVariable(Variable):
if
container
is
not
None
:
self
.
container
=
container
if
(
value
is
not
None
)
or
(
strict
is
not
None
):
raise
TypeError
(
'value and strict are ignored if you pass
a container here'
)
raise
TypeError
(
'value and strict are ignored if you pass '
'
a container here'
)
else
:
if
container
is
not
None
:
raise
TypeError
(
'Error to specify both value and container'
)
self
.
container
=
Container
(
self
,
self
.
container
=
Container
(
self
,
storage
=
[
type
.
filter
(
value
,
strict
=
strict
,
allow_downcast
=
allow_downcast
)],
readonly
=
False
,
...
...
@@ -183,7 +184,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
potential constructors to those that can accept those kwargs.
:note: Some shared variable have ``borrow`` as extra kwargs.
`See <http://deeplearning.net/software/theano/tutorial/aliasing.html#borrowing-when-creating-shared-variables>`_ for detail.
`See <http://deeplearning.net/software/theano/tutorial/aliasing.
\
html#borrowing-when-creating-shared-variables>`_ for detail.
:note: Some shared variable have ``broadcastable`` as extra kwargs.
As shared variable shapes can change, all dimensions default
...
...
@@ -200,7 +202,8 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
try
:
if
isinstance
(
value
,
Variable
):
raise
TypeError
(
" Shared variable constructor needs numeric values and not symbolic variables."
)
raise
TypeError
(
"Shared variable constructor needs numeric "
"values and not symbolic variables."
)
for
ctor
in
reversed
(
shared
.
constructors
):
try
:
...
...
theano/gof/callcache.py
浏览文件 @
4cf7afb4
import
cPickle
,
logging
import
cPickle
import
logging
_logger
=
logging
.
getLogger
(
"theano.gof.callcache"
)
...
...
@@ -18,9 +19,6 @@ class CallCache(object):
def
persist
(
self
,
filename
=
None
):
if
filename
is
None
:
filename
=
self
.
filename
# backport
#filename = self.filename if filename is None else filename
f
=
open
(
filename
,
'w'
)
cPickle
.
dump
(
self
.
cache
,
f
)
f
.
close
()
...
...
@@ -28,9 +26,6 @@ class CallCache(object):
def
call
(
self
,
fn
,
args
=
(),
key
=
None
):
if
key
is
None
:
key
=
(
fn
,
tuple
(
args
))
# backport
#key = (fn, tuple(args)) if key is None else key
if
key
not
in
self
.
cache
:
_logger
.
debug
(
'cache miss
%
i'
,
len
(
self
.
cache
))
self
.
cache
[
key
]
=
fn
(
*
args
)
...
...
theano/gof/compiledir.py
浏览文件 @
4cf7afb4
...
...
@@ -8,7 +8,6 @@ import re
import
shutil
import
struct
import
socket
import
subprocess
import
sys
import
textwrap
...
...
@@ -295,7 +294,8 @@ def cleanup():
have_npy_abi_version
=
True
elif
obj
.
startswith
(
'c_compiler_str='
):
have_c_compiler
=
True
elif
(
isinstance
(
obj
,
(
theano
.
gof
.
Op
,
theano
.
gof
.
Type
))
and
elif
(
isinstance
(
obj
,
(
theano
.
gof
.
Op
,
theano
.
gof
.
Type
))
and
hasattr
(
obj
,
'c_code_cache_version'
)):
v
=
obj
.
c_code_cache_version
()
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
...
...
@@ -310,7 +310,7 @@ def cleanup():
if
keydata
.
key_pkl
!=
filename
:
keydata
.
key_pkl
=
filename
keydata
.
remove_key
(
key
)
except
IOError
as
e
:
except
IOError
:
_logger
.
error
(
"Could not remove file '
%
s'. To complete "
"the clean-up, please remove manually "
...
...
@@ -395,7 +395,7 @@ def print_compiledir_content():
if
big_key_files
:
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_total_size
=
sum
([
s
ize
for
dir
,
size
,
ops
in
big_key_files
])
big_total_size
=
sum
([
s
z
for
_
,
sz
,
_
in
big_key_files
])
print
((
"There are directories with key files bigger than
%
d bytes "
"(they probably contain big tensor constants)"
%
max_key_file_size
))
...
...
theano/gof/compilelock.py
浏览文件 @
4cf7afb4
...
...
@@ -102,8 +102,8 @@ def get_lock(lock_dir=None, **kw):
# the lock state and raise an error.
while
get_lock
.
n_lock
>
0
:
release_lock
()
raise
Exception
(
"For some unknow reason, the lock was already
taken,
"
" but no start time was registered."
)
raise
Exception
(
"For some unknow reason, the lock was already "
"
taken,
but no start time was registered."
)
now
=
time
.
time
()
if
now
-
get_lock
.
start_time
>
config
.
compile
.
timeout
/
2
:
lockpath
=
os
.
path
.
join
(
get_lock
.
lock_dir
,
'lock'
)
...
...
theano/gof/cutils.py
浏览文件 @
4cf7afb4
...
...
@@ -14,18 +14,21 @@ if os.path.exists(os.path.join(config.compiledir, 'cutils_ext.so')):
def
compile_cutils_code
():
types
=
[
'npy_'
+
t
for
t
in
[
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'int128'
,
'int256'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
,
'uint128'
,
'uint256'
,
'float16'
,
'float32'
,
'float64'
,
'float80'
,
'float96'
,
'float128'
,
'int256'
,
'uint8'
,
'uint16'
,
'uint32'
,
'uint64'
,
'uint128'
,
'uint256'
,
'float16'
,
'float32'
,
'float64'
,
'float80'
,
'float96'
,
'float128'
,
'float256'
]]
complex_types
=
[
'npy_'
+
t
for
t
in
[
'complex32'
,
'complex64'
,
'complex128'
,
'complex160'
,
'complex192'
,
'complex512'
]]
'complex128'
,
'complex160'
,
'complex192'
,
'complex512'
]]
inplace_map_template
=
"""
#if defined(
%(typen)
s)
static void
%(type)
s_inplace_add(PyArrayMapIterObject *mit, PyArrayIterObject *it, int inc_or_set)
static void
%(type)
s_inplace_add(PyArrayMapIterObject *mit,
PyArrayIterObject *it, int inc_or_set)
{
int index = mit->size;
while (index--) {
...
...
@@ -38,10 +41,13 @@ def compile_cutils_code():
#endif
"""
floatadd
=
"((
%(type)
s*)mit->dataptr)[0] = inc_or_set * ((
%(type)
s*)mit->dataptr)[0] + ((
%(type)
s*)it->dataptr)[0];"
floatadd
=
(
"((
%(type)
s*)mit->dataptr)[0] = inc_or_set * "
"((
%(type)
s*)mit->dataptr)[0] + ((
%(type)
s*)it->dataptr)[0];"
)
complexadd
=
"""
((
%(type)
s*)mit->dataptr)[0].real = inc_or_set * ((
%(type)
s*)mit->dataptr)[0].real + ((
%(type)
s*)it->dataptr)[0].real;
((
%(type)
s*)mit->dataptr)[0].imag = inc_or_set * ((
%(type)
s*)mit->dataptr)[0].imag + ((
%(type)
s*)it->dataptr)[0].imag;
((
%(type)
s*)mit->dataptr)[0].real = inc_or_set *
((
%(type)
s*)mit->dataptr)[0].real + ((
%(type)
s*)it->dataptr)[0].real;
((
%(type)
s*)mit->dataptr)[0].imag = inc_or_set *
((
%(type)
s*)mit->dataptr)[0].imag + ((
%(type)
s*)it->dataptr)[0].imag;
"""
fns
=
''
.
join
([
inplace_map_template
%
{
'type'
:
t
,
'typen'
:
t
.
upper
(),
...
...
@@ -51,33 +57,36 @@ def compile_cutils_code():
'op'
:
complexadd
%
{
'type'
:
t
}}
for
t
in
complex_types
])
def
gen_binop
(
type
,
typen
):
return
"""
#if defined(
%(typen)
s)
%(type)
s_inplace_add,
#endif
"""
%
dict
(
type
=
type
,
typen
=
typen
)
fn_array
=
(
"static inplace_map_binop addition_funcs[] = {"
+
''
.
join
([
"""
#if defined(
%(typen)
s)
%(type)
s_inplace_add,
#endif
"""
%
{
'type'
:
t
,
'typen'
:
t
.
upper
()}
for
t
in
types
+
complex_types
])
+
"""NULL};
"""
)
''
.
join
([
gen_binop
(
type
=
t
,
typen
=
t
.
upper
())
for
t
in
types
+
complex_types
])
+
"NULL};
\n
"
)
def
gen_num
(
typen
):
return
"""
#if defined(
%(typen)
s)
%(typen)
s,
#endif
"""
%
dict
(
type
=
type
,
typen
=
typen
)
type_number_array
=
(
"static int type_numbers[] = {"
+
''
.
join
([
"""
#if defined(
%(typen)
s)
%(typen)
s,
#endif
"""
%
{
'type'
:
t
,
'typen'
:
t
.
upper
()}
for
t
in
types
+
complex_types
])
+
"-1000};"
)
''
.
join
([
gen_num
(
typen
=
t
.
upper
())
for
t
in
types
+
complex_types
])
+
"-1000};"
)
code
=
(
"""
#if NPY_API_VERSION >= 0x00000008
typedef void (*inplace_map_binop)(PyArrayMapIterObject *, PyArrayIterObject *, int inc_or_set);
"""
+
fns
+
fn_array
+
type_number_array
+
"""
typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
PyArrayIterObject *, int inc_or_set);
"""
+
fns
+
fn_array
+
type_number_array
+
"""
static int
map_increment(PyArrayMapIterObject *mit, PyObject *op, inplace_map_binop add_inplace, int inc_or_set)
map_increment(PyArrayMapIterObject *mit, PyObject *op,
inplace_map_binop add_inplace, int inc_or_set)
{
PyArrayObject *arr = NULL;
PyArrayIterObject *it;
...
...
@@ -129,7 +138,8 @@ inplace_increment(PyObject *dummy, PyObject *args)
return NULL;
}
if (!PyArray_Check(arg_a)) {
PyErr_SetString(PyExc_ValueError, "needs an ndarray as first argument");
PyErr_SetString(PyExc_ValueError,
"needs an ndarray as first argument");
return NULL;
}
...
...
@@ -285,7 +295,7 @@ try:
open
(
os
.
path
.
join
(
location
,
'__init__.py'
),
'w'
)
.
close
()
try
:
from
cutils_ext.cutils_ext
import
*
from
cutils_ext.cutils_ext
import
*
# noqa
except
ImportError
:
get_lock
()
# Ensure no-one else is currently modifying the content of the compilation
...
...
@@ -296,11 +306,11 @@ try:
# We must retry to import it as some other process could
# have been compiling it between the first failed import
# and when we receive the lock
from
cutils_ext.cutils_ext
import
*
from
cutils_ext.cutils_ext
import
*
# noqa
except
ImportError
:
compile_cutils
()
from
cutils_ext.cutils_ext
import
*
from
cutils_ext.cutils_ext
import
*
# noqa
finally
:
# Release lock on compilation directory.
...
...
theano/gof/lazylinker_c.py
浏览文件 @
4cf7afb4
...
...
@@ -15,12 +15,13 @@ _logger = logging.getLogger('theano.gof.lazylinker_c')
force_compile
=
False
version
=
0.21
# must match constant returned in function get_version()
lazylinker_ext
=
None
def
try_import
():
global
lazylinker_ext
sys
.
path
[
0
:
0
]
=
[
config
.
compiledir
]
import
lazylinker_ext
import
lazylinker_ext
# noqa
del
sys
.
path
[
0
]
...
...
@@ -43,11 +44,11 @@ try:
# Try to make the location
os
.
mkdir
(
location
)
except
OSError
as
e
:
# If we get an error, verify that the error was # 17, the
path already exists,
#
and that it is a directory
#
Note: we can't check if it exists before making it, because we are not holding
#
the lock right now, so we could race another process and get error 17 if we los
e
# the race
# If we get an error, verify that the error was # 17, the
#
path already exists, and that it is a directory Note: we
#
can't check if it exists before making it, because we
#
are not holding the lock right now, so we could rac
e
#
another process and get error 17 if we lose
the race
assert
e
.
errno
==
errno
.
EEXIST
assert
os
.
path
.
isdir
(
location
)
...
...
@@ -142,5 +143,5 @@ except ImportError:
# Release lock on compilation directory.
release_lock
()
from
lazylinker_ext.lazylinker_ext
import
*
from
lazylinker_ext.lazylinker_ext
import
*
# noqa
assert
force_compile
or
(
version
==
get_version
())
theano/gof/optdb.py
浏览文件 @
4cf7afb4
...
...
@@ -32,7 +32,7 @@ class DB(object):
self
.
__db__
=
DefaultOrderedDict
(
OrderedSet
)
self
.
_names
=
set
()
self
.
name
=
None
# will be reset by register
#(via obj.name by the thing doing the registering)
#
(via obj.name by the thing doing the registering)
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
"""
...
...
@@ -175,8 +175,10 @@ class Query(object):
self
.
exclude
=
OrderedSet
(
self
.
exclude
)
def
__str__
(
self
):
return
"Query{inc=
%
s,ex=
%
s,require=
%
s,subquery=
%
s,position_cutoff=
%
d}"
%
(
self
.
include
,
self
.
exclude
,
self
.
require
,
self
.
subquery
,
self
.
position_cutoff
)
return
(
"Query{inc=
%
s,ex=
%
s,require=
%
s,subquery=
%
s,"
"position_cutoff=
%
d}"
%
(
self
.
include
,
self
.
exclude
,
self
.
require
,
self
.
subquery
,
self
.
position_cutoff
))
# add all opt with this tag
def
including
(
self
,
*
tags
):
...
...
@@ -268,7 +270,7 @@ class SequenceDB(DB):
position_cutoff
=
kwtags
.
pop
(
'position_cutoff'
,
config
.
optdb
.
position_cutoff
)
if
len
(
tags
)
>=
1
and
isinstance
(
tags
[
0
],
Query
):
# the call to super should have raise an error with a good message
# the call to super should have raise an error with a good message
assert
len
(
tags
)
==
1
if
getattr
(
tags
[
0
],
'position_cutoff'
,
None
):
position_cutoff
=
tags
[
0
]
.
position_cutoff
...
...
theano/gof/sched.py
浏览文件 @
4cf7afb4
...
...
@@ -39,8 +39,8 @@ def make_depends():
def
depends
(
pair
):
""" Returns True if a depends on b """
a
,
b
=
pair
return
(
any
(
bout
in
a
.
inputs
for
bout
in
b
.
outputs
)
or
any
(
depends
((
ainp
.
owner
,
b
))
for
ainp
in
a
.
inputs
return
(
any
(
bout
in
a
.
inputs
for
bout
in
b
.
outputs
)
or
any
(
depends
((
ainp
.
owner
,
b
))
for
ainp
in
a
.
inputs
if
ainp
.
owner
))
return
depends
...
...
@@ -160,12 +160,12 @@ def posort(l, *cmps):
for
b
in
l
:
assert
not
(
b
in
comes_after
[
a
]
and
a
in
comes_after
[
b
])
for
cmp
in
cmps
:
for
cmp
_fn
in
cmps
:
for
a
in
l
:
for
b
in
l
:
if
cmp
(
a
,
b
)
<
0
:
# a wants to come before b
if
cmp
_fn
(
a
,
b
)
<
0
:
# a wants to come before b
# if this wouldn't cause a cycle and isn't already known
if
not
b
in
comes_before
[
a
]
and
not
b
in
comes_after
[
a
]:
if
b
not
in
comes_before
[
a
]
and
b
not
in
comes_after
[
a
]:
add_links
(
a
,
b
)
# check() # debug code
...
...
theano/gof/tests/test_utils.py
浏览文件 @
4cf7afb4
...
...
@@ -36,8 +36,11 @@ def test_give_variables_names_small():
def
test_remove
():
even
=
lambda
x
:
x
%
2
==
0
odd
=
lambda
x
:
x
%
2
==
1
# The list are neede as with python 3, remove and filter return generators
def
even
(
x
):
return
x
%
2
==
0
def
odd
(
x
):
return
x
%
2
==
1
# The list are needed as with python 3, remove and filter return generators
# and we can't compare generators.
assert
list
(
remove
(
even
,
range
(
5
)))
==
list
(
filter
(
odd
,
range
(
5
)))
theano/gof/toolbox.py
浏览文件 @
4cf7afb4
...
...
@@ -214,9 +214,9 @@ class Validator(Feature):
class
ReplaceValidate
(
History
,
Validator
):
pickle_rm_attr
=
[
"replace_validate"
,
"replace_all_validate"
,
"replace_all_validate_remove"
]
+
\
History
.
pickle_rm_attr
+
Validator
.
pickle_rm_attr
pickle_rm_attr
=
(
[
"replace_validate"
,
"replace_all_validate"
,
"replace_all_validate_remove"
]
+
History
.
pickle_rm_attr
+
Validator
.
pickle_rm_attr
)
def
on_attach
(
self
,
fgraph
):
for
attr
in
(
'replace_validate'
,
'replace_all_validate'
,
...
...
@@ -256,11 +256,13 @@ class ReplaceValidate(History, Validator):
try
:
fgraph
.
replace
(
r
,
new_r
,
reason
=
reason
,
verbose
=
False
)
except
Exception
as
e
:
if
(
'The type of the replacement must be the same'
not
in
str
(
e
)
and
'does not belong to this FunctionGraph'
not
in
str
(
e
)):
msg
=
str
(
e
)
s1
=
'The type of the replacement must be the same'
s2
=
'does not belong to this FunctionGraph'
if
(
s1
not
in
msg
and
s2
not
in
msg
):
out
=
sys
.
stderr
print
(
"<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>"
,
end
=
' '
,
file
=
out
)
print
(
type
(
e
),
e
,
reason
,
file
=
out
)
print
(
"<<!! BUG IN FGRAPH.REPLACE OR A LISTENER !!>>"
,
type
(
e
),
e
,
reason
,
file
=
out
)
# this might fail if the error is in a listener:
# (fgraph.replace kinda needs better internal error handling)
fgraph
.
revert
(
chk
)
...
...
@@ -286,13 +288,14 @@ class ReplaceValidate(History, Validator):
fgraph
.
revert
(
chk
)
if
warn
:
out
=
sys
.
stderr
print
(
(
print
(
"WARNING: An optimization wanted to replace a Variable"
" in the graph, but the replacement for it doesn't"
" remove it. We disabled the optimization."
" Your function runs correctly, but it would be"
" appreciated if you submit this problem to the"
" mailing list theano-users so that we can fix it."
),
file
=
out
)
" mailing list theano-users so that we can fix it."
,
file
=
out
)
print
(
reason
,
replacements
,
file
=
out
)
raise
ReplacementDidntRemovedError
()
...
...
@@ -311,7 +314,8 @@ class NodeFinder(Bookkeeper):
def
on_attach
(
self
,
fgraph
):
if
self
.
fgraph
is
not
None
:
raise
Exception
(
"A NodeFinder instance can only serve one FunctionGraph."
)
raise
Exception
(
"A NodeFinder instance can only serve one "
"FunctionGraph."
)
if
hasattr
(
fgraph
,
'get_nodes'
):
raise
AlreadyThere
(
"NodeFinder is already present or in conflict"
" with another plugin."
)
...
...
theano/gof/type.py
浏览文件 @
4cf7afb4
"""WRITEME Defines the `Type` class."""
__docformat__
=
"restructuredtext en"
from
theano.compat
import
PY3
from
theano.gof
import
utils
...
...
@@ -13,6 +10,8 @@ from theano.gof import graph
########
from
theano.gof.op
import
CLinkerObject
__docformat__
=
"restructuredtext en"
class
CLinkerType
(
CLinkerObject
):
"""Interface specification for Types that can be arguments to a `CLinkerOp`.
...
...
@@ -45,7 +44,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method
"""
raise
MethodNotDefined
(
"c_literal"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
MethodNotDefined
(
"c_literal"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
c_declare
(
self
,
name
,
sub
,
check_input
=
True
):
"""Required: Return c code to declare variables that will be
...
...
@@ -56,7 +56,8 @@ class CLinkerType(CLinkerObject):
return "PyObject ** addr_of_
%(name)
s;"
:param name: the name of the ``PyObject *`` pointer that will the value for this Type
:param name: the name of the ``PyObject *`` pointer that will
the value for this Type
:type name: string
...
...
@@ -138,7 +139,8 @@ class CLinkerType(CLinkerObject):
- `MethodNotDefined`: Subclass does not implement this method
"""
raise
MethodNotDefined
(
"c_extract"
,
type
(
self
),
self
.
__class__
.
__name__
)
raise
MethodNotDefined
(
"c_extract"
,
type
(
self
),
self
.
__class__
.
__name__
)
def
c_extract_out
(
self
,
name
,
sub
,
check_input
=
True
):
"""Optional: C code to extract a PyObject * instance.
...
...
@@ -184,11 +186,12 @@ class CLinkerType(CLinkerObject):
def
c_sync
(
self
,
name
,
sub
):
"""Required: Return c code to pack C types back into a PyObject.
The code returned from this function must be templated using "
%(name)
s",
representing the name that the caller wants to call this Variable. The
returned code may set "py_
%(name)
s" to a PyObject* and that PyObject*
will be accessible from Python via variable.data. Do not forget to adjust
reference counts if "py_
%(name)
s" is changed from its original value.
The code returned from this function must be templated using
"
%(name)
s", representing the name that the caller wants to
call this Variable. The returned code may set "py_
%(name)
s"
to a PyObject* and that PyObject* will be accessible from
Python via variable.data. Do not forget to adjust reference
counts if "py_
%(name)
s" is changed from its original value.
:Parameters:
- `name`: WRITEME
...
...
@@ -205,10 +208,11 @@ class CLinkerType(CLinkerObject):
def
c_code_cache_version
(
self
):
"""Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not be cached between processes.
An empty tuple indicates an 'unversioned' Type that will not
be cached between processes.
The cache mechanism may erase cached modules that have been
superceded by newer
versions. See `ModuleCache` for details.
The cache mechanism may erase cached modules that have been
superceded by newer
versions. See `ModuleCache` for details.
"""
return
()
...
...
@@ -221,19 +225,21 @@ class PureType(object):
- creating `Variable` instances (conventionally, `__call__` does this), and
- filtering a value assigned to a `Variable` so that the value
conforms to restrictions
imposed by the type (also known as casting, this is done by `filter`),
- filtering a value assigned to a `Variable` so that the value
conforms to restrictions imposed by the type (also known as
casting, this is done by `filter`),
"""
# the type that will be created by call to make_variable.
Variable
=
graph
.
Variable
Variable
=
graph
.
Variable
# the type that will be created by call to make_variable.
Constant
=
graph
.
Constant
# the type that will be created by call to make_constant
# the type that will be created by call to make_constant
Constant
=
graph
.
Constant
def
filter
(
self
,
data
,
strict
=
False
,
allow_downcast
=
None
):
"""Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if
the data is not of an
acceptable type.
Subclass implementation should raise a TypeError exception if
the data is not of an
acceptable type.
If strict is True, the data returned must be the same as the
data passed as an argument. If it is False, and allow_downcast
...
...
@@ -283,7 +289,8 @@ class PureType(object):
return
other
def
is_valid_value
(
self
,
a
):
"""Required: Return True for any python object `a` that would be a legal value for a Variable of this Type"""
"""Required: Return True for any python object `a` that would be a
legal value for a Variable of this Type"""
try
:
self
.
filter
(
a
,
strict
=
True
)
return
True
...
...
@@ -291,7 +298,8 @@ class PureType(object):
return
False
def
value_validity_msg
(
self
,
a
):
"""Optional: return a message explaining the output of is_valid_value"""
"""Optional: return a message explaining the output of
is_valid_value"""
return
"none"
def
make_variable
(
self
,
name
=
None
):
...
...
@@ -371,7 +379,8 @@ class Type(object2, PureType, CLinkerType):
But you are encouraged to write your own, as described in WRITEME.
The following following code illustrates the use of a Type instance, here tensor.fvector:
The following following code illustrates the use of a Type
instance, here tensor.fvector:
.. code-block:: python
...
...
@@ -381,17 +390,21 @@ class Type(object2, PureType, CLinkerType):
# Create a second Variable with the same Type instance
c = tensor.fvector()
Whenever you create a symbolic variable in theano (technically, `Variable`) it will contain a
reference to a Type instance. That reference is typically constant during the lifetime of
the Variable. Many variables can refer to a single Type instance, as do b and c above. The
Type instance defines the kind of value which might end up in that variable when executing
a `Function`. In this sense, theano is like a strongly-typed language because the types
are included in the graph before the values. In our example above, b is a Variable which is
guaranteed to correspond to a numpy.ndarray of rank 1 when we try to do some computations
Whenever you create a symbolic variable in theano (technically,
`Variable`) it will contain a reference to a Type instance. That
reference is typically constant during the lifetime of the
Variable. Many variables can refer to a single Type instance, as
do b and c above. The Type instance defines the kind of value
which might end up in that variable when executing a `Function`.
In this sense, theano is like a strongly-typed language because
the types are included in the graph before the values. In our
example above, b is a Variable which is guaranteed to correspond
to a numpy.ndarray of rank 1 when we try to do some computations
with it.
Many `Op` instances will raise an exception if they are applied to inputs with incorrect
types. Type references are also useful to do type-checking in pattern-based optimizations.
Many `Op` instances will raise an exception if they are applied to
inputs with incorrect types. Type references are also useful to
do type-checking in pattern-based optimizations.
"""
def
convert_variable
(
self
,
var
):
...
...
@@ -451,8 +464,8 @@ class Generic(SingletonType):
"""
Represents a generic Python object.
This class implements the `PureType` and `CLinkerType` interfaces
for generic PyObject
instances.
This class implements the `PureType` and `CLinkerType` interfaces
for generic PyObject
instances.
EXAMPLE of what this means, or when you would use this type.
...
...
theano/gof/utils.py
浏览文件 @
4cf7afb4
from
__future__
import
print_function
import
linecache
import
traceback
import
re
import
sys
from
theano
import
config
...
...
@@ -15,7 +14,6 @@ def simple_extract_stack(f=None, limit=None):
This is because this update cause an call to os.stat to get the
line content. This cause too much long on cluster.
"""
if
f
is
None
:
try
:
...
...
@@ -48,7 +46,7 @@ if sys.version_info[:2] > (3, 4):
# I enable my implementation only for some python version just to
# be sure the Python internal do not change. If this work with
# other python version, you can enable it.
simple_extract_stack
=
traceback
.
extract_stack
simple_extract_stack
=
traceback
.
extract_stack
# noqa
def
add_tag_trace
(
thing
,
user_line
=
1
):
...
...
@@ -190,8 +188,8 @@ def deprecated(filename, msg=''):
def
g
(
*
args
,
**
kwargs
):
if
printme
[
0
]:
print
(
'WARNING:
%
s.
%
s deprecated.
%
s'
\
%
(
filename
,
f
.
__name__
,
msg
))
print
(
'WARNING:
%
s.
%
s deprecated.
%
s'
%
(
filename
,
f
.
__name__
,
msg
))
printme
[
0
]
=
False
return
f
(
*
args
,
**
kwargs
)
return
g
...
...
@@ -220,7 +218,7 @@ def difference(seq1, seq2):
raise
Exception
(
'not worth it'
)
set2
=
set
(
seq2
)
return
[
x
for
x
in
seq1
if
x
not
in
set2
]
except
Exception
as
e
:
except
Exception
:
# maybe a seq2 element is not hashable
# maybe seq2 is too short
# -> use O(len(seq1) * len(seq2)) algo
...
...
@@ -311,11 +309,11 @@ def comm_guard(type1, type2):
old_f
=
f
.
func_globals
[
f
.
__name__
]
def
new_f
(
arg1
,
arg2
,
*
rest
):
if
(
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
))
\
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg2
,
type2
)):
if
(
(
type1
is
ANY_TYPE
or
isinstance
(
arg1
,
type1
))
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg2
,
type2
)
)):
pass
elif
(
type1
is
ANY_TYPE
or
isinstance
(
arg2
,
type1
))
\
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg1
,
type2
)):
elif
(
(
type1
is
ANY_TYPE
or
isinstance
(
arg2
,
type1
))
and
(
type2
is
ANY_TYPE
or
isinstance
(
arg1
,
type2
)
)):
arg1
,
arg2
=
arg2
,
arg1
else
:
return
old_f
(
arg1
,
arg2
,
*
rest
)
...
...
@@ -337,7 +335,8 @@ def comm_guard(type1, type2):
return
type
.
__name__
new_f
.
__doc__
=
(
str
(
old_f
.
__doc__
)
+
"
\n
"
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,
type2
)])
+
", "
.
join
([
typename
(
type
)
for
type
in
(
type1
,
type2
)])
+
"
\n
"
+
str
(
f
.
__doc__
or
""
))
return
new_f
...
...
@@ -406,15 +405,16 @@ def give_variables_names(variables):
This function is idempotent."""
names
=
map
(
lambda
var
:
var
.
name
,
variables
)
h
=
hist
(
names
)
bad_var
=
lambda
var
:
not
var
.
name
or
h
[
var
.
name
]
>
1
def
bad_var
(
var
):
return
not
var
.
name
or
h
[
var
.
name
]
>
1
for
i
,
var
in
enumerate
(
filter
(
bad_var
,
variables
)):
var
.
name
=
(
var
.
name
or
""
)
+
"_
%
d"
%
i
if
not
unique
(
map
(
str
,
variables
)):
raise
ValueError
(
"Not all variables have unique names."
"Maybe you've named some of the variables identically"
)
raise
ValueError
(
"Not all variables have unique names. Maybe you've "
"named some of the variables identically"
)
return
variables
...
...
theano/gof/vm.py
浏览文件 @
4cf7afb4
...
...
@@ -53,7 +53,8 @@ AddConfigVar('vm.lazy',
in_c_key
=
False
)
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
):
def
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
):
reallocated_info
=
{}
viewed_by
=
{}
for
var
in
fgraph
.
variables
:
...
...
@@ -74,14 +75,14 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
ins
=
None
if
dmap
and
idx_o
in
dmap
:
idx_v
=
dmap
[
idx_o
]
assert
len
(
idx_v
)
==
1
,
"Here we only support the possibility to destroy one input"
assert
len
(
idx_v
)
==
1
,
(
"Here we only support the possibility"
" to destroy one input"
)
ins
=
node
.
inputs
[
idx_v
[
0
]]
if
vmap
and
idx_o
in
vmap
:
assert
ins
is
None
idx_v
=
vmap
[
idx_o
]
assert
len
(
idx_v
)
==
1
,
"Here we only support the possibility to view one input"
assert
len
(
idx_v
)
==
1
,
(
"Here we only support the possibility"
" to view one input"
)
ins
=
node
.
inputs
[
idx_v
[
0
]]
if
ins
is
not
None
:
assert
isinstance
(
ins
,
theano
.
Variable
)
...
...
@@ -92,10 +93,11 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for
ins
in
node
.
inputs
:
assert
not
(
ins
in
view_of
and
viewed_by
[
ins
])
if
(
getattr
(
ins
,
'ndim'
,
None
)
==
0
and
not
storage_map
[
ins
][
0
]
and
ins
not
in
fgraph
.
outputs
and
ins
.
owner
and
all
([
compute_map_re
[
v
][
0
]
for
v
in
dependencies
.
get
(
ins
,
[])])
and
ins
not
in
allocated
):
if
(
getattr
(
ins
,
'ndim'
,
None
)
==
0
and
not
storage_map
[
ins
][
0
]
and
ins
not
in
fgraph
.
outputs
and
ins
.
owner
and
all
([
compute_map_re
[
v
][
0
]
for
v
in
dependencies
.
get
(
ins
,
[])])
and
ins
not
in
allocated
):
# Constant Memory cannot be changed
# Constant and shared variables' storage_map value is not empty
reuse_out
=
None
...
...
@@ -105,8 +107,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if
reuse_out
:
break
for
out
in
order
[
i
]
.
outputs
:
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
out
not
in
pre_allocated
and
ins
.
type
==
out
.
type
):
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
out
not
in
pre_allocated
and
ins
.
type
==
out
.
type
):
reuse_out
=
out
pre_allocated
.
add
(
out
)
allocated
.
add
(
ins
)
...
...
@@ -122,8 +125,9 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
if
reuse_out
:
break
for
out
in
order
[
i
]
.
outputs
:
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
out
not
in
pre_allocated
and
ins
.
type
==
out
.
type
):
if
(
getattr
(
out
,
'ndim'
,
None
)
==
0
and
out
not
in
pre_allocated
and
ins
.
type
==
out
.
type
):
reuse_out
=
out
pre_allocated
.
add
(
out
)
allocated
.
add
(
ins
)
...
...
@@ -508,7 +512,8 @@ class Stack(VM):
st
=
"c"
self
.
variable_strides
[
var
]
=
st
except
Exception
:
link
.
raise_with_op
(
current_apply
,
link
.
raise_with_op
(
current_apply
,
self
.
thunks
[
self
.
node_idx
[
current_apply
]],
storage_map
=
storage_map
)
for
o
in
current_apply
.
outputs
:
...
...
@@ -521,9 +526,9 @@ class Stack(VM):
for
i
in
current_apply
.
inputs
:
# Garbage Collection -> check if anybody else uses
# this input
if
(
dependencies
[
i
]
and
i
.
owner
and
i
not
in
self
.
outputs
):
if
(
dependencies
[
i
]
and
i
.
owner
and
i
not
in
self
.
outputs
):
if
all
(
compute_map
[
v
][
0
]
for
v
in
dependencies
[
i
]):
storage_map
[
i
][
0
]
=
None
...
...
@@ -544,10 +549,13 @@ class Stack(VM):
'destroy_map'
,
False
)):
warnings
.
warn
(
"There was a bug that existed in the default Theano configuration,"
" only in the development version between July 5th 2012"
" and July 30th 2012. This was not in a released version."
" The bug was affecting this script."
,
"There was a bug that existed in "
"the default Theano configuration,"
" only in the development version "
"between July 5th 2012 and "
"July 30th 2012. This was not in "
"a released version. The bug was "
"affecting this script."
,
# The stack level is not good when
# inside a Scan.
stacklevel
=
3
...
...
@@ -578,7 +586,8 @@ class Stack(VM):
self
.
call_times
[
current_idx
]
+=
dt
except
Exception
:
link
.
raise_with_op
(
current_apply
,
link
.
raise_with_op
(
current_apply
,
self
.
thunks
[
self
.
node_idx
[
current_apply
]],
storage_map
=
storage_map
)
...
...
@@ -639,7 +648,7 @@ class Stack(VM):
if
self
.
allow_gc
:
for
v
in
storage_map
:
if
v
.
owner
and
not
v
in
self
.
outputs
:
if
v
.
owner
and
v
not
in
self
.
outputs
:
if
compute_map
[
v
][
0
]
==
2
:
continue
else
:
...
...
@@ -840,7 +849,6 @@ class VM_Linker(link.LocalLinker):
vars_idx_inv
[
i
]
=
var
# put storage_map and compute_map into a int-based scheme
n_applies
=
len
(
nodes
)
storage_map_list
=
[
storage_map
[
vars_idx_inv
[
i
]]
for
i
in
xrange
(
len
(
vars_idx_inv
))]
compute_map_list
=
[
compute_map
[
vars_idx_inv
[
i
]]
...
...
@@ -988,7 +996,8 @@ class VM_Linker(link.LocalLinker):
else
:
dependencies
=
self
.
compute_gc_dependencies
(
storage_map
)
reallocated_info
=
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
)
reallocated_info
=
calculate_reallocate_info
(
order
,
fgraph
,
storage_map
,
compute_map_re
,
dependencies
)
for
node
in
order
:
try
:
...
...
@@ -1014,7 +1023,8 @@ class VM_Linker(link.LocalLinker):
lazy
=
config
.
vm
.
lazy
if
lazy
is
None
:
lazy
=
not
all
([(
not
th
.
lazy
)
for
th
in
thunks
])
if
not
(
lazy
or
(
config
.
profile
and
config
.
profile_memory
)
or
self
.
use_cloop
or
self
.
callback
):
if
not
(
lazy
or
(
config
.
profile
and
config
.
profile_memory
)
or
self
.
use_cloop
or
self
.
callback
):
for
pair
in
reallocated_info
.
values
():
storage_map
[
pair
[
1
]]
=
storage_map
[
pair
[
0
]]
...
...
@@ -1024,10 +1034,10 @@ class VM_Linker(link.LocalLinker):
for
node
in
order
:
clear_after_this_thunk
=
[]
for
input
in
node
.
inputs
:
if
(
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
and
input
not
in
reallocated_info
.
keys
()):
if
(
input
in
computed
and
input
not
in
fgraph
.
outputs
and
node
==
last_user
[
input
]
and
input
not
in
reallocated_info
.
keys
()):
clear_after_this_thunk
.
append
(
storage_map
[
input
])
post_thunk_clear
.
append
(
clear_after_this_thunk
)
else
:
...
...
theano/sandbox/cuda/opt_util.py
浏览文件 @
4cf7afb4
...
...
@@ -2,7 +2,6 @@ from functools import wraps
import
numpy
import
theano
from
theano
import
scalar
as
scal
,
Constant
from
theano.gof
import
local_optimizer
from
theano.tensor
import
(
DimShuffle
,
get_scalar_constant_value
,
...
...
theano/sandbox/cuda/tests/test_fftconv.py
浏览文件 @
4cf7afb4
...
...
@@ -7,13 +7,13 @@ from theano.tests import unittest_tools as utt
# Skip tests if cuda_ndarray is not available.
from
nose.plugins.skip
import
SkipTest
import
theano.sandbox.cuda
as
cuda_ndarray
if
not
cuda_ndarray
.
cuda_available
:
if
not
cuda_ndarray
.
cuda_available
:
# noqa
raise
SkipTest
(
'Optional package cuda not available'
)
from
theano.misc.pycuda_init
import
pycuda_available
if
not
pycuda_available
:
if
not
pycuda_available
:
# noqa
raise
SkipTest
(
'Optional package pycuda not available'
)
from
theano.sandbox.cuda.fftconv
import
scikits_cuda_available
if
not
scikits_cuda_available
:
if
not
scikits_cuda_available
:
# noqa
raise
SkipTest
(
'Optional package scikits.cuda not available'
)
from
theano.sandbox.cuda
import
float32_shared_constructor
as
shared
...
...
theano/tensor/tests/_test_mpi_roundtrip.py
浏览文件 @
4cf7afb4
...
...
@@ -2,13 +2,14 @@
# mpiexec -np 2 python _test_mpi_roundtrip.py
from
mpi4py
import
MPI
comm
=
MPI
.
COMM_WORLD
import
theano
from
theano.tensor.io
import
send
,
recv
,
mpi_cmps
from
theano.gof.sched
import
sort_schedule_fn
import
numpy
as
np
from
sys
import
stdout
,
stderr
,
exit
comm
=
MPI
.
COMM_WORLD
rank
=
comm
.
Get_rank
()
size
=
comm
.
Get_size
()
...
...
theano/tests/disturb_mem.py
浏览文件 @
4cf7afb4
from
datetime
import
datetime
__authors__
=
"Ian Goodfellow"
__credits__
=
[
"Ian Goodfellow"
]
__license__
=
"3-clause BSD"
__maintainer__
=
"Ian Goodfellow"
__email__
=
"goodfeli@iro"
from
datetime
import
datetime
def
disturb_mem
():
# Allocate a time-dependent amount of objects to increase
...
...
theano/tests/main.py
浏览文件 @
4cf7afb4
from
__future__
import
print_function
import
os
,
unittest
,
sys
import
nose.plugins.builtin
import
os
import
unittest
import
sys
from
nose.config
import
Config
from
nose.plugins.manager
import
PluginManager
from
numpy.testing.nosetester
import
import_nose
,
NoseTester
import
nose.plugins.builtin
from
numpy.testing.nosetester
import
NoseTester
from
numpy.testing.noseclasses
import
KnownFailure
,
NumpyTestProgram
...
...
@@ -31,7 +34,7 @@ class TheanoNoseTester(NoseTester):
:type extra_argv: list
:param extra_argv: List with any extra arguments to pass to nosetests.
"""
#self.package_path = os.path.abspath(self.package_path)
#
self.package_path = os.path.abspath(self.package_path)
argv
=
[
__file__
,
self
.
package_path
]
argv
+=
[
'--verbosity'
,
str
(
verbose
)]
if
extra_argv
:
...
...
@@ -39,8 +42,6 @@ class TheanoNoseTester(NoseTester):
return
argv
def
_show_system_info
(
self
):
nose
=
import_nose
()
import
theano
print
(
"Theano version
%
s"
%
theano
.
__version__
)
theano_dir
=
os
.
path
.
dirname
(
theano
.
__file__
)
...
...
@@ -55,16 +56,14 @@ class TheanoNoseTester(NoseTester):
Takes the same arguments as `test`.
"""
# fail with nice error message if nose is not present
nose
=
import_nose
()
# compile argv
argv
=
self
.
_test_argv
(
verbose
,
extra_argv
)
# numpy way of doing coverage
if
coverage
:
argv
+=
[
'--cover-package=
%
s'
%
self
.
package_name
,
'--with-coverage'
,
'--cover-tests'
,
'--cover-inclusive'
,
'--cover-erase'
]
argv
+=
[
'--cover-package=
%
s'
%
self
.
package_name
,
'--with-coverage'
,
'--cover-tests'
,
'--cover-inclusive'
,
'--cover-erase'
]
# Capture output only if needed
if
not
capture
:
...
...
@@ -91,7 +90,8 @@ class TheanoNoseTester(NoseTester):
:param extra_argv: List with any extra arguments to pass to nosetests.
:type coverage: bool
:param coverage: If True, report coverage of Theano code. Default is False.
:param coverage: If True, report coverage of Theano
code. Default is False.
:type capture: bool
:param capture: If True, capture the standard output of the tests, like
...
...
@@ -134,8 +134,6 @@ class TheanoNoseTester(NoseTester):
def
main
(
modulename
):
debug
=
False
if
0
:
unittest
.
main
()
elif
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
]
==
"--debug"
:
...
...
theano/tests/test_flake8.py
浏览文件 @
4cf7afb4
...
...
@@ -20,7 +20,6 @@ __contact__ = "Saizheng Zhang <saizhenglisa..at..gmail.com>"
whitelist_flake8
=
[
"__init__.py"
,
"version.py"
,
"tests/test_gradient.py"
,
"tests/test_config.py"
,
"tests/diverse_tests.py"
,
...
...
@@ -31,37 +30,20 @@ whitelist_flake8 = [
"tests/test_record.py"
,
"tests/__init__.py"
,
"tests/test_updates.py"
,
"tests/main.py"
,
"tests/test_pickle_unpickle_theano_fn.py"
,
"tests/test_determinism.py"
,
"tests/record.py"
,
"tests/test_printing.py"
,
"tests/test_tutorial.py"
,
"tests/disturb_mem.py"
,
"tests/unittest_tools.py"
,
"compile/ops.py"
,
"compile/debugmode.py"
,
"compile/function.py"
,
"compile/pfunc.py"
,
"compile/mode.py"
,
"compile/profilemode.py"
,
"compile/builders.py"
,
"compile/__init__.py"
,
"compile/profiling.py"
,
"compile/function_module.py"
,
"compile/sharedvalue.py"
,
"compile/monitormode.py"
,
"compile/io.py"
,
"compile/module.py"
,
"compile/tests/test_builders.py"
,
"compile/tests/test_misc.py"
,
"compile/tests/test_monitormode.py"
,
"compile/tests/test_function_module.py"
,
"compile/tests/test_inplace_opt_for_value.py"
,
"compile/tests/test_shared.py"
,
"compile/tests/test_ops.py"
,
"compile/tests/test_pfunc.py"
,
"compile/tests/test_module.py"
,
"compile/tests/test_debugmode.py"
,
"compile/tests/test_profiling.py"
,
"typed_list/type.py"
,
...
...
@@ -94,16 +76,13 @@ whitelist_flake8 = [
"tensor/io.py"
,
"tensor/elemwise_cgen.py"
,
"tensor/raw_random.py"
,
"tensor/randomstreams.py"
,
"tensor/blas_scipy.py"
,
"tensor/basic.py"
,
"tensor/tests/test_subtensor.py"
,
"tensor/tests/test_utils.py"
,
"tensor/tests/test_nlinalg.py"
,
"tensor/tests/test_randomstreams.py"
,
"tensor/tests/test_shared_randomstreams.py"
,
"tensor/tests/test_misc.py"
,
"tensor/tests/test_naacl09.py"
,
"tensor/tests/mlp_test.py"
,
"tensor/tests/test_opt_uncanonicalize.py"
,
"tensor/tests/test_opt.py"
,
...
...
@@ -155,7 +134,6 @@ whitelist_flake8 = [
"sandbox/test_theano_object.py"
,
"sandbox/test_scan.py"
,
"sandbox/rng_mrg.py"
,
"sandbox/downsample.py"
,
"sandbox/solve.py"
,
"sandbox/theano_object.py"
,
"sandbox/scan.py"
,
...
...
@@ -190,7 +168,6 @@ whitelist_flake8 = [
"sandbox/cuda/nvcc_compiler.py"
,
"sandbox/cuda/neighbours.py"
,
"sandbox/cuda/tests/walltime.py"
,
"sandbox/cuda/tests/test_fftconv.py"
,
"sandbox/cuda/tests/test_gradient.py"
,
"sandbox/cuda/tests/test_neighbours.py"
,
"sandbox/cuda/tests/test_conv_cuda_ndarray.py"
,
...
...
@@ -218,7 +195,6 @@ whitelist_flake8 = [
"sandbox/scan_module/tests/test_utils.py"
,
"sandbox/scan_module/tests/test_scan.py"
,
"sandbox/linalg/ops.py"
,
"sandbox/linalg/kron.py"
,
"sandbox/linalg/__init__.py"
,
"sandbox/linalg/tests/test_linalg.py"
,
"sandbox/gpuarray/comp.py"
,
...
...
@@ -288,24 +264,12 @@ whitelist_flake8 = [
"sparse/sandbox/truedot.py"
,
"sparse/sandbox/sp.py"
,
"gof/destroyhandler.py"
,
"gof/vm.py"
,
"gof/cutils.py"
,
"gof/compiledir.py"
,
"gof/unify.py"
,
"gof/lazylinker_c.py"
,
"gof/optdb.py"
,
"gof/utils.py"
,
"gof/graph.py"
,
"gof/callcache.py"
,
"gof/python25.py"
,
"gof/type.py"
,
"gof/__init__.py"
,
"gof/cc.py"
,
"gof/opt.py"
,
"gof/compilelock.py"
,
"gof/link.py"
,
"gof/sched.py"
,
"gof/toolbox.py"
,
"gof/fg.py"
,
"gof/op.py"
,
"gof/cmodule.py"
,
...
...
@@ -322,9 +286,6 @@ whitelist_flake8 = [
"gof/tests/test_cc.py"
,
"gof/tests/test_compute_test_value.py"
,
"gof/sandbox/equilibrium.py"
,
"sandbox/cuda/opt_util.py"
,
"gof/tests/test_utils.py"
,
"tensor/tests/_test_mpi_roundtrip.py"
,
]
...
...
theano/version.py
浏览文件 @
4cf7afb4
try
:
from
theano.generated_version
import
*
from
theano.generated_version
import
*
# noqa
except
ImportError
:
short_version
=
'unknown'
version
=
'unknown'
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论