Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d0458048
提交
d0458048
authored
10月 18, 2020
作者:
Brandon T. Willard
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Raise TestValueError instead of AttributeError when test values are missing
上级
a1938f76
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
37 行增加
和
28 行删除
+37
-28
test_op.py
tests/gof/test_op.py
+5
-5
fg.py
theano/gof/fg.py
+2
-2
graph.py
theano/gof/graph.py
+5
-4
op.py
theano/gof/op.py
+4
-3
utils.py
theano/gof/utils.py
+6
-0
scan.py
theano/scan_module/scan.py
+4
-3
scan_utils.py
theano/scan_module/scan_utils.py
+3
-4
blas.py
theano/tensor/blas.py
+5
-4
opt.py
theano/tensor/opt.py
+3
-3
没有找到文件。
tests/gof/test_op.py
浏览文件 @
d0458048
...
@@ -6,11 +6,11 @@ import theano.tensor as tt
...
@@ -6,11 +6,11 @@ import theano.tensor as tt
from
six
import
string_types
from
six
import
string_types
from
theano
import
scalar
,
shared
,
config
from
theano
import
scalar
,
shared
,
config
from
theano.gof
import
utils
from
theano.configparser
import
change_flags
from
theano.configparser
import
change_flags
from
theano.gof.graph
import
Apply
,
Variable
from
theano.gof.graph
import
Apply
,
Variable
from
theano.gof.type
import
Generic
,
Type
from
theano.gof.type
import
Generic
,
Type
from
theano.gof.op
import
Op
from
theano.gof.op
import
Op
from
theano.gof.utils
import
TestValueError
,
MethodNotDefined
def
as_variable
(
x
):
def
as_variable
(
x
):
...
@@ -175,7 +175,7 @@ class TestMakeThunk:
...
@@ -175,7 +175,7 @@ class TestMakeThunk:
o
=
IncOnePython
()(
i
)
o
=
IncOnePython
()(
i
)
# Check that the c_code function is not implemented
# Check that the c_code function is not implemented
with
pytest
.
raises
((
NotImplementedError
,
utils
.
MethodNotDefined
)):
with
pytest
.
raises
((
NotImplementedError
,
MethodNotDefined
)):
o
.
owner
.
op
.
c_code
(
o
.
owner
,
"o"
,
[
"x"
],
"z"
,
{
"fail"
:
""
})
o
.
owner
.
op
.
c_code
(
o
.
owner
,
"o"
,
[
"x"
],
"z"
,
{
"fail"
:
""
})
storage_map
=
{
i
:
[
np
.
int32
(
3
)],
o
:
[
None
]}
storage_map
=
{
i
:
[
np
.
int32
(
3
)],
o
:
[
None
]}
...
@@ -211,7 +211,7 @@ class TestMakeThunk:
...
@@ -211,7 +211,7 @@ class TestMakeThunk:
o
=
IncOneC
()(
i
)
o
=
IncOneC
()(
i
)
# Check that the perform function is not implemented
# Check that the perform function is not implemented
with
pytest
.
raises
((
NotImplementedError
,
utils
.
MethodNotDefined
)):
with
pytest
.
raises
((
NotImplementedError
,
MethodNotDefined
)):
o
.
owner
.
op
.
perform
(
o
.
owner
,
0
,
[
None
])
o
.
owner
.
op
.
perform
(
o
.
owner
,
0
,
[
None
])
storage_map
=
{
i
:
[
np
.
int32
(
3
)],
o
:
[
None
]}
storage_map
=
{
i
:
[
np
.
int32
(
3
)],
o
:
[
None
]}
...
@@ -227,7 +227,7 @@ class TestMakeThunk:
...
@@ -227,7 +227,7 @@ class TestMakeThunk:
assert
compute_map
[
o
][
0
]
assert
compute_map
[
o
][
0
]
assert
storage_map
[
o
][
0
]
==
4
assert
storage_map
[
o
][
0
]
==
4
else
:
else
:
with
pytest
.
raises
((
NotImplementedError
,
utils
.
MethodNotDefined
)):
with
pytest
.
raises
((
NotImplementedError
,
MethodNotDefined
)):
thunk
()
thunk
()
def
test_no_make_node
(
self
):
def
test_no_make_node
(
self
):
...
@@ -326,6 +326,6 @@ def test_get_test_values_success():
...
@@ -326,6 +326,6 @@ def test_get_test_values_success():
def
test_get_test_values_exc
():
def
test_get_test_values_exc
():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with
pytest
.
raises
(
Attribut
eError
):
with
pytest
.
raises
(
TestValu
eError
):
x
=
tt
.
vector
()
x
=
tt
.
vector
()
assert
op
.
get_test_values
(
x
)
==
[]
assert
op
.
get_test_values
(
x
)
==
[]
theano/gof/fg.py
浏览文件 @
d0458048
...
@@ -14,7 +14,7 @@ from six.moves import StringIO
...
@@ -14,7 +14,7 @@ from six.moves import StringIO
from
theano
import
config
from
theano
import
config
from
theano.gof
import
graph
,
utils
,
toolbox
from
theano.gof
import
graph
,
utils
,
toolbox
from
theano.gof.utils
import
get_variable_trace_string
from
theano.gof.utils
import
get_variable_trace_string
,
TestValueError
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
NullType
=
None
NullType
=
None
...
@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
...
@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
try
:
try
:
tval
=
theano
.
gof
.
op
.
get_test_value
(
r
)
tval
=
theano
.
gof
.
op
.
get_test_value
(
r
)
new_tval
=
theano
.
gof
.
op
.
get_test_value
(
new_r
)
new_tval
=
theano
.
gof
.
op
.
get_test_value
(
new_r
)
except
Attribut
eError
:
except
TestValu
eError
:
pass
pass
else
:
else
:
tval_shape
=
getattr
(
tval
,
"shape"
,
None
)
tval_shape
=
getattr
(
tval
,
"shape"
,
None
)
...
...
theano/gof/graph.py
浏览文件 @
d0458048
...
@@ -14,6 +14,7 @@ from six import string_types, integer_types
...
@@ -14,6 +14,7 @@ from six import string_types, integer_types
from
theano
import
config
from
theano
import
config
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof.utils
import
TestValueError
from
theano.misc.ordered_set
import
OrderedSet
from
theano.misc.ordered_set
import
OrderedSet
__docformat__
=
"restructuredtext en"
__docformat__
=
"restructuredtext en"
...
@@ -405,12 +406,12 @@ class Variable(Node):
...
@@ -405,12 +406,12 @@ class Variable(Node):
Raises
Raises
------
------
Attribut
eError
TestValu
eError
"""
"""
if
not
hasattr
(
self
.
tag
,
"test_value"
):
if
not
hasattr
(
self
.
tag
,
"test_value"
):
detailed_err_msg
=
utils
.
get_variable_trace_string
(
self
)
detailed_err_msg
=
utils
.
get_variable_trace_string
(
self
)
raise
Attribut
eError
(
raise
TestValu
eError
(
"{} has no test value {}"
.
format
(
self
,
detailed_err_msg
)
"{} has no test value {}"
.
format
(
self
,
detailed_err_msg
)
)
)
...
@@ -436,7 +437,7 @@ class Variable(Node):
...
@@ -436,7 +437,7 @@ class Variable(Node):
overridden by classes with non printable test_value to provide a
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
suitable representation of the test_value.
"""
"""
return
repr
(
theano
.
gof
.
op
.
get_test_value
(
self
))
return
repr
(
self
.
get_test_value
(
))
def
__repr__
(
self
,
firstPass
=
True
):
def
__repr__
(
self
,
firstPass
=
True
):
"""Return a repr of the Variable.
"""Return a repr of the Variable.
...
@@ -449,7 +450,7 @@ class Variable(Node):
...
@@ -449,7 +450,7 @@ class Variable(Node):
if
config
.
print_test_value
and
firstPass
:
if
config
.
print_test_value
and
firstPass
:
try
:
try
:
to_print
.
append
(
self
.
__repr_test_value__
())
to_print
.
append
(
self
.
__repr_test_value__
())
except
Attribut
eError
:
except
TestValu
eError
:
pass
pass
return
"
\n
"
.
join
(
to_print
)
return
"
\n
"
.
join
(
to_print
)
...
...
theano/gof/op.py
浏览文件 @
d0458048
...
@@ -22,6 +22,7 @@ from six import PY3
...
@@ -22,6 +22,7 @@ from six import PY3
from
theano
import
config
from
theano
import
config
from
theano.gof
import
graph
from
theano.gof
import
graph
from
theano.gof
import
utils
from
theano.gof
import
utils
from
theano.gof.utils
import
TestValueError
from
theano.gof.cmodule
import
GCC_compiler
from
theano.gof.cmodule
import
GCC_compiler
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.fg
import
FunctionGraph
...
@@ -68,7 +69,7 @@ def compute_test_value(node):
...
@@ -68,7 +69,7 @@ def compute_test_value(node):
try
:
try
:
storage_map
[
ins
]
=
[
ins
.
get_test_value
()]
storage_map
[
ins
]
=
[
ins
.
get_test_value
()]
compute_map
[
ins
]
=
[
True
]
compute_map
[
ins
]
=
[
True
]
except
Attribut
eError
:
except
TestValu
eError
:
# no test-value was specified, act accordingly
# no test-value was specified, act accordingly
if
config
.
compute_test_value
==
"warn"
:
if
config
.
compute_test_value
==
"warn"
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -1073,7 +1074,7 @@ def missing_test_message(msg):
...
@@ -1073,7 +1074,7 @@ def missing_test_message(msg):
"""
"""
action
=
config
.
compute_test_value
action
=
config
.
compute_test_value
if
action
==
"raise"
:
if
action
==
"raise"
:
raise
Attribut
eError
(
msg
)
raise
TestValu
eError
(
msg
)
elif
action
==
"warn"
:
elif
action
==
"warn"
:
warnings
.
warn
(
msg
,
stacklevel
=
2
)
warnings
.
warn
(
msg
,
stacklevel
=
2
)
else
:
else
:
...
@@ -1113,7 +1114,7 @@ def get_test_values(*args):
...
@@ -1113,7 +1114,7 @@ def get_test_values(*args):
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
try
:
try
:
rval
.
append
(
get_test_value
(
arg
))
rval
.
append
(
get_test_value
(
arg
))
except
Attribut
eError
:
except
TestValu
eError
:
if
hasattr
(
arg
,
"name"
)
and
arg
.
name
is
not
None
:
if
hasattr
(
arg
,
"name"
)
and
arg
.
name
is
not
None
:
missing_test_message
(
missing_test_message
(
"Argument {} ('{}') has no test value"
.
format
(
i
,
arg
.
name
)
"Argument {} ('{}') has no test value"
.
format
(
i
,
arg
.
name
)
...
...
theano/gof/utils.py
浏览文件 @
d0458048
...
@@ -156,6 +156,12 @@ def hashtype(self):
...
@@ -156,6 +156,12 @@ def hashtype(self):
undef
=
object
()
undef
=
object
()
class
TestValueError
(
Exception
):
"""Base exception class for all test value errors."""
pass
class
MethodNotDefined
(
Exception
):
class
MethodNotDefined
(
Exception
):
"""
"""
To be raised by functions defined as part of an interface.
To be raised by functions defined as part of an interface.
...
...
theano/scan_module/scan.py
浏览文件 @
d0458048
...
@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
...
@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
from
theano.compile
import
SharedVariable
,
function
,
ops
from
theano.compile
import
SharedVariable
,
function
,
ops
from
theano.tensor
import
opt
from
theano.tensor
import
opt
from
theano.updates
import
OrderedUpdates
from
theano.updates
import
OrderedUpdates
from
theano.gof.utils
import
TestValueError
from
theano.scan_module
import
scan_op
,
scan_utils
from
theano.scan_module
import
scan_op
,
scan_utils
from
theano.scan_module.scan_utils
import
safe_new
,
traverse
from
theano.scan_module.scan_utils
import
safe_new
,
traverse
...
@@ -524,7 +525,7 @@ def scan(
...
@@ -524,7 +525,7 @@ def scan(
if
config
.
compute_test_value
!=
"off"
:
if
config
.
compute_test_value
!=
"off"
:
try
:
try
:
nw_slice
.
tag
.
test_value
=
gof
.
get_test_value
(
_seq_val_slice
)
nw_slice
.
tag
.
test_value
=
gof
.
get_test_value
(
_seq_val_slice
)
except
Attribut
eError
as
e
:
except
TestValu
eError
as
e
:
if
config
.
compute_test_value
!=
"ignore"
:
if
config
.
compute_test_value
!=
"ignore"
:
# No need to print a warning or raise an error now,
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
# it will be done when fn will be called.
...
@@ -656,7 +657,7 @@ def scan(
...
@@ -656,7 +657,7 @@ def scan(
if
config
.
compute_test_value
!=
"off"
:
if
config
.
compute_test_value
!=
"off"
:
try
:
try
:
arg
.
tag
.
test_value
=
gof
.
get_test_value
(
actual_arg
)
arg
.
tag
.
test_value
=
gof
.
get_test_value
(
actual_arg
)
except
Attribut
eError
as
e
:
except
TestValu
eError
as
e
:
if
config
.
compute_test_value
!=
"ignore"
:
if
config
.
compute_test_value
!=
"ignore"
:
# No need to print a warning or raise an error now,
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
# it will be done when fn will be called.
...
@@ -719,7 +720,7 @@ def scan(
...
@@ -719,7 +720,7 @@ def scan(
nw_slice
.
tag
.
test_value
=
gof
.
get_test_value
(
nw_slice
.
tag
.
test_value
=
gof
.
get_test_value
(
_init_out_var_slice
_init_out_var_slice
)
)
except
Attribut
eError
as
e
:
except
TestValu
eError
as
e
:
if
config
.
compute_test_value
!=
"ignore"
:
if
config
.
compute_test_value
!=
"ignore"
:
# No need to print a warning or raise an error now,
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
# it will be done when fn will be called.
...
...
theano/scan_module/scan_utils.py
浏览文件 @
d0458048
...
@@ -33,6 +33,7 @@ from six import string_types
...
@@ -33,6 +33,7 @@ from six import string_types
from
theano
import
gof
,
compat
,
tensor
,
scalar
from
theano
import
gof
,
compat
,
tensor
,
scalar
from
theano.compile.pfunc
import
rebuild_collect_shared
from
theano.compile.pfunc
import
rebuild_collect_shared
from
theano.tensor.basic
import
get_scalar_constant_value
from
theano.tensor.basic
import
get_scalar_constant_value
from
theano.gof.utils
import
TestValueError
# Logging function for sending warning or info
# Logging function for sending warning or info
...
@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
...
@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
# Copy test value, cast it if necessary
# Copy test value, cast it if necessary
try
:
try
:
x_test_value
=
gof
.
op
.
get_test_value
(
x
)
x_test_value
=
gof
.
op
.
get_test_value
(
x
)
except
AttributeError
:
except
TestValueError
:
# There is no test value
pass
pass
else
:
else
:
# This clause is executed if no exception was raised
# This clause is executed if no exception was raised
...
@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
...
@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
if
theano
.
config
.
compute_test_value
!=
"off"
:
if
theano
.
config
.
compute_test_value
!=
"off"
:
try
:
try
:
nw_x
.
tag
.
test_value
=
copy
.
deepcopy
(
gof
.
op
.
get_test_value
(
x
))
nw_x
.
tag
.
test_value
=
copy
.
deepcopy
(
gof
.
op
.
get_test_value
(
x
))
except
AttributeError
:
except
TestValueError
:
# This means `x` has no test value.
pass
pass
return
nw_x
return
nw_x
...
...
theano/tensor/blas.py
浏览文件 @
d0458048
...
@@ -156,6 +156,7 @@ from theano.gof import (
...
@@ -156,6 +156,7 @@ from theano.gof import (
Apply
,
Apply
,
ReplacementDidntRemovedError
,
ReplacementDidntRemovedError
,
)
)
from
theano.gof.utils
import
TestValueError
from
theano.gof.params_type
import
ParamsType
from
theano.gof.params_type
import
ParamsType
from
theano.gof.opt
import
inherit_stack_trace
from
theano.gof.opt
import
inherit_stack_trace
from
theano.printing
import
pprint
,
FunctionPrinter
,
debugprint
from
theano.printing
import
pprint
,
FunctionPrinter
,
debugprint
...
@@ -2497,7 +2498,7 @@ class BatchedDot(Op):
...
@@ -2497,7 +2498,7 @@ class BatchedDot(Op):
if
debugger_available
:
if
debugger_available
:
try
:
try
:
iv0
=
theano
.
gof
.
op
.
get_test_value
(
inputs
[
0
])
iv0
=
theano
.
gof
.
op
.
get_test_value
(
inputs
[
0
])
except
Attribut
eError
:
except
TestValu
eError
:
theano
.
gof
.
op
.
missing_test_message
(
theano
.
gof
.
op
.
missing_test_message
(
"first input passed to BatchedDot.R_op has no test value"
"first input passed to BatchedDot.R_op has no test value"
)
)
...
@@ -2505,7 +2506,7 @@ class BatchedDot(Op):
...
@@ -2505,7 +2506,7 @@ class BatchedDot(Op):
try
:
try
:
iv1
=
theano
.
gof
.
op
.
get_test_value
(
inputs
[
1
])
iv1
=
theano
.
gof
.
op
.
get_test_value
(
inputs
[
1
])
except
Attribut
eError
:
except
TestValu
eError
:
theano
.
gof
.
op
.
missing_test_message
(
theano
.
gof
.
op
.
missing_test_message
(
"second input passed to BatchedDot.R_op has no test value"
"second input passed to BatchedDot.R_op has no test value"
)
)
...
@@ -2514,7 +2515,7 @@ class BatchedDot(Op):
...
@@ -2514,7 +2515,7 @@ class BatchedDot(Op):
if
eval_points
[
0
]:
if
eval_points
[
0
]:
try
:
try
:
ev0
=
theano
.
gof
.
op
.
get_test_value
(
eval_points
[
0
])
ev0
=
theano
.
gof
.
op
.
get_test_value
(
eval_points
[
0
])
except
Attribut
eError
:
except
TestValu
eError
:
theano
.
gof
.
op
.
missing_test_message
(
theano
.
gof
.
op
.
missing_test_message
(
"first eval point passed to BatchedDot.R_op "
"first eval point passed to BatchedDot.R_op "
"has no test value"
"has no test value"
...
@@ -2523,7 +2524,7 @@ class BatchedDot(Op):
...
@@ -2523,7 +2524,7 @@ class BatchedDot(Op):
if
eval_points
[
1
]:
if
eval_points
[
1
]:
try
:
try
:
ev1
=
theano
.
gof
.
op
.
get_test_value
(
eval_points
[
1
])
ev1
=
theano
.
gof
.
op
.
get_test_value
(
eval_points
[
1
])
except
Attribut
eError
:
except
TestValu
eError
:
theano
.
gof
.
op
.
missing_test_message
(
theano
.
gof
.
op
.
missing_test_message
(
"second eval point passed to BatchedDot.R_op "
"second eval point passed to BatchedDot.R_op "
"has no test value"
"has no test value"
...
...
theano/tensor/opt.py
浏览文件 @
d0458048
...
@@ -42,7 +42,7 @@ from theano.gof.opt import (
...
@@ -42,7 +42,7 @@ from theano.gof.opt import (
pre_constant_merge
,
pre_constant_merge
,
pre_greedy_local_optimizer
,
pre_greedy_local_optimizer
,
)
)
from
theano.gof.utils
import
MethodNotDefined
from
theano.gof.utils
import
MethodNotDefined
,
TestValueError
from
theano.gradient
import
DisconnectedType
from
theano.gradient
import
DisconnectedType
from
theano.tensor.elemwise
import
Elemwise
,
DimShuffle
from
theano.tensor.elemwise
import
Elemwise
,
DimShuffle
from
theano.tensor.subtensor
import
(
from
theano.tensor.subtensor
import
(
...
@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
...
@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"Cannot construct a scalar test value"
"Cannot construct a scalar test value"
" from a test value with no size: {}"
.
format
(
ii
)
" from a test value with no size: {}"
.
format
(
ii
)
)
)
except
Attribut
eError
:
except
TestValu
eError
:
pass
pass
tmp_s_input
.
append
(
tmp
)
tmp_s_input
.
append
(
tmp
)
...
@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
...
@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
v
=
gof
.
op
.
get_test_value
(
i
)
v
=
gof
.
op
.
get_test_value
(
i
)
if
v
.
size
>
0
:
if
v
.
size
>
0
:
s
.
tag
.
test_value
=
v
.
flatten
()[
0
]
s
.
tag
.
test_value
=
v
.
flatten
()[
0
]
except
Attribut
eError
:
except
TestValu
eError
:
pass
pass
inputs
.
append
(
i
)
inputs
.
append
(
i
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论