Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9f7a1b69
Unverified
提交
9f7a1b69
authored
10月 18, 2020
作者:
Brandon T. Willard
提交者:
GitHub
10月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #110 from brandonwillard/enforce-sane-test-values
Refactor test value framework so that test value validation is performed up-front.
上级
53486aae
ea44b16d
隐藏空白字符变更
内嵌
并排
正在显示
22 个修改的文件
包含
192 行增加
和
464 行删除
+192
-464
pkl_utils.txt
doc/library/misc/pkl_utils.txt
+0
-4
test_compute_test_value.py
tests/gof/test_compute_test_value.py
+8
-6
test_fg.py
tests/gof/test_fg.py
+1
-27
test_fg_old_crash.pkl
tests/gof/test_fg_old_crash.pkl
+0
-0
test_op.py
tests/gof/test_op.py
+32
-100
test_multinomial.py
tests/gpuarray/test_multinomial.py
+3
-2
test_pickle.py
tests/gpuarray/test_pickle.py
+4
-10
test_type.py
tests/gpuarray/test_type.py
+3
-6
test_multinomial.py
tests/sandbox/test_multinomial.py
+0
-40
test_multinomial_wo_replacement.py
tests/sandbox/test_multinomial_wo_replacement.py
+1
-14
unittest_tools.py
tests/unittest_tools.py
+1
-1
debugmode.py
theano/compile/debugmode.py
+4
-4
cmodule.py
theano/gof/cmodule.py
+3
-3
compilelock.py
theano/gof/compilelock.py
+1
-1
graph.py
theano/gof/graph.py
+3
-2
op.py
theano/gof/op.py
+5
-59
utils.py
theano/gof/utils.py
+18
-1
vm.py
theano/gof/vm.py
+1
-1
pkl_utils.py
theano/misc/pkl_utils.py
+4
-91
printing.py
theano/printing.py
+3
-3
conv.py
theano/tensor/nnet/conv.py
+2
-2
opt.py
theano/tensor/opt.py
+95
-87
没有找到文件。
doc/library/misc/pkl_utils.txt
浏览文件 @
9f7a1b69
...
...
@@ -15,10 +15,6 @@
.. autoclass:: theano.misc.pkl_utils.StripPickler
.. autoclass:: theano.misc.pkl_utils.CompatUnpickler
.. seealso::
:ref:`tutorial_loadsave`
tests/gof/test_compute_test_value.py
浏览文件 @
9f7a1b69
...
...
@@ -167,14 +167,16 @@ class TestComputeTestValue:
@theano.change_flags
(
compute_test_value
=
"raise"
)
def
test_incorrect_type
(
self
):
x
=
tt
.
fmatrix
(
"x"
)
# Incorrect dtype (float64) for test_value
x
.
tag
.
test_value
=
np
.
random
.
rand
(
3
,
4
)
y
=
tt
.
dmatrix
(
"y"
)
y
.
tag
.
test_value
=
np
.
random
.
rand
(
4
,
5
)
x
=
tt
.
vector
(
"x"
)
with
pytest
.
raises
(
TypeError
):
tt
.
dot
(
x
,
y
)
# Incorrect shape for test value
x
.
tag
.
test_value
=
np
.
empty
((
2
,
2
))
x
=
tt
.
fmatrix
(
"x"
)
with
pytest
.
raises
(
TypeError
):
# Incorrect dtype (float64) for test value
x
.
tag
.
test_value
=
np
.
random
.
rand
(
3
,
4
)
@theano.change_flags
(
compute_test_value
=
"raise"
)
def
test_overided_function
(
self
):
...
...
tests/gof/test_fg.py
浏览文件 @
9f7a1b69
import
os
import
pickle
import
pytest
import
theano
from
theano.compat
import
PY3
from
theano.gof.fg
import
FunctionGraph
from
theano
import
tensor
as
tt
from
theano.gof.fg
import
FunctionGraph
class
TestFunctionGraph
:
...
...
@@ -16,24 +11,3 @@ class TestFunctionGraph:
s
=
pickle
.
dumps
(
func
)
pickle
.
loads
(
s
)
@pytest.mark.skipif
(
not
theano
.
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
)
@pytest.mark.slow
def
test_node_outputs_not_used
(
self
):
# In the past, we where removing some not used variable from
# fgraph.variables event if the apply had other output used in
# the graph. This caused a crash.
# This test run the pickle that reproduce this case.
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"test_fg_old_crash.pkl"
),
"rb"
)
as
f
:
from
theano.misc.pkl_utils
import
CompatUnpickler
if
PY3
:
u
=
CompatUnpickler
(
f
,
encoding
=
"latin1"
)
else
:
u
=
CompatUnpickler
(
f
)
d
=
u
.
load
()
f
=
theano
.
function
(
**
d
)
tests/gof/test_fg_old_crash.pkl
deleted
100644 → 0
浏览文件 @
53486aae
File deleted
tests/gof/test_op.py
浏览文件 @
9f7a1b69
import
numpy
as
np
import
pytest
import
theano
import
theano.gof.op
as
op
import
theano.tensor
as
tt
from
six
import
string_types
from
theano.gof.type
import
Type
,
Generic
from
theano
import
scalar
,
shared
from
theano.configparser
import
change_flags
from
theano.gof.graph
import
Apply
,
Variable
import
theano.tensor
as
T
from
theano
import
scalar
from
theano
import
shared
from
theano.gof.type
import
Generic
,
Type
config
=
theano
.
config
Op
=
op
.
Op
...
...
@@ -238,15 +238,15 @@ class TestMakeThunk:
__props__
=
()
itypes
=
[
T
.
dmatrix
]
otypes
=
[
T
.
dmatrix
]
itypes
=
[
tt
.
dmatrix
]
otypes
=
[
tt
.
dmatrix
]
def
perform
(
self
,
node
,
inputs
,
outputs
):
inp
=
inputs
[
0
]
output
=
outputs
[
0
]
output
[
0
]
=
inp
*
2
x_input
=
T
.
dmatrix
(
"x_input"
)
x_input
=
tt
.
dmatrix
(
"x_input"
)
f
=
theano
.
function
([
x_input
],
DoubleOp
()(
x_input
))
inp
=
np
.
random
.
rand
(
5
,
4
)
out
=
f
(
inp
)
...
...
@@ -255,17 +255,17 @@ class TestMakeThunk:
def
test_test_value_python_objects
():
for
x
in
([
0
,
1
,
2
],
0
,
0.5
,
1
):
assert
(
op
.
get_test_value
(
x
)
==
x
)
.
all
(
)
assert
np
.
all
(
op
.
get_test_value
(
x
)
==
x
)
def
test_test_value_ndarray
():
x
=
np
.
zeros
((
5
,
5
))
v
=
op
.
get_test_value
(
x
)
assert
(
v
==
x
)
.
all
(
)
assert
np
.
all
(
v
==
x
)
def
test_test_value_constant
():
x
=
T
.
as_tensor_variable
(
np
.
zeros
((
5
,
5
)))
x
=
tt
.
as_tensor_variable
(
np
.
zeros
((
5
,
5
)))
v
=
op
.
get_test_value
(
x
)
assert
np
.
all
(
v
==
np
.
zeros
((
5
,
5
)))
...
...
@@ -278,62 +278,37 @@ def test_test_value_shared():
assert
np
.
all
(
v
==
np
.
zeros
((
5
,
5
)))
@change_flags
(
compute_test_value
=
"raise"
)
def
test_test_value_op
():
try
:
prev_value
=
config
.
compute_test_value
config
.
compute_test_value
=
"raise"
x
=
T
.
log
(
np
.
ones
((
5
,
5
)))
v
=
op
.
get_test_value
(
x
)
assert
np
.
allclose
(
v
,
np
.
zeros
((
5
,
5
)))
finally
:
config
.
compute_test_value
=
prev_value
def
test_get_debug_values_no_debugger
():
"get_debug_values should return [] when debugger is off"
x
=
tt
.
log
(
np
.
ones
((
5
,
5
)))
v
=
op
.
get_test_value
(
x
)
prev_value
=
config
.
compute_test_value
try
:
config
.
compute_test_value
=
"off"
assert
np
.
allclose
(
v
,
np
.
zeros
((
5
,
5
)))
x
=
T
.
vector
()
for
x_val
in
op
.
get_debug_values
(
x
):
assert
False
@change_flags
(
compute_test_value
=
"off"
)
def
test_get_debug_values_no_debugger
():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
finally
:
config
.
compute_test_value
=
prev_value
x
=
tt
.
vector
()
assert
op
.
get_debug_values
(
x
)
==
[]
@change_flags
(
compute_test_value
=
"ignore"
)
def
test_get_det_debug_values_ignore
():
# get_debug_values should return [] when debugger is ignore
# and some values are missing
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
prev_value
=
config
.
compute_test_value
try
:
config
.
compute_test_value
=
"ignore"
x
=
T
.
vector
()
for
x_val
in
op
.
get_debug_values
(
x
):
assert
False
finally
:
config
.
compute_test_value
=
prev_value
x
=
tt
.
vector
()
assert
op
.
get_debug_values
(
x
)
==
[]
def
test_get_debug_values_success
():
# tests that get_debug_value returns values when available
# (and the debugger is on)
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
prev_value
=
config
.
compute_test_value
for
mode
in
[
"ignore"
,
"warn"
,
"raise"
]:
try
:
config
.
compute_test_value
=
mode
x
=
T
.
vector
()
with
change_flags
(
compute_test_value
=
mode
):
x
=
tt
.
vector
()
x
.
tag
.
test_value
=
np
.
zeros
((
4
,),
dtype
=
config
.
floatX
)
y
=
np
.
zeros
((
5
,
5
))
...
...
@@ -348,54 +323,11 @@ def test_get_debug_values_success():
assert
iters
==
1
finally
:
config
.
compute_test_value
=
prev_value
@change_flags
(
compute_test_value
=
"raise"
)
def
test_get_debug_values_exc
():
# tests that get_debug_value raises an exception when
# debugger is set to raise and a value is missing
prev_value
=
config
.
compute_test_value
try
:
config
.
compute_test_value
=
"raise"
x
=
T
.
vector
()
try
:
for
x_val
in
op
.
get_debug_values
(
x
):
# this assert catches the case where we
# erroneously get a value returned
assert
False
raised
=
False
except
AttributeError
:
raised
=
True
# this assert catches the case where we got []
# returned, and possibly issued a warning,
# rather than raising an exception
assert
raised
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
finally
:
config
.
compute_test_value
=
prev_value
def
test_debug_error_message
():
# tests that debug_error_message raises an
# exception when it should.
prev_value
=
config
.
compute_test_value
for
mode
in
[
"ignore"
,
"raise"
]:
try
:
config
.
compute_test_value
=
mode
try
:
op
.
debug_error_message
(
"msg"
)
raised
=
False
except
ValueError
:
raised
=
True
assert
raised
finally
:
config
.
compute_test_value
=
prev_value
with
pytest
.
raises
(
AttributeError
):
x
=
tt
.
vector
()
assert
op
.
get_debug_values
(
x
)
==
[]
tests/gpuarray/test_multinomial.py
浏览文件 @
9f7a1b69
...
...
@@ -7,9 +7,10 @@ import theano
import
tests.unittest_tools
as
utt
from
pickle
import
Unpickler
from
theano
import
config
,
function
,
tensor
from
theano.compat
import
PY3
from
theano.misc.pkl_utils
import
CompatUnpickler
from
theano.sandbox
import
multinomial
from
theano.sandbox.rng_mrg
import
MRG_RandomStreams
as
RandomStreams
from
theano.gpuarray.multinomial
import
(
...
...
@@ -384,6 +385,6 @@ def test_unpickle_legacy_op():
if
not
PY3
:
with
open
(
os
.
path
.
join
(
testfile_dir
,
fname
),
"r"
)
as
fp
:
u
=
Compat
Unpickler
(
fp
)
u
=
Unpickler
(
fp
)
m
=
u
.
load
()
assert
isinstance
(
m
,
GPUAChoiceFromUniform
)
tests/gpuarray/test_pickle.py
浏览文件 @
9f7a1b69
...
...
@@ -13,9 +13,9 @@ import pytest
import
numpy
as
np
from
pickle
import
Unpickler
from
theano
import
config
from
theano.compat
import
PY3
from
theano.misc.pkl_utils
import
CompatUnpickler
from
theano.gpuarray.type
import
ContextNotDefined
...
...
@@ -37,10 +37,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag1():
fname
=
"GpuArray.pkl"
with
open
(
os
.
path
.
join
(
testfile_dir
,
fname
),
"rb"
)
as
fp
:
if
PY3
:
u
=
CompatUnpickler
(
fp
,
encoding
=
"latin1"
)
else
:
u
=
CompatUnpickler
(
fp
)
u
=
Unpickler
(
fp
,
encoding
=
"latin1"
)
with
pytest
.
raises
((
ImportError
,
ContextNotDefined
)):
u
.
load
()
finally
:
...
...
@@ -56,10 +53,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag2():
fname
=
"GpuArray.pkl"
with
open
(
os
.
path
.
join
(
testfile_dir
,
fname
),
"rb"
)
as
fp
:
if
PY3
:
u
=
CompatUnpickler
(
fp
,
encoding
=
"latin1"
)
else
:
u
=
CompatUnpickler
(
fp
)
u
=
Unpickler
(
fp
,
encoding
=
"latin1"
)
try
:
mat
=
u
.
load
()
except
ImportError
:
...
...
tests/gpuarray/test_type.py
浏览文件 @
9f7a1b69
...
...
@@ -5,10 +5,10 @@ import theano
pygpu
=
pytest
.
importorskip
(
"pygpu"
)
from
theano.compat
import
PY3
from
pickle
import
Unpickler
from
theano
import
config
from
theano.compile
import
DeepCopyOp
,
Rebroadcast
,
ViewOp
from
theano.misc.pkl_utils
import
CompatUnpickler
from
theano.gpuarray.type
import
GpuArrayType
,
gpuarray_shared_constructor
from
tests.gpuarray.config
import
test_ctx_name
...
...
@@ -122,10 +122,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag0():
fname
=
"GpuArray.pkl"
with
open
(
os
.
path
.
join
(
testfile_dir
,
fname
),
"rb"
)
as
fp
:
if
PY3
:
u
=
CompatUnpickler
(
fp
,
encoding
=
"latin1"
)
else
:
u
=
CompatUnpickler
(
fp
)
u
=
Unpickler
(
fp
,
encoding
=
"latin1"
)
mat
=
u
.
load
()
assert
isinstance
(
mat
,
pygpu
.
gpuarray
.
GpuArray
)
assert
np
.
asarray
(
mat
)[
0
]
==
-
42.0
...
...
tests/sandbox/test_multinomial.py
浏览文件 @
9f7a1b69
import
os
import
sys
import
numpy
as
np
import
theano
import
tests.unittest_tools
as
utt
from
theano
import
config
,
function
,
tensor
from
theano.sandbox
import
multinomial
from
theano.compat
import
PY3
from
theano.misc.pkl_utils
import
CompatUnpickler
def
test_n_samples_1
():
...
...
@@ -51,40 +45,6 @@ def test_n_samples_2():
assert
res
.
sum
()
==
i
def
test_n_samples_compatibility
():
# This test checks if the new change to MultinomialFromUniform is still compatible
# with old interface. Here I will load a graph created (using the old interface) as follows:
# RandomStreams = theano.sandbox.rng_mrg.MRG_RandomStreams
# th_rng = RandomStreams(12345)
# X = T.matrix('X')
# pvals = T.exp(X)
# pvals = pvals / pvals.sum(axis=1, keepdims=True)
# samples = th_rng.multinomial(pvals=pvals)
# pickle.dump([X, samples], open("multinomial_test_graph.pkl", "w"))
folder
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
with
open
(
os
.
path
.
join
(
folder
,
"multinomial_test_graph.pkl"
),
"rb"
)
as
pkl_file
:
if
PY3
:
u
=
CompatUnpickler
(
pkl_file
,
encoding
=
"latin1"
)
else
:
u
=
CompatUnpickler
(
pkl_file
)
try
:
X
,
samples
=
u
.
load
()
except
ImportError
:
# Windows sometimes fail with nonsensical errors like:
# ImportError: No module named type
# ImportError: No module named copy_reg
# when "type" and "copy_reg" are builtin modules.
if
sys
.
platform
==
"win32"
:
exc_type
,
exc_value
,
exc_trace
=
sys
.
exc_info
()
raise
raise
f
=
theano
.
function
([
X
],
samples
)
res
=
f
(
np
.
random
.
randn
(
20
,
10
))
assert
np
.
all
(
res
.
sum
(
axis
=
1
)
==
1
)
def
test_multinomial_0
():
# This tests the MultinomialFromUniform Op directly, not going through the
# multinomial() call in GPU random generation.
...
...
tests/sandbox/test_multinomial_wo_replacement.py
浏览文件 @
9f7a1b69
import
numpy
as
np
import
pytest
import
os
from
theano
import
config
,
function
,
tensor
from
theano.compat
import
PY3
from
theano.misc.pkl_utils
import
CompatUnpickler
from
theano.sandbox
import
multinomial
from
theano.sandbox.rng_mrg
import
MRG_RandomStreams
as
RandomStreams
...
...
@@ -214,14 +212,3 @@ class TestFunction:
avg_pvals
/=
avg_pvals
.
sum
()
avg_diff
=
np
.
mean
(
abs
(
avg_pvals
-
pvals
))
assert
avg_diff
<
mean_rtol
def
test_unpickle_legacy_op
(
self
):
testfile_dir
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
fname
=
"test_sandbox_multinomial_wo_replacement.pkl"
if
not
PY3
:
with
open
(
os
.
path
.
join
(
testfile_dir
,
fname
),
"r"
)
as
fp
:
u
=
CompatUnpickler
(
fp
)
m
=
u
.
load
()
print
(
m
)
assert
isinstance
(
m
,
multinomial
.
ChoiceFromUniform
)
tests/unittest_tools.py
浏览文件 @
9f7a1b69
...
...
@@ -255,7 +255,7 @@ class InferShapeTester:
else
:
shp
=
inp
.
shape
if
len
(
set
(
shp
))
!=
len
(
shp
):
_logger
.
warn
(
_logger
.
warn
ing
(
"While testing shape inference for
%
r, we received an"
" input with a shape that has some repeated values:
%
r"
", like a square matrix. This makes it impossible to"
...
...
theano/compile/debugmode.py
浏览文件 @
9f7a1b69
...
...
@@ -1437,7 +1437,7 @@ def _check_preallocated_output(
fn_attr_name
=
ops_with_inner_function
[
type
(
node
.
op
)]
fn
=
getattr
(
node
.
op
,
fn_attr_name
,
None
)
if
not
fn
or
not
hasattr
(
fn
,
"maker"
)
or
not
hasattr
(
fn
.
maker
,
"mode"
):
_logger
.
warn
(
_logger
.
warn
ing
(
"Expected theano function not found in
%
s.
%
s"
,
node
.
op
,
fn_attr_name
)
else
:
...
...
@@ -1482,7 +1482,7 @@ def _check_preallocated_output(
if
not
out_map
:
# Map is empty, there is no need to execute thunk() again
_logger
.
warn
(
"
%
s: out_map is empty"
,
name
)
_logger
.
warn
ing
(
"
%
s: out_map is empty"
,
name
)
continue
# Copy the inputs over, if they were marked as destroyed or viewed
...
...
@@ -1904,7 +1904,7 @@ class _Linker(gof.link.LocalLinker):
thunks_py
.
append
(
None
)
if
not
self
.
maker
.
mode
.
check_c_code
and
thunks_py
[
-
1
]
is
None
:
_logger
.
warn
(
_logger
.
warn
ing
(
"Op
%
s doesn't have a perform, "
"forcing check of the C code"
%
node
.
op
)
...
...
@@ -1921,7 +1921,7 @@ class _Linker(gof.link.LocalLinker):
elif
thunks_c
[
-
1
]
is
None
:
thunks_c
[
-
1
]
=
thunk_other
else
:
_logger
.
warn
(
_logger
.
warn
ing
(
"We won't check the perform function "
"of node '
%
s' but we will check its "
"make_thunk function"
%
node
...
...
theano/gof/cmodule.py
浏览文件 @
9f7a1b69
...
...
@@ -2055,7 +2055,7 @@ class GCC_compiler(Compiler):
and
"clang-omp++"
not
in
theano
.
config
.
cxx
and
"icpc"
not
in
theano
.
config
.
cxx
):
_logger
.
warn
(
_logger
.
warn
ing
(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization"
" specific to g++ that tell to compile for a specific CPU."
...
...
@@ -2124,7 +2124,7 @@ class GCC_compiler(Compiler):
)
else
:
reported_lines
=
native_lines
_logger
.
warn
(
_logger
.
warn
ing
(
"OPTIMIZATION WARNING: Theano was not able to find the"
" g++ parameters that tune the compilation to your "
" specific CPU. This can slow down the execution of Theano"
...
...
@@ -2137,7 +2137,7 @@ class GCC_compiler(Compiler):
default_lines
=
get_lines
(
"
%
s -E -v -"
%
theano
.
config
.
cxx
)
_logger
.
info
(
"g++ default lines:
%
s"
,
default_lines
)
if
len
(
default_lines
)
<
1
:
_logger
.
warn
(
_logger
.
warn
ing
(
"OPTIMIZATION WARNING: Theano was not able to find the"
" default g++ parameters. This is needed to tune"
" the compilation to your specific"
...
...
theano/gof/compilelock.py
浏览文件 @
9f7a1b69
...
...
@@ -349,7 +349,7 @@ def refresh_lock(lock_file):
# This way, only 1 test would fail.
while
get_lock
.
n_lock
>
0
:
release_lock
()
_logger
.
warn
(
_logger
.
warn
ing
(
"Refreshing lock failed, we release the"
" lock before raising again the exception"
)
...
...
theano/gof/graph.py
浏览文件 @
9f7a1b69
...
...
@@ -92,7 +92,7 @@ class Apply(Node):
def
__init__
(
self
,
op
,
inputs
,
outputs
):
self
.
op
=
op
self
.
inputs
=
[]
self
.
tag
=
utils
.
s
cratchpad
()
self
.
tag
=
utils
.
S
cratchpad
()
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
raise
TypeError
(
"The inputs of an Apply must be a list or tuple"
)
...
...
@@ -383,7 +383,8 @@ class Variable(Node):
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
super
(
Variable
,
self
)
.
__init__
()
self
.
tag
=
utils
.
scratchpad
()
self
.
tag
=
utils
.
ValidatingScratchpad
(
"test_value"
,
type
.
filter
)
self
.
type
=
type
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
...
...
theano/gof/op.py
浏览文件 @
9f7a1b69
...
...
@@ -553,29 +553,10 @@ class PureOp(object):
elif
isinstance
(
v
,
SharedVariable
):
return
v
.
get_value
(
borrow
=
True
,
return_internal_type
=
True
)
elif
isinstance
(
v
,
graph
.
Variable
)
and
hasattr
(
v
.
tag
,
"test_value"
):
# ensure that the test value is correct
try
:
ret
=
v
.
type
.
filter
(
v
.
tag
.
test_value
)
except
Exception
as
e
:
# Better error message.
detailed_err_msg
=
(
"For compute_test_value, one input test value does not"
" have the requested type.
\n
"
)
detailed_err_msg
+=
utils
.
get_variable_trace_string
(
v
)
return
v
.
tag
.
test_value
detailed_err_msg
+=
(
"
\n
The error when converting the test value to that"
" variable type:"
)
# We need to only have 1 args and it should be of type
# string. Otherwise, it print the tuple and so the
# new line do not get printed.
args
=
(
detailed_err_msg
,)
+
tuple
(
str
(
arg
)
for
arg
in
e
.
args
)
e
.
args
=
(
"
\n
"
.
join
(
args
),)
raise
return
ret
detailed_err_msg
=
utils
.
get_variable_trace_string
(
v
)
raise
AttributeError
(
"
%
s has no test value
%
s"
%
(
v
,
detailed_err_msg
))
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
...
...
@@ -1057,48 +1038,13 @@ def missing_test_message(msg):
assert
action
in
[
"ignore"
,
"off"
]
def
debug_error_message
(
msg
):
"""
Displays a message saying that an error was found in some
test_values. Becomes a warning or a ValueError depending on
config.compute_test_value.
"""
action
=
config
.
compute_test_value
# this message should never be called when the debugger is off
assert
action
!=
"off"
if
action
in
[
"raise"
,
"ignore"
]:
raise
ValueError
(
msg
)
else
:
assert
action
==
"warn"
warnings
.
warn
(
msg
,
stacklevel
=
2
)
def
debug_assert
(
condition
,
msg
=
None
):
"""
Customized assert with options to ignore the assert
with just a warning
"""
if
msg
is
None
:
msg
=
"debug_assert failed"
if
not
condition
:
action
=
config
.
compute_test_value
if
action
in
[
"raise"
,
"ignore"
]:
raise
AssertionError
(
msg
)
else
:
assert
action
==
"warn"
warnings
.
warn
(
msg
,
stacklevel
=
2
)
def
get_debug_values
(
*
args
):
"""
Intended use:
for val_1, ..., val_n in get_debug_values(var_1, ..., var_n):
if some condition on val_1, ..., val_n is not met:
debug_error
_message("condition was not met")
missing_test
_message("condition was not met")
Given a list of variables, get_debug_values does one of three things:
...
...
@@ -1128,10 +1074,10 @@ def get_debug_values(*args):
except
AttributeError
:
if
hasattr
(
arg
,
"name"
)
and
arg
.
name
is
not
None
:
missing_test_message
(
"Argument
"
+
str
(
i
)
+
"('"
+
arg
.
name
+
"') has no test value"
"Argument
{} ('{}') has no test value"
.
format
(
i
,
arg
.
name
)
)
else
:
missing_test_message
(
"Argument
"
+
str
(
i
)
+
" has no test value"
)
missing_test_message
(
"Argument
{} has no test value"
.
format
(
i
)
)
return
[]
if
len
(
rval
)
==
1
:
...
...
theano/gof/utils.py
浏览文件 @
9f7a1b69
...
...
@@ -239,7 +239,7 @@ class object2(with_metaclass(MetaObject, object)):
return
not
self
==
other
class
s
cratchpad
(
object
):
class
S
cratchpad
(
object
):
def
clear
(
self
):
self
.
__dict__
.
clear
()
...
...
@@ -259,6 +259,23 @@ class scratchpad(object):
print
(
"
%
s:
%
s"
%
(
k
,
v
))
class
ValidatingScratchpad
(
Scratchpad
):
"""This `Scratchpad` validates attribute values."""
def
__init__
(
self
,
attr
,
attr_filter
):
super
()
.
__init__
()
object
.
__setattr__
(
self
,
"attr"
,
attr
)
object
.
__setattr__
(
self
,
"attr_filter"
,
attr_filter
)
def
__setattr__
(
self
,
attr
,
obj
):
if
getattr
(
self
,
"attr"
,
None
)
==
attr
:
obj
=
self
.
attr_filter
(
obj
)
return
object
.
__setattr__
(
self
,
attr
,
obj
)
class
D
:
def
__init__
(
self
,
**
d
):
self
.
__dict__
.
update
(
d
)
...
...
theano/gof/vm.py
浏览文件 @
9f7a1b69
...
...
@@ -924,7 +924,7 @@ class VM_Linker(link.LocalLinker):
if
self
.
use_cloop
and
(
self
.
callback
is
not
None
or
self
.
callback_input
is
not
None
):
logger
.
warn
(
"CVM does not support callback, using Stack VM."
)
logger
.
warn
ing
(
"CVM does not support callback, using Stack VM."
)
if
self
.
use_cloop
and
config
.
profile_memory
:
warnings
.
warn
(
"CVM does not support memory profile, using Stack VM."
)
if
not
self
.
use_cloop
and
self
.
allow_partial_eval
:
...
...
theano/misc/pkl_utils.py
浏览文件 @
9f7a1b69
...
...
@@ -12,6 +12,9 @@ import sys
import
tempfile
import
zipfile
import
warnings
import
theano
from
collections
import
defaultdict
from
contextlib
import
closing
from
pickle
import
HIGHEST_PROTOCOL
...
...
@@ -22,10 +25,7 @@ try:
except
ImportError
:
DEFAULT_PROTOCOL
=
HIGHEST_PROTOCOL
import
theano
from
theano
import
config
from
theano.compat
import
PY3
from
six
import
string_types
from
theano.compile.sharedvalue
import
SharedVariable
__docformat__
=
"restructuredtext en"
...
...
@@ -68,7 +68,7 @@ class StripPickler(Pickler):
def
save
(
self
,
obj
):
# Remove the tag.trace attribute from Variable and Apply nodes
if
isinstance
(
obj
,
theano
.
gof
.
utils
.
s
cratchpad
):
if
isinstance
(
obj
,
theano
.
gof
.
utils
.
S
cratchpad
):
for
tag
in
self
.
tag_to_remove
:
if
hasattr
(
obj
,
tag
):
del
obj
.
__dict__
[
tag
]
...
...
@@ -80,93 +80,6 @@ class StripPickler(Pickler):
return
Pickler
.
save
(
self
,
obj
)
# Make an unpickler that tries encoding byte streams before raising TypeError.
# This is useful with python 3, in order to unpickle files created with
# python 2.
# This code is taken from Pandas, https://github.com/pydata/pandas,
# under the same 3-clause BSD license.
def
load_reduce
(
self
):
stack
=
self
.
stack
args
=
stack
.
pop
()
func
=
stack
[
-
1
]
try
:
value
=
func
(
*
args
)
except
Exception
:
# try to reencode the arguments
if
self
.
encoding
is
not
None
:
new_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
string_types
):
new_args
.
append
(
arg
.
encode
(
self
.
encoding
))
else
:
new_args
.
append
(
arg
)
args
=
tuple
(
new_args
)
try
:
stack
[
-
1
]
=
func
(
*
args
)
return
except
Exception
:
pass
# if self.is_verbose:
# print(sys.exc_info())
# print(func, args)
raise
stack
[
-
1
]
=
value
if
PY3
:
class
CompatUnpickler
(
pickle
.
_Unpickler
):
"""
Allow to reload in python 3 some pickled numpy ndarray.
.. versionadded:: 0.8
Examples
--------
::
with open(fname, 'rb') as fp:
if PY3:
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
mat = u.load()
"""
pass
# Register `load_reduce` defined above in CompatUnpickler
CompatUnpickler
.
dispatch
[
pickle
.
REDUCE
[
0
]]
=
load_reduce
else
:
class
CompatUnpickler
(
pickle
.
Unpickler
):
"""
Allow to reload in python 3 some pickled numpy ndarray.
.. versionadded:: 0.8
Examples
--------
::
with open(fname, 'rb') as fp:
if PY3:
u = CompatUnpickler(fp, encoding="latin1")
else:
u = CompatUnpickler(fp)
mat = u.load()
"""
pass
class
PersistentNdarrayID
(
object
):
"""Persist ndarrays in an object by saving them to a zip file.
...
...
theano/printing.py
浏览文件 @
9f7a1b69
...
...
@@ -371,11 +371,11 @@ class Print(Op):
return
(
1
,)
class
PrinterState
(
gof
.
utils
.
s
cratchpad
):
class
PrinterState
(
gof
.
utils
.
S
cratchpad
):
def
__init__
(
self
,
props
=
None
,
**
more_props
):
if
props
is
None
:
props
=
{}
elif
isinstance
(
props
,
gof
.
utils
.
s
cratchpad
):
elif
isinstance
(
props
,
gof
.
utils
.
S
cratchpad
):
self
.
__update__
(
props
)
else
:
self
.
__dict__
.
update
(
props
)
...
...
@@ -862,7 +862,7 @@ def pydotprint(
):
cond
=
node
if
cond
is
None
:
_logger
.
warn
(
_logger
.
warn
ing
(
"pydotprint: cond_highlight is set but there is no"
" IfElse node in the graph"
)
...
...
theano/tensor/nnet/conv.py
浏览文件 @
9f7a1b69
...
...
@@ -559,7 +559,7 @@ class ConvOp(OpenMPOp):
" bsize(
%
i). We revert it to
%
i. This"
" won't change the result, but may make it slower."
)
_logger
.
warn
(
warnstr
,
self
.
unroll_batch
,
self
.
bsize
,
new
)
_logger
.
warn
ing
(
warnstr
,
self
.
unroll_batch
,
self
.
bsize
,
new
)
self
.
unroll_batch
=
new
...
...
@@ -585,7 +585,7 @@ class ConvOp(OpenMPOp):
" nkern(
%
i). We revert it to
%
i. This"
" won't change the result, but may make it slower."
)
_logger
.
warn
(
warnstr
,
self
.
unroll_kern
,
self
.
nkern
,
new
)
_logger
.
warn
ing
(
warnstr
,
self
.
unroll_kern
,
self
.
nkern
,
new
)
self
.
unroll_kern
=
new
self
.
outshp
=
get_conv_output_shape
(
...
...
theano/tensor/opt.py
浏览文件 @
9f7a1b69
...
...
@@ -3251,7 +3251,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
# sl.stop backwards
n_val
=
sl1
.
stop
-
1
-
sl2
*
sl1
.
step
if
config
.
warn
.
subtensor_merge_bug
:
warnings
.
warn
(
warnings
.
warn
ing
(
(
"Your current code is fine, but Theano versions "
"prior to 0.5rc2 might have given an incorrect result. "
...
...
@@ -3843,7 +3843,7 @@ def local_adv_sub1_adv_inc_sub1(node):
if
not
inp
.
owner
.
op
.
set_instead_of_inc
:
if
config
.
warn
.
inc_subtensor1_opt
:
warnings
.
warn
(
warnings
.
warn
ing
(
"Your current code is fine, but Theano versions "
"between 0.7rc1 and 0.10 (or development versions "
"between Nov. 2014 and May 2017) "
...
...
@@ -5851,7 +5851,7 @@ def local_sum_prod_div_dimshuffle(node):
break
if
compatible_dims
:
_logger
.
warn
(
_logger
.
warn
ing
(
"WARNING: Your current code is fine, but"
" Theano versions between "
"rev. 3bd9b789f5e8 (2010-06-16) and"
...
...
@@ -5906,7 +5906,7 @@ def local_sum_prod_div_dimshuffle(node):
if
config
.
warn
.
sum_div_dimshuffle_bug
and
isinstance
(
node
.
op
,
T
.
Sum
):
_logger
.
warn
(
_logger
.
warn
ing
(
"WARNING: Your current code is fine,"
" but Theano versions between "
"rev. 3bd9b789f5e8 (2010-06-16) and"
...
...
@@ -6016,7 +6016,7 @@ def local_op_of_op(node):
and
newaxis
!=
newaxis_old
and
len
(
newaxis
)
==
len
(
newaxis_old
)
):
_logger
.
warn
(
_logger
.
warn
ing
(
"WARNING (YOUR CURRENT CODE IS FINE): Theano "
"versions between version 9923a40c7b7a and August "
"2nd, 2010 generated bugged code in this case. "
...
...
@@ -6102,7 +6102,7 @@ def local_reduce_join(node):
# I put this warning late to don't add extra warning.
if
len
(
reduce_axis
)
!=
1
or
0
not
in
reduce_axis
:
if
theano
.
config
.
warn
.
reduce_join
:
warnings
.
warn
(
warnings
.
warn
ing
(
(
"Your current code is fine, but Theano versions "
"prior to 0.7 (or this development version Sept 2014) "
...
...
@@ -7229,19 +7229,21 @@ register_stabilize(local_erf_neg_minus_one2)
register_specialize
(
local_erf_neg_minus_one2
)
# Stability optimization
# log(erfc(x)) => when x>threashold,
# -x**2-log(x)-.5*log(pi)+log(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))
# for float64: threshold=26.641747557 was choosed with:
# [(i,numpy.log(scipy.special.erfc(numpy.asarray([i],dtype='float64'))))
# for i in numpy.arange(26.641747557,26.6417475571,.00000000001)]
# for float32: threshold=10.0541949, [(i,numpy.log(scipy.special.erfc(
# numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
# 10.0541948,10.0541951,.0000001)]
@register_stabilize
@register_specialize
@gof.local_optimizer
([
T
.
log
])
def
local_log_erfc
(
node
):
"""Stability optimization for `log(erfc(x))`.
log(erfc(x)) => when x>threshold,
-x**2-log(x)-.5*log(pi)+log(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))
for float64: threshold=26.641747557 was choosed with:
[(i,numpy.log(scipy.special.erfc(numpy.asarray([i],dtype='float64'))))
for i in numpy.arange(26.641747557,26.6417475571,.00000000001)]
for float32: threshold=10.0541949, [(i,numpy.log(scipy.special.erfc(
numpy.asarray([i],dtype='float32')))) for i in numpy.arange(
10.0541948,10.0541951,.0000001)]
"""
if
node
.
op
!=
T
.
log
:
return
False
if
not
node
.
inputs
[
0
]
.
owner
or
node
.
inputs
[
0
]
.
owner
.
op
!=
T
.
erfc
:
...
...
@@ -7270,21 +7272,26 @@ def local_log_erfc(node):
return
[
ret
]
# Stability optimization of the grad of log(erfc(x))
# ([y*]exp(-(x**2)))/erfc(x) # The y* is optional
# ([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explanation
# for float32: threshold=9.3 see at the end of the fct for the explanation
# TODO: remove the contraint that there are only 2 inputs to exp(x**2)
# is the second.
# TODO: at the test point 10 in float32, there is instability in the original
# value. The original gives -30.0, the stab -20.1 and in float64 -18.1.
# Make it so that the test does not generate an error in that case!
@register_stabilize
@register_specialize
@gof.local_optimizer
([
T
.
true_div
])
def
local_grad_log_erfc_neg
(
node
):
"""Stability optimization for the grad of `log(erfc(x))`.
([y*]exp(-(x**2)))/erfc(x) # The y* is optional
([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold,
sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
for float64: threshold=26.63 see at the end of the fct for the explanation
for float32: threshold=9.3 see at the end of the fct for the explanation
TODO: remove the contraint that there are only 2 inputs to exp(x**2)
is the second.
TODO: at the test point 10 in float32, there is instability in the original
value. The original gives -30.0, the stab -20.1 and in float64 -18.1.
Make it so that the test does not generate an error in that case!
"""
if
node
.
op
!=
T
.
true_div
:
return
False
if
not
node
.
inputs
[
1
]
.
owner
or
node
.
inputs
[
1
]
.
owner
.
op
!=
T
.
erfc
:
...
...
@@ -7602,43 +7609,52 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1]
"""
# ###############
# # Loop fusion #
# ###############
def
local_elemwise_fusion_op
(
OP
,
max_input_fct
=
lambda
node
:
32
,
maker
=
None
):
"""
We parametrize it to make it work for Elemwise and GpuElemwise op.
def
local_elemwise_fusion_op
(
op_class
,
max_input_fct
=
lambda
node
:
32
,
maker
=
None
):
"""Create a recursive function that fuses `Elemwise` `Op`s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` and `GpuElemwise` `Op`s.
Parameters
----------
OP
GpuElemwise or Elemwise class (the one that we want to fuse)
max_input_fct
A function that returns the maximum number of inputs
that this elemwise can take (useful for GpuElemwise).
GPU kernel currently has a limit of 256 bytes for
the size of all parameters passed to it. As currently
we pass many information only by parameter, we must
limit how many ops we fuse together to avoid busting
that 256 limit.
op_class : type
`GpuElemwise` or `Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take (useful for `GpuElemwise`). The GPU kernel currently has a
limit of 256 bytes for the size of all parameters passed to it. As
currently we pass a lot of information only by parameter, we must limit how
many `Op`s we fuse together to avoid busting that 256 limit.
On the CPU we limit to 32 input variables
since that is the maximum numpy support.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature `(node, *args)` that constructs an
`op_class` instance (e.g. `op_class(*args)`).
"""
if
maker
is
None
:
def
maker
(
node
,
scalar_op
):
return
OP
(
scalar_op
)
return
op_class
(
scalar_op
)
def
local_fuse
(
node
):
"""
As part of specialization, we fuse two consecutive elemwise Ops of the
"""Fuse `Elemwise` `Op`s in a node.
As part of specialization, we fuse two consecutive elemwise `Op`s of the
same shape.
For mixed dtype, we let the
Composite op
do the cast. It lets the C
For mixed dtype, we let the
`Composite` `Op`
do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by theano itself.
The number of dimensions is validated at call time by Theano itself.
"""
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
...
...
@@ -7665,12 +7681,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if
type
(
node
.
op
)
is
not
OP
:
if
type
(
node
.
op
)
is
not
op_class
:
return
False
if
len
(
node
.
outputs
)
>
1
:
# We don't support
the fusion for node
with multiple outputs.
# We don't support
fusion for nodes
with multiple outputs.
return
inputs
=
[]
# inputs of the new Elemwise op.
s_inputs
=
[]
# inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
...
...
@@ -7691,7 +7708,6 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
for
i
in
node
.
inputs
:
do_fusion
=
False
catch
=
False
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input
=
[]
...
...
@@ -7704,7 +7720,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# we still want to fusion. So we take the set.
if
(
i
.
owner
and
isinstance
(
i
.
owner
.
op
,
OP
)
and
isinstance
(
i
.
owner
.
op
,
op_class
)
and
len
(
set
([
n
for
n
,
idx
in
i
.
clients
]))
==
1
and
# Do not merge elemwise that don't have the same
...
...
@@ -7712,7 +7728,6 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# computation due to broadcast.
i
.
owner
.
outputs
[
0
]
.
broadcastable
==
node
.
outputs
[
0
]
.
broadcastable
):
do_fusion
=
True
try
:
tmp_s_input
=
[]
# we should not put duplicate input into s_inputs and inputs
...
...
@@ -7728,12 +7743,17 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
if
tv
.
size
>
0
:
tmp
.
tag
.
test_value
=
tv
.
flatten
()[
0
]
else
:
tmp
.
tag
.
test_value
=
tv
_logger
.
warning
(
"Cannot construct a scalar test value"
" from a test value with no size: {}"
.
format
(
ii
)
)
except
AttributeError
:
pass
tmp_s_input
.
append
(
tmp
)
tmp_input
.
append
(
ii
)
tmp_scalar
.
append
(
tmp_s_input
[
-
1
])
s_op
=
i
.
owner
.
op
.
scalar_op
(
*
tmp_s_input
,
return_list
=
True
)
# if the scalar_op don't have a c implementation,
...
...
@@ -7746,12 +7766,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
[
"z"
for
z
in
i
.
owner
.
outputs
],
{
"fail"
:
"
%(fail)
s"
},
)
except
MethodNotDefined
:
catch
=
True
except
NotImplementedError
:
catch
=
True
if
catch
:
_logger
.
info
(
do_fusion
=
True
except
(
NotImplementedError
,
MethodNotDefined
):
_logger
.
warning
(
(
"
%
s does not implement the c_code function."
" As well as being potentially slow, this"
...
...
@@ -7782,8 +7801,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
s_inputs
.
extend
(
tmp_scalar
)
s_g
.
extend
(
s_op
)
else
:
# We must support the case where the same variable appear many
# time
in the inputs
# We must support the case where the same variable appear
s
many
# time
s with
in the inputs
if
inputs
.
count
(
i
)
==
node
.
inputs
.
count
(
i
):
s
=
s_inputs
[
inputs
.
index
(
i
)]
else
:
...
...
@@ -7819,8 +7838,8 @@ your code will run correctly, but may be slower."""
[
"z"
for
x
in
s_new_out
],
{
"fail"
:
"
%(fail)
s"
},
)
except
MethodNotDefined
:
_logger
.
info
(
except
(
NotImplementedError
,
MethodNotDefined
)
:
_logger
.
warning
(
(
"
%
s does not implement the c_code function."
" As well as being potentially slow, this disables "
...
...
@@ -7828,29 +7847,19 @@ your code will run correctly, but may be slower."""
)
%
str
(
s_new_out
[
0
]
.
owner
.
op
)
)
return
False
except
NotImplementedError
:
_logger
.
info
(
(
"
%
s does not implement the c_code function. As well"
" as being potentially slow, this disables loop"
" fusion of this op."
)
%
str
(
s_new_out
[
0
]
.
owner
.
op
)
)
return
False
# create the composite op.
C
=
scalar
.
Composite
(
s_inputs
,
s_new_out
)
composite_op
=
scalar
.
Composite
(
s_inputs
,
s_new_out
)
# create the new node.
# Do not call make_node to have test_value
n
=
maker
(
node
,
C
)(
*
inputs
)
.
owner
assert
len
(
n
.
outputs
)
==
1
assert
node
.
outputs
[
0
]
.
dtype
==
n
.
outputs
[
0
]
.
dtype
new_node
=
maker
(
node
,
composite_op
)(
*
inputs
)
.
owner
assert
len
(
new_node
.
outputs
)
==
1
assert
node
.
outputs
[
0
]
.
dtype
==
new_node
.
outputs
[
0
]
.
dtype
if
len
(
n
.
inputs
)
>
max_nb_input
:
_logger
.
info
(
if
len
(
n
ew_node
.
inputs
)
>
max_nb_input
:
_logger
.
warning
(
"loop fusion failed because Op would exceed"
" kernel argument limit."
)
return
False
...
...
@@ -7858,16 +7867,15 @@ your code will run correctly, but may be slower."""
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while
True
:
ret
=
local_fuse
(
n
)
ret
=
local_fuse
(
n
ew_node
)
if
ret
is
not
False
and
ret
is
not
None
:
# print n,ret
assert
len
(
ret
)
==
len
(
n
.
outputs
)
assert
len
(
ret
)
==
len
(
new_node
.
outputs
)
assert
len
(
ret
)
==
1
n
=
ret
[
0
]
.
owner
n
ew_node
=
ret
[
0
]
.
owner
else
:
break
return
n
.
outputs
return
n
ew_node
.
outputs
return
local_fuse
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论