提交 d0458048 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Raise TestValueError instead of AttributeError when test values are missing

上级 a1938f76
......@@ -6,11 +6,11 @@ import theano.tensor as tt
from six import string_types
from theano import scalar, shared, config
from theano.gof import utils
from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable
from theano.gof.type import Generic, Type
from theano.gof.op import Op
from theano.gof.utils import TestValueError, MethodNotDefined
def as_variable(x):
......@@ -175,7 +175,7 @@ class TestMakeThunk:
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
......@@ -211,7 +211,7 @@ class TestMakeThunk:
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
......@@ -227,7 +227,7 @@ class TestMakeThunk:
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
def test_no_make_node(self):
......@@ -326,6 +326,6 @@ def test_get_test_values_success():
def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError):
with pytest.raises(TestValueError):
x = tt.vector()
assert op.get_test_values(x) == []
......@@ -14,7 +14,7 @@ from six.moves import StringIO
from theano import config
from theano.gof import graph, utils, toolbox
from theano.gof.utils import get_variable_trace_string
from theano.gof.utils import get_variable_trace_string, TestValueError
from theano.misc.ordered_set import OrderedSet
NullType = None
......@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError:
except TestValueError:
pass
else:
tval_shape = getattr(tval, "shape", None)
......
......@@ -14,6 +14,7 @@ from six import string_types, integer_types
from theano import config
from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en"
......@@ -405,12 +406,12 @@ class Variable(Node):
Raises
------
AttributeError
TestValueError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise AttributeError(
raise TestValueError(
"{} has no test value {}".format(self, detailed_err_msg)
)
......@@ -436,7 +437,7 @@ class Variable(Node):
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
return repr(theano.gof.op.get_test_value(self))
return repr(self.get_test_value())
def __repr__(self, firstPass=True):
"""Return a repr of the Variable.
......@@ -449,7 +450,7 @@ class Variable(Node):
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
except AttributeError:
except TestValueError:
pass
return "\n".join(to_print)
......
......@@ -22,6 +22,7 @@ from six import PY3
from theano import config
from theano.gof import graph
from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.gof.cmodule import GCC_compiler
from theano.gof.fg import FunctionGraph
......@@ -68,7 +69,7 @@ def compute_test_value(node):
try:
storage_map[ins] = [ins.get_test_value()]
compute_map[ins] = [True]
except AttributeError:
except TestValueError:
# no test-value was specified, act accordingly
if config.compute_test_value == "warn":
warnings.warn(
......@@ -1073,7 +1074,7 @@ def missing_test_message(msg):
"""
action = config.compute_test_value
if action == "raise":
raise AttributeError(msg)
raise TestValueError(msg)
elif action == "warn":
warnings.warn(msg, stacklevel=2)
else:
......@@ -1113,7 +1114,7 @@ def get_test_values(*args):
for i, arg in enumerate(args):
try:
rval.append(get_test_value(arg))
except AttributeError:
except TestValueError:
if hasattr(arg, "name") and arg.name is not None:
missing_test_message(
"Argument {} ('{}') has no test value".format(i, arg.name)
......
......@@ -156,6 +156,12 @@ def hashtype(self):
undef = object()
class TestValueError(Exception):
"""Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......
......@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
from theano.compile import SharedVariable, function, ops
from theano.tensor import opt
from theano.updates import OrderedUpdates
from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse
......@@ -524,7 +525,7 @@ def scan(
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except AttributeError as e:
except TestValueError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
......@@ -656,7 +657,7 @@ def scan(
if config.compute_test_value != "off":
try:
arg.tag.test_value = gof.get_test_value(actual_arg)
except AttributeError as e:
except TestValueError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
......@@ -719,7 +720,7 @@ def scan(
nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice
)
except AttributeError as e:
except TestValueError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
......
......@@ -33,6 +33,7 @@ from six import string_types
from theano import gof, compat, tensor, scalar
from theano.compile.pfunc import rebuild_collect_shared
from theano.tensor.basic import get_scalar_constant_value
from theano.gof.utils import TestValueError
# Logging function for sending warning or info
......@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
# Copy test value, cast it if necessary
try:
x_test_value = gof.op.get_test_value(x)
except AttributeError:
# There is no test value
except TestValueError:
pass
else:
# This clause is executed if no exception was raised
......@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
if theano.config.compute_test_value != "off":
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
except TestValueError:
pass
return nw_x
......
......@@ -156,6 +156,7 @@ from theano.gof import (
Apply,
ReplacementDidntRemovedError,
)
from theano.gof.utils import TestValueError
from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint
......@@ -2497,7 +2498,7 @@ class BatchedDot(Op):
if debugger_available:
try:
iv0 = theano.gof.op.get_test_value(inputs[0])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"first input passed to BatchedDot.R_op has no test value"
)
......@@ -2505,7 +2506,7 @@ class BatchedDot(Op):
try:
iv1 = theano.gof.op.get_test_value(inputs[1])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"second input passed to BatchedDot.R_op has no test value"
)
......@@ -2514,7 +2515,7 @@ class BatchedDot(Op):
if eval_points[0]:
try:
ev0 = theano.gof.op.get_test_value(eval_points[0])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"first eval point passed to BatchedDot.R_op "
"has no test value"
......@@ -2523,7 +2524,7 @@ class BatchedDot(Op):
if eval_points[1]:
try:
ev1 = theano.gof.op.get_test_value(eval_points[1])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"second eval point passed to BatchedDot.R_op "
"has no test value"
......
......@@ -42,7 +42,7 @@ from theano.gof.opt import (
pre_constant_merge,
pre_greedy_local_optimizer,
)
from theano.gof.utils import MethodNotDefined
from theano.gof.utils import MethodNotDefined, TestValueError
from theano.gradient import DisconnectedType
from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (
......@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except AttributeError:
except TestValueError:
pass
tmp_s_input.append(tmp)
......@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = v.flatten()[0]
except AttributeError:
except TestValueError:
pass
inputs.append(i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论