Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e0821760
提交
e0821760
authored
8月 14, 2017
作者:
abergeron
提交者:
GitHub
8月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6000 from ReyhaneAskari/new_destroy_handler
New destroy handler
上级
26d47057
f03092a0
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
147 行增加
和
38 行删除
+147
-38
faq.txt
doc/faq.txt
+9
-0
modes.txt
doc/tutorial/modes.txt
+2
-1
debugmode.py
theano/compile/debugmode.py
+20
-19
function_module.py
theano/compile/function_module.py
+8
-3
configdefaults.py
theano/configdefaults.py
+1
-1
destroyhandler.py
theano/gof/destroyhandler.py
+47
-9
test_destroyhandler.py
theano/gof/tests/test_destroyhandler.py
+12
-0
test_dnn.py
theano/gpuarray/tests/test_dnn.py
+2
-0
test_linalg.py
theano/gpuarray/tests/test_linalg.py
+6
-0
test_opt.py
theano/gpuarray/tests/test_opt.py
+2
-0
test_scan.py
theano/scan_module/tests/test_scan.py
+3
-0
test_basic.py
theano/sparse/tests/test_basic.py
+1
-0
test_nnet.py
theano/tensor/nnet/tests/test_nnet.py
+1
-2
test_opt.py
theano/tensor/nnet/tests/test_opt.py
+4
-0
opt.py
theano/tensor/opt.py
+3
-3
test_basic.py
theano/tensor/tests/test_basic.py
+3
-0
test_blas.py
theano/tensor/tests/test_blas.py
+2
-0
test_opt.py
theano/tensor/tests/test_opt.py
+2
-0
test_sharedvar.py
theano/tensor/tests/test_sharedvar.py
+5
-0
unittest_tools.py
theano/tests/unittest_tools.py
+14
-0
没有找到文件。
doc/faq.txt
浏览文件 @
e0821760
...
@@ -43,6 +43,8 @@ CPUs. In fact, Theano asks g++ what are the equivalent flags it uses, and re-use
...
@@ -43,6 +43,8 @@ CPUs. In fact, Theano asks g++ what are the equivalent flags it uses, and re-use
them directly.
them directly.
.. _faster-theano-function-compilation:
Faster Theano Function Compilation
Faster Theano Function Compilation
----------------------------------
----------------------------------
...
@@ -67,6 +69,13 @@ compilation but it will also use more memory because
...
@@ -67,6 +69,13 @@ compilation but it will also use more memory because
resulting in a trade off between speed of compilation and memory
resulting in a trade off between speed of compilation and memory
usage.
usage.
Alternatively, if the graph is big, using the flag ``cycle_detection=fast``
will speedup the computations by removing some of the inplace
optimizations. This would allow theano to skip a time consuming cycle
detection algorithm. If the graph is big enough,we suggest that you use
this flag instead of ``optimizer_excluding=inplace``. It will result in a
computation time that is in between fast compile and fast run.
Theano flag `reoptimize_unpickled_function` controls if an unpickled
Theano flag `reoptimize_unpickled_function` controls if an unpickled
theano function should reoptimize its graph or not. Theano users can
theano function should reoptimize its graph or not. Theano users can
use the standard python pickle tools to save a compiled theano
use the standard python pickle tools to save a compiled theano
...
...
doc/tutorial/modes.txt
浏览文件 @
e0821760
...
@@ -225,7 +225,8 @@ stabilize "+++++" "++" Only applies stability opts
...
@@ -225,7 +225,8 @@ stabilize "+++++" "++" Only applies stability opts
================= ============ ============== ==================================================
================= ============ ============== ==================================================
For a detailed list of the specific optimizations applied for each of these
For a detailed list of the specific optimizations applied for each of these
optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization`.
optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization` and
:ref:`faster-theano-function-compilation` for other trade-off.
.. _using_debugmode:
.. _using_debugmode:
...
...
theano/compile/debugmode.py
浏览文件 @
e0821760
...
@@ -2273,25 +2273,26 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
...
@@ -2273,25 +2273,26 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
"of"
,
len
(
li
),
"events was stable."
,
"of"
,
len
(
li
),
"events was stable."
,
file
=
sys
.
stderr
)
file
=
sys
.
stderr
)
self
.
fgraph
=
fgraph
self
.
fgraph
=
fgraph
destroy_handler_added
=
False
if
theano
.
config
.
cycle_detection
==
'regular'
:
for
feature
in
fgraph
.
_features
:
destroy_handler_added
=
False
if
isinstance
(
feature
,
gof
.
DestroyHandler
):
for
feature
in
fgraph
.
_features
:
destroy_handler_added
=
True
if
isinstance
(
feature
,
gof
.
DestroyHandler
):
break
destroy_handler_added
=
True
if
not
destroy_handler_added
:
break
fgraph
.
attach_feature
(
gof
.
DestroyHandler
())
if
not
destroy_handler_added
:
for
o
in
fgraph
.
outputs
:
fgraph
.
attach_feature
(
gof
.
DestroyHandler
())
try
:
for
o
in
fgraph
.
outputs
:
with
change_flags
(
compute_test_value
=
config
.
compute_test_value_opt
):
try
:
fgraph
.
replace_validate
(
o
,
_output_guard
(
o
),
reason
=
'output_guard'
)
with
change_flags
(
compute_test_value
=
config
.
compute_test_value_opt
):
raise
Exception
(
"Output variable
%
s required output_guard, "
fgraph
.
replace_validate
(
o
,
_output_guard
(
o
),
reason
=
'output_guard'
)
"how was this output left unprotected against "
raise
Exception
(
"Output variable
%
s required output_guard, "
"destructive operations?"
%
o
)
"how was this output left unprotected against "
"destructive operations?"
%
o
)
except
gof
.
InconsistencyError
:
# This output is already impossible to destroy.
except
gof
.
InconsistencyError
:
# No guard necessary
# This output is already impossible to destroy.
pass
# No guard necessary
pass
linker
=
_Linker
(
self
)
linker
=
_Linker
(
self
)
...
...
theano/compile/function_module.py
浏览文件 @
e0821760
...
@@ -132,6 +132,11 @@ class Supervisor:
...
@@ -132,6 +132,11 @@ class Supervisor:
self
.
protected
=
list
(
protected
)
self
.
protected
=
list
(
protected
)
def
validate
(
self
,
fgraph
):
def
validate
(
self
,
fgraph
):
if
config
.
cycle_detection
==
'fast'
and
hasattr
(
fgraph
,
'has_destroyers'
):
if
fgraph
.
has_destroyers
(
self
.
protected
):
raise
gof
.
InconsistencyError
(
"Trying to destroy a protected"
"Variable."
)
return
True
if
not
hasattr
(
fgraph
,
'destroyers'
):
if
not
hasattr
(
fgraph
,
'destroyers'
):
return
True
return
True
for
r
in
self
.
protected
+
list
(
fgraph
.
outputs
):
for
r
in
self
.
protected
+
list
(
fgraph
.
outputs
):
...
@@ -190,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
...
@@ -190,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
for
spec
,
input
in
zip
(
input_specs
,
fgraph
.
inputs
)
for
spec
,
input
in
zip
(
input_specs
,
fgraph
.
inputs
)
if
not
(
spec
.
mutable
or
if
not
(
spec
.
mutable
or
(
hasattr
(
fgraph
,
'destroyers'
)
and
(
hasattr
(
fgraph
,
'destroyers'
)
and
fgraph
.
destroyers
(
input
)))))
fgraph
.
has_destroyers
([
input
]
)))))
# If named nodes are replaced, keep the name
# If named nodes are replaced, keep the name
for
feature
in
std_fgraph
.
features
:
for
feature
in
std_fgraph
.
features
:
...
@@ -1111,7 +1116,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1111,7 +1116,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# We can't use fgraph.inputs as this don't include Constant Value.
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs
=
gof
.
graph
.
inputs
(
fgraph
.
outputs
)
all_graph_inputs
=
gof
.
graph
.
inputs
(
fgraph
.
outputs
)
has_destroyers
=
hasattr
(
fgraph
,
'get_destroyers_of
'
)
has_destroyers
_attr
=
hasattr
(
fgraph
,
'has_destroyers
'
)
for
i
in
xrange
(
len
(
fgraph
.
outputs
)):
for
i
in
xrange
(
len
(
fgraph
.
outputs
)):
views_of_output_i
=
set
()
views_of_output_i
=
set
()
...
@@ -1142,7 +1147,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1142,7 +1147,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# being updated
# being updated
if
input_j
in
updated_fgraph_inputs
:
if
input_j
in
updated_fgraph_inputs
:
continue
continue
if
input_j
in
views_of_output_i
and
not
(
has_destroyers
and
fgraph
.
get_destroyers_of
(
input_j
)):
if
input_j
in
views_of_output_i
and
not
(
has_destroyers
_attr
and
fgraph
.
has_destroyers
([
input_j
]
)):
# We don't put deep_copy_op if the input and the
# We don't put deep_copy_op if the input and the
# output have borrow==True
# output have borrow==True
if
input_j
in
fgraph
.
inputs
:
if
input_j
in
fgraph
.
inputs
:
...
...
theano/configdefaults.py
浏览文件 @
e0821760
...
@@ -1575,7 +1575,7 @@ AddConfigVar('cycle_detection',
...
@@ -1575,7 +1575,7 @@ AddConfigVar('cycle_detection',
"The interaction of which one give the lower peak memory usage is"
"The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak"
"complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain.
"
,
"memory usage, triyng both could give you a small gain."
,
EnumStr
(
'regular'
,
'fast'
),
EnumStr
(
'regular'
,
'fast'
),
in_c_key
=
False
)
in_c_key
=
False
)
...
...
theano/gof/destroyhandler.py
浏览文件 @
e0821760
...
@@ -250,7 +250,7 @@ def fast_inplace_check(inputs):
...
@@ -250,7 +250,7 @@ def fast_inplace_check(inputs):
inputs
=
[
i
for
i
in
inputs
if
inputs
=
[
i
for
i
in
inputs
if
not
isinstance
(
i
,
graph
.
Constant
)
and
not
isinstance
(
i
,
graph
.
Constant
)
and
not
fgraph
.
destroyers
(
i
)
and
not
fgraph
.
has_destroyers
([
i
]
)
and
i
not
in
protected_inputs
]
i
not
in
protected_inputs
]
return
inputs
return
inputs
...
@@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
<unknown>
<unknown>
"""
"""
pickle_rm_attr
=
[
"destroyers"
]
pickle_rm_attr
=
[
"destroyers"
,
"has_destroyers"
]
def
__init__
(
self
,
do_imports_on_attach
=
True
,
algo
=
None
):
def
__init__
(
self
,
do_imports_on_attach
=
True
,
algo
=
None
):
self
.
fgraph
=
None
self
.
fgraph
=
None
...
@@ -394,6 +394,41 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -394,6 +394,41 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
return
[]
return
[]
fgraph
.
destroyers
=
get_destroyers_of
fgraph
.
destroyers
=
get_destroyers_of
def
has_destroyers
(
protected_list
):
if
self
.
algo
!=
'fast'
:
droot
,
_
,
root_destroyer
=
self
.
refresh_droot_impact
()
for
protected_var
in
protected_list
:
try
:
root_destroyer
[
droot
[
protected_var
]]
return
True
except
KeyError
:
pass
return
False
def
recursive_destroys_finder
(
protected_var
):
# protected_var is the idx'th input of app.
for
(
app
,
idx
)
in
protected_var
.
clients
:
if
app
==
'output'
:
continue
destroy_maps
=
getattr
(
app
.
op
,
'destroy_map'
,
{})
.
values
()
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
return
True
for
var_idx
in
getattr
(
app
.
op
,
'view_map'
,
{})
.
keys
():
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
if
recursive_destroys_finder
(
app
.
outputs
[
var_idx
]):
return
True
return
False
for
protected_var
in
protected_list
:
if
recursive_destroys_finder
(
protected_var
):
return
True
return
False
fgraph
.
has_destroyers
=
has_destroyers
def
refresh_droot_impact
(
self
):
def
refresh_droot_impact
(
self
):
"""
"""
Makes sure self.droot, self.impact, and self.root_destroyer are up to
Makes sure self.droot, self.impact, and self.root_destroyer are up to
...
@@ -416,6 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -416,6 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
del
self
.
stale_droot
del
self
.
stale_droot
assert
self
.
fgraph
.
destroyer_handler
is
self
assert
self
.
fgraph
.
destroyer_handler
is
self
delattr
(
self
.
fgraph
,
'destroyers'
)
delattr
(
self
.
fgraph
,
'destroyers'
)
delattr
(
self
.
fgraph
,
'has_destroyers'
)
delattr
(
self
.
fgraph
,
'destroy_handler'
)
delattr
(
self
.
fgraph
,
'destroy_handler'
)
self
.
fgraph
=
None
self
.
fgraph
=
None
...
@@ -452,11 +488,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -452,11 +488,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if
len
(
v
)
>
0
:
if
len
(
v
)
>
0
:
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
"Destroyed variable has view_map. "
+
str
(
reason
))
"Destroyed variable has view_map. "
+
str
(
reason
))
elif
d
:
elif
d
:
d
=
d
.
get
(
inp_idx2
,
[])
d
=
d
.
get
(
inp_idx2
,
[])
if
len
(
d
)
>
0
:
if
len
(
d
)
>
0
:
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
self
.
fail_validate
[
app
]
=
theano
.
gof
.
InconsistencyError
(
"Destroyed variable has destroy_map. "
+
str
(
reason
))
"Destroyed variable has destroy_map. "
+
str
(
reason
))
# These 2 assertions are commented since this function is called so many times
# These 2 assertions are commented since this function is called so many times
# but they should be true.
# but they should be true.
...
@@ -474,13 +510,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
...
@@ -474,13 +510,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
# If it's a destructive op, add it to our watch list
if
getattr
(
app
.
op
,
'destroy_map'
,
None
):
dmap
=
getattr
(
app
.
op
,
'destroy_map'
,
None
)
vmap
=
getattr
(
app
.
op
,
'view_map'
,
{})
if
dmap
:
self
.
destroyers
.
add
(
app
)
self
.
destroyers
.
add
(
app
)
if
self
.
algo
==
'fast'
:
if
self
.
algo
==
'fast'
:
self
.
fast_destroy
(
app
,
reason
)
self
.
fast_destroy
(
app
,
reason
)
# add this symbol to the forward and backward maps
# add this symbol to the forward and backward maps
for
o_idx
,
i_idx_list
in
iteritems
(
getattr
(
app
.
op
,
'view_map'
,
{})
):
for
o_idx
,
i_idx_list
in
iteritems
(
vmap
):
if
len
(
i_idx_list
)
>
1
:
if
len
(
i_idx_list
)
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'destroying this output invalidates multiple inputs'
,
'destroying this output invalidates multiple inputs'
,
...
...
theano/gof/tests/test_destroyhandler.py
浏览文件 @
e0821760
...
@@ -11,6 +11,7 @@ from theano.gof.opt import (OpKeyOptimizer, PatternSub, NavigatorOptimizer,
...
@@ -11,6 +11,7 @@ from theano.gof.opt import (OpKeyOptimizer, PatternSub, NavigatorOptimizer,
from
theano.gof
import
destroyhandler
from
theano.gof
import
destroyhandler
from
theano.gof.fg
import
FunctionGraph
,
InconsistencyError
from
theano.gof.fg
import
FunctionGraph
,
InconsistencyError
from
theano.gof.toolbox
import
ReplaceValidate
from
theano.gof.toolbox
import
ReplaceValidate
from
theano.tests.unittest_tools
import
assertFailure_fast
from
theano.configparser
import
change_flags
from
theano.configparser
import
change_flags
...
@@ -169,6 +170,7 @@ def test_misc():
...
@@ -169,6 +170,7 @@ def test_misc():
######################
######################
@assertFailure_fast
def
test_aliased_inputs_replacement
():
def
test_aliased_inputs_replacement
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
tv
=
transpose_view
(
x
)
tv
=
transpose_view
(
x
)
...
@@ -200,6 +202,7 @@ def test_indestructible():
...
@@ -200,6 +202,7 @@ def test_indestructible():
consistent
(
g
)
consistent
(
g
)
@assertFailure_fast
def
test_usage_loop_through_views_2
():
def
test_usage_loop_through_views_2
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e0
=
transpose_view
(
transpose_view
(
sigmoid
(
x
)))
e0
=
transpose_view
(
transpose_view
(
sigmoid
(
x
)))
...
@@ -210,6 +213,7 @@ def test_usage_loop_through_views_2():
...
@@ -210,6 +213,7 @@ def test_usage_loop_through_views_2():
inconsistent
(
g
)
# we cut off the path to the sigmoid
inconsistent
(
g
)
# we cut off the path to the sigmoid
@assertFailure_fast
def
test_destroyers_loop
():
def
test_destroyers_loop
():
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
...
@@ -259,6 +263,7 @@ def test_aliased_inputs2():
...
@@ -259,6 +263,7 @@ def test_aliased_inputs2():
inconsistent
(
g
)
inconsistent
(
g
)
@assertFailure_fast
def
test_aliased_inputs_tolerate
():
def
test_aliased_inputs_tolerate
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
add_in_place_2
(
x
,
x
)
e
=
add_in_place_2
(
x
,
x
)
...
@@ -273,6 +278,7 @@ def test_aliased_inputs_tolerate2():
...
@@ -273,6 +278,7 @@ def test_aliased_inputs_tolerate2():
inconsistent
(
g
)
inconsistent
(
g
)
@assertFailure_fast
def
test_same_aliased_inputs_ignored
():
def
test_same_aliased_inputs_ignored
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
add_in_place_3
(
x
,
x
)
e
=
add_in_place_3
(
x
,
x
)
...
@@ -280,6 +286,7 @@ def test_same_aliased_inputs_ignored():
...
@@ -280,6 +286,7 @@ def test_same_aliased_inputs_ignored():
consistent
(
g
)
consistent
(
g
)
@assertFailure_fast
def
test_different_aliased_inputs_ignored
():
def
test_different_aliased_inputs_ignored
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
add_in_place_3
(
x
,
transpose_view
(
x
))
e
=
add_in_place_3
(
x
,
transpose_view
(
x
))
...
@@ -314,6 +321,7 @@ def test_indirect():
...
@@ -314,6 +321,7 @@ def test_indirect():
inconsistent
(
g
)
inconsistent
(
g
)
@assertFailure_fast
def
test_indirect_2
():
def
test_indirect_2
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e0
=
transpose_view
(
x
)
e0
=
transpose_view
(
x
)
...
@@ -325,6 +333,7 @@ def test_indirect_2():
...
@@ -325,6 +333,7 @@ def test_indirect_2():
consistent
(
g
)
consistent
(
g
)
@assertFailure_fast
def
test_long_destroyers_loop
():
def
test_long_destroyers_loop
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
dot
(
dot
(
add_in_place
(
x
,
y
),
e
=
dot
(
dot
(
add_in_place
(
x
,
y
),
...
@@ -366,6 +375,7 @@ def test_multi_destroyers():
...
@@ -366,6 +375,7 @@ def test_multi_destroyers():
pass
pass
@assertFailure_fast
def
test_multi_destroyers_through_views
():
def
test_multi_destroyers_through_views
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
dot
(
add
(
transpose_view
(
z
),
y
),
add
(
z
,
x
))
e
=
dot
(
add
(
transpose_view
(
z
),
y
),
add
(
z
,
x
))
...
@@ -408,6 +418,7 @@ def test_usage_loop_through_views():
...
@@ -408,6 +418,7 @@ def test_usage_loop_through_views():
consistent
(
g
)
consistent
(
g
)
@assertFailure_fast
def
test_usage_loop_insert_views
():
def
test_usage_loop_insert_views
():
x
,
y
,
z
=
inputs
()
x
,
y
,
z
=
inputs
()
e
=
dot
(
add_in_place
(
x
,
add
(
y
,
z
)),
e
=
dot
(
add_in_place
(
x
,
add
(
y
,
z
)),
...
@@ -442,6 +453,7 @@ def test_value_repl_2():
...
@@ -442,6 +453,7 @@ def test_value_repl_2():
consistent
(
g
)
consistent
(
g
)
@assertFailure_fast
def
test_multiple_inplace
():
def
test_multiple_inplace
():
# this tests issue #5223
# this tests issue #5223
# there were some problems with Ops that have more than
# there were some problems with Ops that have more than
...
...
theano/gpuarray/tests/test_dnn.py
浏览文件 @
e0821760
...
@@ -1754,6 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages():
...
@@ -1754,6 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages():
f_abstract
(
X
,
Scale
,
Bias
,
Dy
)
f_abstract
(
X
,
Scale
,
Bias
,
Dy
)
@utt.assertFailure_fast
def
test_dnn_batchnorm_train_inplace
():
def
test_dnn_batchnorm_train_inplace
():
# test inplace_running_mean and inplace_running_var
# test inplace_running_mean and inplace_running_var
if
not
dnn
.
dnn_available
(
test_ctx_name
):
if
not
dnn
.
dnn_available
(
test_ctx_name
):
...
@@ -1876,6 +1877,7 @@ def test_batchnorm_inference():
...
@@ -1876,6 +1877,7 @@ def test_batchnorm_inference():
utt
.
assert_allclose
(
outputs_abstract
[
5
],
outputs_ref
[
5
],
rtol
=
2e-3
,
atol
=
4e-5
)
# dvar
utt
.
assert_allclose
(
outputs_abstract
[
5
],
outputs_ref
[
5
],
rtol
=
2e-3
,
atol
=
4e-5
)
# dvar
@utt.assertFailure_fast
def
test_batchnorm_inference_inplace
():
def
test_batchnorm_inference_inplace
():
# test inplace
# test inplace
if
not
dnn
.
dnn_available
(
test_ctx_name
):
if
not
dnn
.
dnn_available
(
test_ctx_name
):
...
...
theano/gpuarray/tests/test_linalg.py
浏览文件 @
e0821760
...
@@ -175,6 +175,7 @@ class TestGpuCholesky(unittest.TestCase):
...
@@ -175,6 +175,7 @@ class TestGpuCholesky(unittest.TestCase):
GpuCholesky
(
lower
=
True
,
inplace
=
False
)(
A
)
GpuCholesky
(
lower
=
True
,
inplace
=
False
)(
A
)
self
.
assertRaises
(
AssertionError
,
invalid_input_func
)
self
.
assertRaises
(
AssertionError
,
invalid_input_func
)
@utt.assertFailure_fast
def
test_diag_chol
(
self
):
def
test_diag_chol
(
self
):
# Diagonal matrix input Cholesky test.
# Diagonal matrix input Cholesky test.
for
lower
in
[
True
,
False
]:
for
lower
in
[
True
,
False
]:
...
@@ -183,6 +184,7 @@ class TestGpuCholesky(unittest.TestCase):
...
@@ -183,6 +184,7 @@ class TestGpuCholesky(unittest.TestCase):
A_val
=
np
.
diag
(
np
.
random
.
uniform
(
size
=
5
)
.
astype
(
"float32"
)
+
1
)
A_val
=
np
.
diag
(
np
.
random
.
uniform
(
size
=
5
)
.
astype
(
"float32"
)
+
1
)
self
.
compare_gpu_cholesky_to_np
(
A_val
,
lower
=
lower
,
inplace
=
inplace
)
self
.
compare_gpu_cholesky_to_np
(
A_val
,
lower
=
lower
,
inplace
=
inplace
)
@utt.assertFailure_fast
def
test_dense_chol_lower
(
self
):
def
test_dense_chol_lower
(
self
):
# Dense matrix input lower-triangular Cholesky test.
# Dense matrix input lower-triangular Cholesky test.
for
lower
in
[
True
,
False
]:
for
lower
in
[
True
,
False
]:
...
@@ -243,6 +245,7 @@ class TestMagma(unittest.TestCase):
...
@@ -243,6 +245,7 @@ class TestMagma(unittest.TestCase):
A_val_inv
=
fn
(
A_val
)
A_val_inv
=
fn
(
A_val
)
utt
.
assert_allclose
(
np
.
eye
(
N
),
np
.
dot
(
A_val_inv
,
A_val
),
atol
=
1e-2
)
utt
.
assert_allclose
(
np
.
eye
(
N
),
np
.
dot
(
A_val_inv
,
A_val
),
atol
=
1e-2
)
@utt.assertFailure_fast
def
test_gpu_matrix_inverse_inplace
(
self
):
def
test_gpu_matrix_inverse_inplace
(
self
):
N
=
1000
N
=
1000
test_rng
=
np
.
random
.
RandomState
(
seed
=
1
)
test_rng
=
np
.
random
.
RandomState
(
seed
=
1
)
...
@@ -258,6 +261,7 @@ class TestMagma(unittest.TestCase):
...
@@ -258,6 +261,7 @@ class TestMagma(unittest.TestCase):
fn
()
fn
()
utt
.
assert_allclose
(
np
.
eye
(
N
),
np
.
dot
(
A_val_gpu
.
get_value
(),
A_val_copy
),
atol
=
5e-3
)
utt
.
assert_allclose
(
np
.
eye
(
N
),
np
.
dot
(
A_val_gpu
.
get_value
(),
A_val_copy
),
atol
=
5e-3
)
@utt.assertFailure_fast
def
test_gpu_matrix_inverse_inplace_opt
(
self
):
def
test_gpu_matrix_inverse_inplace_opt
(
self
):
A
=
theano
.
tensor
.
fmatrix
(
"A"
)
A
=
theano
.
tensor
.
fmatrix
(
"A"
)
fn
=
theano
.
function
([
A
],
matrix_inverse
(
A
),
mode
=
mode_with_gpu
)
fn
=
theano
.
function
([
A
],
matrix_inverse
(
A
),
mode
=
mode_with_gpu
)
...
@@ -360,6 +364,7 @@ class TestMagma(unittest.TestCase):
...
@@ -360,6 +364,7 @@ class TestMagma(unittest.TestCase):
assert
any
([
isinstance
(
node
.
op
,
GpuMagmaCholesky
)
assert
any
([
isinstance
(
node
.
op
,
GpuMagmaCholesky
)
for
node
in
fn
.
maker
.
fgraph
.
toposort
()])
for
node
in
fn
.
maker
.
fgraph
.
toposort
()])
@utt.assertFailure_fast
def
test_gpu_cholesky_inplace
(
self
):
def
test_gpu_cholesky_inplace
(
self
):
A
=
self
.
rand_symmetric
(
1000
)
A
=
self
.
rand_symmetric
(
1000
)
A_gpu
=
gpuarray_shared_constructor
(
A
)
A_gpu
=
gpuarray_shared_constructor
(
A
)
...
@@ -375,6 +380,7 @@ class TestMagma(unittest.TestCase):
...
@@ -375,6 +380,7 @@ class TestMagma(unittest.TestCase):
L
=
A_gpu
.
get_value
()
L
=
A_gpu
.
get_value
()
utt
.
assert_allclose
(
np
.
dot
(
L
,
L
.
T
),
A_copy
,
atol
=
1e-3
)
utt
.
assert_allclose
(
np
.
dot
(
L
,
L
.
T
),
A_copy
,
atol
=
1e-3
)
@utt.assertFailure_fast
def
test_gpu_cholesky_inplace_opt
(
self
):
def
test_gpu_cholesky_inplace_opt
(
self
):
A
=
theano
.
tensor
.
fmatrix
(
"A"
)
A
=
theano
.
tensor
.
fmatrix
(
"A"
)
fn
=
theano
.
function
([
A
],
GpuMagmaCholesky
()(
A
),
mode
=
mode_with_gpu
)
fn
=
theano
.
function
([
A
],
GpuMagmaCholesky
()(
A
),
mode
=
mode_with_gpu
)
...
...
theano/gpuarray/tests/test_opt.py
浏览文件 @
e0821760
...
@@ -585,6 +585,7 @@ def test_no_complex():
...
@@ -585,6 +585,7 @@ def test_no_complex():
mode
=
mode_with_gpu
)
mode
=
mode_with_gpu
)
@utt.assertFailure_fast
def
test_local_lift_solve
():
def
test_local_lift_solve
():
if
not
cusolver_available
:
if
not
cusolver_available
:
raise
SkipTest
(
'No cuSolver'
)
raise
SkipTest
(
'No cuSolver'
)
...
@@ -619,6 +620,7 @@ def test_gpu_solve_not_inplace():
...
@@ -619,6 +620,7 @@ def test_gpu_solve_not_inplace():
utt
.
assert_allclose
(
f_cpu
(
A_val
,
b_val
),
f_gpu
(
A_val
,
b_val
))
utt
.
assert_allclose
(
f_cpu
(
A_val
,
b_val
),
f_gpu
(
A_val
,
b_val
))
@utt.assertFailure_fast
def
test_local_lift_cholesky
():
def
test_local_lift_cholesky
():
if
not
cusolver_available
:
if
not
cusolver_available
:
raise
SkipTest
(
'No cuSolver'
)
raise
SkipTest
(
'No cuSolver'
)
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
e0821760
...
@@ -886,6 +886,7 @@ class T_Scan(unittest.TestCase):
...
@@ -886,6 +886,7 @@ class T_Scan(unittest.TestCase):
utt
.
assert_allclose
(
numpy_out
,
theano_out
)
utt
.
assert_allclose
(
numpy_out
,
theano_out
)
# simple rnn ; compute inplace version 1
# simple rnn ; compute inplace version 1
@utt.assertFailure_fast
def
test_inplace1
(
self
):
def
test_inplace1
(
self
):
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
vW
=
asarrayX
(
np
.
random
.
uniform
())
vW
=
asarrayX
(
np
.
random
.
uniform
())
...
@@ -950,6 +951,7 @@ class T_Scan(unittest.TestCase):
...
@@ -950,6 +951,7 @@ class T_Scan(unittest.TestCase):
utt
.
assert_allclose
(
theano_x1
,
numpy_x1
)
utt
.
assert_allclose
(
theano_x1
,
numpy_x1
)
# simple rnn ; compute inplace version 2
# simple rnn ; compute inplace version 2
@utt.assertFailure_fast
def
test_inplace2
(
self
):
def
test_inplace2
(
self
):
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
vW
=
asarrayX
(
np
.
random
.
uniform
())
vW
=
asarrayX
(
np
.
random
.
uniform
())
...
@@ -1021,6 +1023,7 @@ class T_Scan(unittest.TestCase):
...
@@ -1021,6 +1023,7 @@ class T_Scan(unittest.TestCase):
utt
.
assert_allclose
(
theano_x0
,
numpy_x0
)
utt
.
assert_allclose
(
theano_x0
,
numpy_x0
)
utt
.
assert_allclose
(
theano_x1
,
numpy_x1
)
utt
.
assert_allclose
(
theano_x1
,
numpy_x1
)
@utt.assertFailure_fast
def
test_inplace3
(
self
):
def
test_inplace3
(
self
):
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
RandomState
(
utt
.
fetch_seed
())
...
...
theano/sparse/tests/test_basic.py
浏览文件 @
e0821760
...
@@ -3201,6 +3201,7 @@ import theano.tensor.tests.test_sharedvar
...
@@ -3201,6 +3201,7 @@ import theano.tensor.tests.test_sharedvar
theano_fct_
=
lambda
a
:
dense_from_sparse
(
a
*
2.
),
theano_fct_
=
lambda
a
:
dense_from_sparse
(
a
*
2.
),
ref_fct_
=
lambda
a
:
np
.
asarray
((
a
*
2
)
.
todense
()),
ref_fct_
=
lambda
a
:
np
.
asarray
((
a
*
2
)
.
todense
()),
cast_value_
=
scipy
.
sparse
.
csr_matrix
,
cast_value_
=
scipy
.
sparse
.
csr_matrix
,
expect_fail_fast_shape_inplace
=
False
,
)
)
class
test_shared_options
(
object
):
class
test_shared_options
(
object
):
pass
pass
...
...
theano/tensor/nnet/tests/test_nnet.py
浏览文件 @
e0821760
...
@@ -579,7 +579,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
...
@@ -579,7 +579,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano
.
compile
.
mode
.
optdb
.
query
(
theano
.
compile
.
mode
.
optdb
.
query
(
theano
.
compile
.
mode
.
OPT_FAST_RUN
)
.
optimize
(
fgraph
)
theano
.
compile
.
mode
.
OPT_FAST_RUN
)
.
optimize
(
fgraph
)
assert
(
fgraph
.
outputs
[
0
]
.
owner
.
op
==
assert
(
fgraph
.
outputs
[
0
]
.
owner
.
op
==
crossentropy_softmax_argmax_1hot_with_bias
)
crossentropy_softmax_argmax_1hot_with_bias
)
...
@@ -652,7 +651,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
...
@@ -652,7 +651,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# print node.op
# print node.op
# print '===='
# print '===='
assert
len
(
fgraph
.
toposort
())
==
2
assert
len
(
fgraph
.
toposort
())
==
2
assert
(
fgraph
.
outputs
[
0
]
.
owner
.
op
==
assert
(
fgraph
.
outputs
[
0
]
.
owner
.
op
==
crossentropy_softmax_argmax_1hot_with_bias
)
crossentropy_softmax_argmax_1hot_with_bias
)
...
@@ -1382,6 +1380,7 @@ def test_argmax_pushdown_bias():
...
@@ -1382,6 +1380,7 @@ def test_argmax_pushdown_bias():
# print node.op
# print node.op
types_to_check
=
(
tensor
.
DimShuffle
,
tensor
.
Elemwise
,
tensor
.
Argmax
)
types_to_check
=
(
tensor
.
DimShuffle
,
tensor
.
Elemwise
,
tensor
.
Argmax
)
assert
len
(
fgraph
.
toposort
())
==
3
assert
len
(
fgraph
.
toposort
())
==
3
for
i
,
type
in
enumerate
(
types_to_check
):
for
i
,
type
in
enumerate
(
types_to_check
):
assert
isinstance
(
fgraph
.
toposort
()[
i
]
.
op
,
type
)
assert
isinstance
(
fgraph
.
toposort
()[
i
]
.
op
,
type
)
assert
check_stack_trace
(
fgraph
,
ops_to_check
=
types_to_check
)
assert
check_stack_trace
(
fgraph
,
ops_to_check
=
types_to_check
)
...
...
theano/tensor/nnet/tests/test_opt.py
浏览文件 @
e0821760
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
import
theano
import
theano
from
theano
import
tensor
from
theano
import
tensor
from
theano.tests.unittest_tools
import
assertFailure_fast
from
theano.gof.opt
import
check_stack_trace
from
theano.gof.opt
import
check_stack_trace
from
theano.tensor.nnet.blocksparse
import
(
from
theano.tensor.nnet.blocksparse
import
(
sparse_block_dot
,
sparse_block_gemv_inplace
,
sparse_block_outer_inplace
,
sparse_block_dot
,
sparse_block_gemv_inplace
,
sparse_block_outer_inplace
,
...
@@ -25,6 +26,9 @@ def test_blocksparse_inplace_gemv_opt():
...
@@ -25,6 +26,9 @@ def test_blocksparse_inplace_gemv_opt():
assert
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
.
inplace
assert
f
.
maker
.
fgraph
.
toposort
()[
-
1
]
.
op
.
inplace
assert
check_stack_trace
(
f
,
ops_to_check
=
[
sparse_block_gemv_inplace
])
assert
check_stack_trace
(
f
,
ops_to_check
=
[
sparse_block_gemv_inplace
])
if
theano
.
config
.
mode
!=
'FAST_COMPILE'
:
test_blocksparse_inplace_gemv_opt
=
assertFailure_fast
(
test_blocksparse_inplace_gemv_opt
)
def
test_blocksparse_inplace_outer_opt
():
def
test_blocksparse_inplace_outer_opt
():
b
=
tensor
.
fmatrix
()
b
=
tensor
.
fmatrix
()
...
...
theano/tensor/opt.py
浏览文件 @
e0821760
...
@@ -265,8 +265,8 @@ class InplaceElemwiseOptimizer(Optimizer):
...
@@ -265,8 +265,8 @@ class InplaceElemwiseOptimizer(Optimizer):
candidate_inputs
=
[
i
for
i
in
xrange
(
len
(
node
.
inputs
))
candidate_inputs
=
[
i
for
i
in
xrange
(
len
(
node
.
inputs
))
if
i
not
in
baseline
.
values
()
and
if
i
not
in
baseline
.
values
()
and
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
#
Is next line costly?
#
the next line should not be costly most of the time.
not
fgraph
.
destroyers
(
node
.
inputs
[
i
])
and
not
fgraph
.
has_destroyers
([
node
.
inputs
[
i
]
])
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
node
.
inputs
[
i
]
not
in
protected_inputs
]
else
:
else
:
baseline
=
[]
baseline
=
[]
...
@@ -277,7 +277,7 @@ class InplaceElemwiseOptimizer(Optimizer):
...
@@ -277,7 +277,7 @@ class InplaceElemwiseOptimizer(Optimizer):
# Remove here as faster.
# Remove here as faster.
candidate_inputs
=
[
i
for
i
in
xrange
(
len
(
node
.
inputs
))
candidate_inputs
=
[
i
for
i
in
xrange
(
len
(
node
.
inputs
))
if
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
if
not
isinstance
(
node
.
inputs
[
i
],
Constant
)
and
not
fgraph
.
destroyers
(
node
.
inputs
[
i
])
and
not
fgraph
.
has_destroyers
([
node
.
inputs
[
i
]
])
and
node
.
inputs
[
i
]
not
in
protected_inputs
]
node
.
inputs
[
i
]
not
in
protected_inputs
]
verbose
=
False
verbose
=
False
...
...
theano/tensor/tests/test_basic.py
浏览文件 @
e0821760
...
@@ -4806,6 +4806,9 @@ class T_exp(unittest.TestCase):
...
@@ -4806,6 +4806,9 @@ class T_exp(unittest.TestCase):
np
.
asarray
([[
1.5089518
,
1.48439076
,
-
4.7820262
],
np
.
asarray
([[
1.5089518
,
1.48439076
,
-
4.7820262
],
[
2.04832468
,
0.50791564
,
-
1.58892269
]])])
[
2.04832468
,
0.50791564
,
-
1.58892269
]])])
if
theano
.
config
.
cycle_detection
==
'fast'
and
theano
.
config
.
mode
!=
'FAST_COMPILE'
:
test_grad_1
=
unittest
.
expectedFailure
(
test_grad_1
)
def
test_int
(
self
):
def
test_int
(
self
):
x
=
ivector
()
x
=
ivector
()
f
=
function
([
x
],
exp
(
x
))
f
=
function
([
x
],
exp
(
x
))
...
...
theano/tensor/tests/test_blas.py
浏览文件 @
e0821760
...
@@ -500,6 +500,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
...
@@ -500,6 +500,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
raise
raise
@unittest_tools.assertFailure_fast
def
test_gemm_opt0
():
def
test_gemm_opt0
():
# Many subgraphs whose dots can be eliminated
# Many subgraphs whose dots can be eliminated
X
,
Y
,
Z
,
a
,
b
=
XYZab
()
X
,
Y
,
Z
,
a
,
b
=
XYZab
()
...
@@ -528,6 +529,7 @@ def test_gemm_opt0():
...
@@ -528,6 +529,7 @@ def test_gemm_opt0():
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
-
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
just_gemm
([
X
,
Y
,
Z
,
a
,
b
],
[
Z
-
a
*
b
*
a
*
T
.
dot
(
X
,
Y
)])
@unittest_tools.assertFailure_fast
def
test_gemm_opt_double_gemm
():
def
test_gemm_opt_double_gemm
():
# This is the pattern that shows up in the autoencoder
# This is the pattern that shows up in the autoencoder
X
,
Y
,
Z
,
a
,
b
=
T
.
matrix
(),
T
.
matrix
(),
T
.
matrix
(),
T
.
scalar
(),
T
.
scalar
()
X
,
Y
,
Z
,
a
,
b
=
T
.
matrix
(),
T
.
matrix
(),
T
.
matrix
(),
T
.
scalar
(),
T
.
scalar
()
...
...
theano/tensor/tests/test_opt.py
浏览文件 @
e0821760
...
@@ -1367,6 +1367,7 @@ class TestCompositeCodegen(unittest.TestCase):
...
@@ -1367,6 +1367,7 @@ class TestCompositeCodegen(unittest.TestCase):
utt
.
assert_allclose
(
f
([[
1.
]]),
[[
0.
]])
utt
.
assert_allclose
(
f
([[
1.
]]),
[[
0.
]])
@utt.assertFailure_fast
def
test_log1p
():
def
test_log1p
():
m
=
theano
.
config
.
mode
m
=
theano
.
config
.
mode
if
m
==
'FAST_COMPILE'
:
if
m
==
'FAST_COMPILE'
:
...
@@ -1989,6 +1990,7 @@ class test_local_subtensor_lift(unittest.TestCase):
...
@@ -1989,6 +1990,7 @@ class test_local_subtensor_lift(unittest.TestCase):
assert
len
(
prog
)
==
3
assert
len
(
prog
)
==
3
f
([
4
,
5
])
# let debugmode test something
f
([
4
,
5
])
# let debugmode test something
@utt.assertFailure_fast
def
test4
(
self
):
def
test4
(
self
):
# basic test that the optimization doesn't work with broadcasting
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... It *could* be extended to,
...
...
theano/tensor/tests/test_sharedvar.py
浏览文件 @
e0821760
...
@@ -27,6 +27,7 @@ def makeSharedTester(shared_constructor_,
...
@@ -27,6 +27,7 @@ def makeSharedTester(shared_constructor_,
theano_fct_
,
theano_fct_
,
ref_fct_
,
ref_fct_
,
cast_value_
=
np
.
asarray
,
cast_value_
=
np
.
asarray
,
expect_fail_fast_shape_inplace
=
True
,
):
):
"""
"""
This is a generic fct to allow reusing the same test function
This is a generic fct to allow reusing the same test function
...
@@ -549,6 +550,10 @@ def makeSharedTester(shared_constructor_,
...
@@ -549,6 +550,10 @@ def makeSharedTester(shared_constructor_,
assert
sum
([
node
.
op
.
__class__
.
__name__
in
[
"Gemm"
,
"GpuGemm"
,
"StructuredDot"
]
for
node
in
topo
])
==
1
assert
sum
([
node
.
op
.
__class__
.
__name__
in
[
"Gemm"
,
"GpuGemm"
,
"StructuredDot"
]
for
node
in
topo
])
==
1
assert
all
(
node
.
op
==
tensor
.
blas
.
gemm_inplace
for
node
in
topo
if
isinstance
(
node
.
op
,
tensor
.
blas
.
Gemm
))
assert
all
(
node
.
op
==
tensor
.
blas
.
gemm_inplace
for
node
in
topo
if
isinstance
(
node
.
op
,
tensor
.
blas
.
Gemm
))
assert
all
(
node
.
op
.
inplace
for
node
in
topo
if
node
.
op
.
__class__
.
__name__
==
"GpuGemm"
)
assert
all
(
node
.
op
.
inplace
for
node
in
topo
if
node
.
op
.
__class__
.
__name__
==
"GpuGemm"
)
if
theano
.
config
.
cycle_detection
==
'fast'
and
expect_fail_fast_shape_inplace
and
theano
.
config
.
mode
!=
'FAST_COMPILE'
:
test_specify_shape_inplace
=
unittest
.
expectedFailure
(
test_specify_shape_inplace
)
def
test_values_eq
(
self
):
def
test_values_eq
(
self
):
""" Test the type.values_eq[_approx] function"""
""" Test the type.values_eq[_approx] function"""
dtype
=
self
.
dtype
dtype
=
self
.
dtype
...
...
theano/tests/unittest_tools.py
浏览文件 @
e0821760
...
@@ -5,6 +5,7 @@ import logging
...
@@ -5,6 +5,7 @@ import logging
import
sys
import
sys
import
unittest
import
unittest
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
nose.tools
import
assert_raises
from
six
import
integer_types
from
six
import
integer_types
from
six.moves
import
StringIO
from
six.moves
import
StringIO
...
@@ -445,3 +446,16 @@ class AttemptManyTimes:
...
@@ -445,3 +446,16 @@ class AttemptManyTimes:
current_seed
=
str
(
int
(
current_seed
)
+
1
)
current_seed
=
str
(
int
(
current_seed
)
+
1
)
return
attempt_multiple_times
return
attempt_multiple_times
def
assertFailure_fast
(
f
):
"""A Decorator to handle the test cases that are failing when
THEANO_FLAGS =cycle_detection='fast'.
"""
if
theano
.
config
.
cycle_detection
==
'fast'
:
def
test_with_assert
(
*
args
,
**
kwargs
):
with
assert_raises
(
Exception
):
f
(
*
args
,
**
kwargs
)
return
test_with_assert
else
:
return
f
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论