提交 d0458048 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Raise TestValueError instead of AttributeError when test values are missing

上级 a1938f76
...@@ -6,11 +6,11 @@ import theano.tensor as tt ...@@ -6,11 +6,11 @@ import theano.tensor as tt
from six import string_types from six import string_types
from theano import scalar, shared, config from theano import scalar, shared, config
from theano.gof import utils
from theano.configparser import change_flags from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
from theano.gof.type import Generic, Type from theano.gof.type import Generic, Type
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.utils import TestValueError, MethodNotDefined
def as_variable(x): def as_variable(x):
...@@ -175,7 +175,7 @@ class TestMakeThunk: ...@@ -175,7 +175,7 @@ class TestMakeThunk:
o = IncOnePython()(i) o = IncOnePython()(i)
# Check that the c_code function is not implemented # Check that the c_code function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)): with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""}) o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]} storage_map = {i: [np.int32(3)], o: [None]}
...@@ -211,7 +211,7 @@ class TestMakeThunk: ...@@ -211,7 +211,7 @@ class TestMakeThunk:
o = IncOneC()(i) o = IncOneC()(i)
# Check that the perform function is not implemented # Check that the perform function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)): with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None]) o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]} storage_map = {i: [np.int32(3)], o: [None]}
...@@ -227,7 +227,7 @@ class TestMakeThunk: ...@@ -227,7 +227,7 @@ class TestMakeThunk:
assert compute_map[o][0] assert compute_map[o][0]
assert storage_map[o][0] == 4 assert storage_map[o][0] == 4
else: else:
with pytest.raises((NotImplementedError, utils.MethodNotDefined)): with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk() thunk()
def test_no_make_node(self): def test_no_make_node(self):
...@@ -326,6 +326,6 @@ def test_get_test_values_success(): ...@@ -326,6 +326,6 @@ def test_get_test_values_success():
def test_get_test_values_exc(): def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing.""" """Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError): with pytest.raises(TestValueError):
x = tt.vector() x = tt.vector()
assert op.get_test_values(x) == [] assert op.get_test_values(x) == []
...@@ -14,7 +14,7 @@ from six.moves import StringIO ...@@ -14,7 +14,7 @@ from six.moves import StringIO
from theano import config from theano import config
from theano.gof import graph, utils, toolbox from theano.gof import graph, utils, toolbox
from theano.gof.utils import get_variable_trace_string from theano.gof.utils import get_variable_trace_string, TestValueError
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
...@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2): ...@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
try: try:
tval = theano.gof.op.get_test_value(r) tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r) new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError: except TestValueError:
pass pass
else: else:
tval_shape = getattr(tval, "shape", None) tval_shape = getattr(tval, "shape", None)
......
...@@ -14,6 +14,7 @@ from six import string_types, integer_types ...@@ -14,6 +14,7 @@ from six import string_types, integer_types
from theano import config from theano import config
from theano.gof import utils from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -405,12 +406,12 @@ class Variable(Node): ...@@ -405,12 +406,12 @@ class Variable(Node):
Raises Raises
------ ------
AttributeError TestValueError
""" """
if not hasattr(self.tag, "test_value"): if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self) detailed_err_msg = utils.get_variable_trace_string(self)
raise AttributeError( raise TestValueError(
"{} has no test value {}".format(self, detailed_err_msg) "{} has no test value {}".format(self, detailed_err_msg)
) )
...@@ -436,7 +437,7 @@ class Variable(Node): ...@@ -436,7 +437,7 @@ class Variable(Node):
overridden by classes with non printable test_value to provide a overridden by classes with non printable test_value to provide a
suitable representation of the test_value. suitable representation of the test_value.
""" """
return repr(theano.gof.op.get_test_value(self)) return repr(self.get_test_value())
def __repr__(self, firstPass=True): def __repr__(self, firstPass=True):
"""Return a repr of the Variable. """Return a repr of the Variable.
...@@ -449,7 +450,7 @@ class Variable(Node): ...@@ -449,7 +450,7 @@ class Variable(Node):
if config.print_test_value and firstPass: if config.print_test_value and firstPass:
try: try:
to_print.append(self.__repr_test_value__()) to_print.append(self.__repr_test_value__())
except AttributeError: except TestValueError:
pass pass
return "\n".join(to_print) return "\n".join(to_print)
......
...@@ -22,6 +22,7 @@ from six import PY3 ...@@ -22,6 +22,7 @@ from six import PY3
from theano import config from theano import config
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.gof.cmodule import GCC_compiler from theano.gof.cmodule import GCC_compiler
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
...@@ -68,7 +69,7 @@ def compute_test_value(node): ...@@ -68,7 +69,7 @@ def compute_test_value(node):
try: try:
storage_map[ins] = [ins.get_test_value()] storage_map[ins] = [ins.get_test_value()]
compute_map[ins] = [True] compute_map[ins] = [True]
except AttributeError: except TestValueError:
# no test-value was specified, act accordingly # no test-value was specified, act accordingly
if config.compute_test_value == "warn": if config.compute_test_value == "warn":
warnings.warn( warnings.warn(
...@@ -1073,7 +1074,7 @@ def missing_test_message(msg): ...@@ -1073,7 +1074,7 @@ def missing_test_message(msg):
""" """
action = config.compute_test_value action = config.compute_test_value
if action == "raise": if action == "raise":
raise AttributeError(msg) raise TestValueError(msg)
elif action == "warn": elif action == "warn":
warnings.warn(msg, stacklevel=2) warnings.warn(msg, stacklevel=2)
else: else:
...@@ -1113,7 +1114,7 @@ def get_test_values(*args): ...@@ -1113,7 +1114,7 @@ def get_test_values(*args):
for i, arg in enumerate(args): for i, arg in enumerate(args):
try: try:
rval.append(get_test_value(arg)) rval.append(get_test_value(arg))
except AttributeError: except TestValueError:
if hasattr(arg, "name") and arg.name is not None: if hasattr(arg, "name") and arg.name is not None:
missing_test_message( missing_test_message(
"Argument {} ('{}') has no test value".format(i, arg.name) "Argument {} ('{}') has no test value".format(i, arg.name)
......
...@@ -156,6 +156,12 @@ def hashtype(self): ...@@ -156,6 +156,12 @@ def hashtype(self):
undef = object() undef = object()
class TestValueError(Exception):
"""Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
......
...@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config ...@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
from theano.compile import SharedVariable, function, ops from theano.compile import SharedVariable, function, ops
from theano.tensor import opt from theano.tensor import opt
from theano.updates import OrderedUpdates from theano.updates import OrderedUpdates
from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse from theano.scan_module.scan_utils import safe_new, traverse
...@@ -524,7 +525,7 @@ def scan( ...@@ -524,7 +525,7 @@ def scan(
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice) nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except AttributeError as e: except TestValueError as e:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
...@@ -656,7 +657,7 @@ def scan( ...@@ -656,7 +657,7 @@ def scan(
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
arg.tag.test_value = gof.get_test_value(actual_arg) arg.tag.test_value = gof.get_test_value(actual_arg)
except AttributeError as e: except TestValueError as e:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
...@@ -719,7 +720,7 @@ def scan( ...@@ -719,7 +720,7 @@ def scan(
nw_slice.tag.test_value = gof.get_test_value( nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice _init_out_var_slice
) )
except AttributeError as e: except TestValueError as e:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
......
...@@ -33,6 +33,7 @@ from six import string_types ...@@ -33,6 +33,7 @@ from six import string_types
from theano import gof, compat, tensor, scalar from theano import gof, compat, tensor, scalar
from theano.compile.pfunc import rebuild_collect_shared from theano.compile.pfunc import rebuild_collect_shared
from theano.tensor.basic import get_scalar_constant_value from theano.tensor.basic import get_scalar_constant_value
from theano.gof.utils import TestValueError
# Logging function for sending warning or info # Logging function for sending warning or info
...@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None): ...@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
# Copy test value, cast it if necessary # Copy test value, cast it if necessary
try: try:
x_test_value = gof.op.get_test_value(x) x_test_value = gof.op.get_test_value(x)
except AttributeError: except TestValueError:
# There is no test value
pass pass
else: else:
# This clause is executed if no exception was raised # This clause is executed if no exception was raised
...@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None): ...@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
if theano.config.compute_test_value != "off": if theano.config.compute_test_value != "off":
try: try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x)) nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError: except TestValueError:
# This means `x` has no test value.
pass pass
return nw_x return nw_x
......
...@@ -156,6 +156,7 @@ from theano.gof import ( ...@@ -156,6 +156,7 @@ from theano.gof import (
Apply, Apply,
ReplacementDidntRemovedError, ReplacementDidntRemovedError,
) )
from theano.gof.utils import TestValueError
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
...@@ -2497,7 +2498,7 @@ class BatchedDot(Op): ...@@ -2497,7 +2498,7 @@ class BatchedDot(Op):
if debugger_available: if debugger_available:
try: try:
iv0 = theano.gof.op.get_test_value(inputs[0]) iv0 = theano.gof.op.get_test_value(inputs[0])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"first input passed to BatchedDot.R_op has no test value" "first input passed to BatchedDot.R_op has no test value"
) )
...@@ -2505,7 +2506,7 @@ class BatchedDot(Op): ...@@ -2505,7 +2506,7 @@ class BatchedDot(Op):
try: try:
iv1 = theano.gof.op.get_test_value(inputs[1]) iv1 = theano.gof.op.get_test_value(inputs[1])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"second input passed to BatchedDot.R_op has no test value" "second input passed to BatchedDot.R_op has no test value"
) )
...@@ -2514,7 +2515,7 @@ class BatchedDot(Op): ...@@ -2514,7 +2515,7 @@ class BatchedDot(Op):
if eval_points[0]: if eval_points[0]:
try: try:
ev0 = theano.gof.op.get_test_value(eval_points[0]) ev0 = theano.gof.op.get_test_value(eval_points[0])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"first eval point passed to BatchedDot.R_op " "first eval point passed to BatchedDot.R_op "
"has no test value" "has no test value"
...@@ -2523,7 +2524,7 @@ class BatchedDot(Op): ...@@ -2523,7 +2524,7 @@ class BatchedDot(Op):
if eval_points[1]: if eval_points[1]:
try: try:
ev1 = theano.gof.op.get_test_value(eval_points[1]) ev1 = theano.gof.op.get_test_value(eval_points[1])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"second eval point passed to BatchedDot.R_op " "second eval point passed to BatchedDot.R_op "
"has no test value" "has no test value"
......
...@@ -42,7 +42,7 @@ from theano.gof.opt import ( ...@@ -42,7 +42,7 @@ from theano.gof.opt import (
pre_constant_merge, pre_constant_merge,
pre_greedy_local_optimizer, pre_greedy_local_optimizer,
) )
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined, TestValueError
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import ( from theano.tensor.subtensor import (
...@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"Cannot construct a scalar test value" "Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii) " from a test value with no size: {}".format(ii)
) )
except AttributeError: except TestValueError:
pass pass
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
...@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
v = gof.op.get_test_value(i) v = gof.op.get_test_value(i)
if v.size > 0: if v.size > 0:
s.tag.test_value = v.flatten()[0] s.tag.test_value = v.flatten()[0]
except AttributeError: except TestValueError:
pass pass
inputs.append(i) inputs.append(i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论