Unverified 提交 3861b86c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #112 from brandonwillard/move-get_test_value-to-classes

Move get_test_value method to the Variable (sub)classes
...@@ -5,14 +5,12 @@ import theano.gof.op as op ...@@ -5,14 +5,12 @@ import theano.gof.op as op
import theano.tensor as tt import theano.tensor as tt
from six import string_types from six import string_types
from theano import scalar, shared from theano import scalar, shared, config
from theano.configparser import change_flags from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable from theano.gof.graph import Apply, Variable
from theano.gof.type import Generic, Type from theano.gof.type import Generic, Type
from theano.gof.op import Op
config = theano.config from theano.gof.utils import TestValueError, MethodNotDefined
Op = op.Op
utils = op.utils
def as_variable(x): def as_variable(x):
...@@ -177,7 +175,7 @@ class TestMakeThunk: ...@@ -177,7 +175,7 @@ class TestMakeThunk:
o = IncOnePython()(i) o = IncOnePython()(i)
# Check that the c_code function is not implemented # Check that the c_code function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)): with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""}) o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]} storage_map = {i: [np.int32(3)], o: [None]}
...@@ -213,7 +211,7 @@ class TestMakeThunk: ...@@ -213,7 +211,7 @@ class TestMakeThunk:
o = IncOneC()(i) o = IncOneC()(i)
# Check that the perform function is not implemented # Check that the perform function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)): with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None]) o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]} storage_map = {i: [np.int32(3)], o: [None]}
...@@ -229,7 +227,7 @@ class TestMakeThunk: ...@@ -229,7 +227,7 @@ class TestMakeThunk:
assert compute_map[o][0] assert compute_map[o][0]
assert storage_map[o][0] == 4 assert storage_map[o][0] == 4
else: else:
with pytest.raises((NotImplementedError, utils.MethodNotDefined)): with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk() thunk()
def test_no_make_node(self): def test_no_make_node(self):
...@@ -288,23 +286,23 @@ def test_test_value_op(): ...@@ -288,23 +286,23 @@ def test_test_value_op():
@change_flags(compute_test_value="off") @change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger(): def test_get_test_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off.""" """Tests that `get_test_values` returns `[]` when debugger is off."""
x = tt.vector() x = tt.vector()
assert op.get_debug_values(x) == [] assert op.get_test_values(x) == []
@change_flags(compute_test_value="ignore") @change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore(): def test_get_test_values_ignore():
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing.""" """Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
x = tt.vector() x = tt.vector()
assert op.get_debug_values(x) == [] assert op.get_test_values(x) == []
def test_get_debug_values_success(): def test_get_test_values_success():
"""Tests that `get_debug_value` returns values when available (and the debugger is on).""" """Tests that `get_test_values` returns values when available (and the debugger is on)."""
for mode in ["ignore", "warn", "raise"]: for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode): with change_flags(compute_test_value=mode):
...@@ -314,7 +312,7 @@ def test_get_debug_values_success(): ...@@ -314,7 +312,7 @@ def test_get_debug_values_success():
iters = 0 iters = 0
for x_val, y_val in op.get_debug_values(x, y): for x_val, y_val in op.get_test_values(x, y):
assert x_val.shape == (4,) assert x_val.shape == (4,)
assert y_val.shape == (5, 5) assert y_val.shape == (5, 5)
...@@ -325,9 +323,9 @@ def test_get_debug_values_success(): ...@@ -325,9 +323,9 @@ def test_get_debug_values_success():
@change_flags(compute_test_value="raise") @change_flags(compute_test_value="raise")
def test_get_debug_values_exc(): def test_get_test_values_exc():
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing.""" """Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError): with pytest.raises(TestValueError):
x = tt.vector() x = tt.vector()
assert op.get_debug_values(x) == [] assert op.get_test_values(x) == []
...@@ -2799,11 +2799,11 @@ class TestAsTensorVariable: ...@@ -2799,11 +2799,11 @@ class TestAsTensorVariable:
as_tensor_variable(good_apply_var) as_tensor_variable(good_apply_var)
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x) bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
with pytest.raises(AttributeError): with pytest.raises(ValueError):
_ = as_tensor_variable(bad_apply_var) _ = as_tensor_variable(bad_apply_var)
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x) bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
with pytest.raises(AttributeError): with pytest.raises(ValueError):
_ = as_tensor_variable(bad_apply_var) _ = as_tensor_variable(bad_apply_var)
def test_list(self): def test_list(self):
...@@ -2816,7 +2816,7 @@ class TestAsTensorVariable: ...@@ -2816,7 +2816,7 @@ class TestAsTensorVariable:
_ = as_tensor_variable(y) _ = as_tensor_variable(y)
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x) bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(AttributeError): with pytest.raises(ValueError):
as_tensor_variable(bad_apply_var) as_tensor_variable(bad_apply_var)
def test_strip_leading_broadcastable(self): def test_strip_leading_broadcastable(self):
......
...@@ -172,28 +172,26 @@ else: ...@@ -172,28 +172,26 @@ else:
np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid) np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid)
del _all, _divide, _over, _under, _invalid del _all, _divide, _over, _under, _invalid
# This is defined here because it is designed to work across symbolic
# datatypes (Sparse and Tensor)
def dot(l, r): def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """ """Return a symbolic dot product.
rval = NotImplemented
e0, e1 = None, None
if rval == NotImplemented and hasattr(l, "__dot__"): This is designed to work with both sparse and dense tensors types.
try: """
rval = l.__dot__(r)
except Exception as e0:
rval = NotImplemented
if rval == NotImplemented and hasattr(r, "__rdot__"):
try: try:
rval = r.__rdot__(l) res = l.__dot__(r)
except Exception as e1:
rval = NotImplemented if res is NotImplemented:
if rval == NotImplemented: raise NotImplementedError()
raise NotImplementedError("Dot failed for the following reasons:", (e0, e1))
return rval return res
except (NotImplementedError, AttributeError, TypeError):
res = r.__rdot__(l)
if res is NotImplemented:
raise NotImplementedError()
return res
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
......
...@@ -150,6 +150,9 @@ class SharedVariable(Variable): ...@@ -150,6 +150,9 @@ class SharedVariable(Variable):
else: else:
self.container.value = copy.deepcopy(new_value) self.container.value = copy.deepcopy(new_value)
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False): def zero(self, borrow=False):
""" """
Set the values of a shared variable to 0. Set the values of a shared variable to 0.
......
...@@ -58,7 +58,14 @@ from theano.gof.link import ( ...@@ -58,7 +58,14 @@ from theano.gof.link import (
WrapLinkerMany, WrapLinkerMany,
) )
from theano.gof.op import Op, OpenMPOp, PureOp, COp, ops_with_inner_function from theano.gof.op import (
Op,
OpenMPOp,
PureOp,
COp,
ops_with_inner_function,
get_test_value,
)
from theano.gof.type import EnumType, EnumList, CEnumType from theano.gof.type import EnumType, EnumList, CEnumType
......
...@@ -14,7 +14,7 @@ from six.moves import StringIO ...@@ -14,7 +14,7 @@ from six.moves import StringIO
from theano import config from theano import config
from theano.gof import graph, utils, toolbox from theano.gof import graph, utils, toolbox
from theano.gof.utils import get_variable_trace_string from theano.gof.utils import get_variable_trace_string, TestValueError
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
NullType = None NullType = None
...@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2): ...@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
try: try:
tval = theano.gof.op.get_test_value(r) tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r) new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError: except TestValueError:
pass pass
else: else:
tval_shape = getattr(tval, "shape", None) tval_shape = getattr(tval, "shape", None)
......
...@@ -14,6 +14,7 @@ from six import string_types, integer_types ...@@ -14,6 +14,7 @@ from six import string_types, integer_types
from theano import config from theano import config
from theano.gof import utils from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -164,13 +165,13 @@ class Apply(Node): ...@@ -164,13 +165,13 @@ class Apply(Node):
if len(self.outputs) == 1: if len(self.outputs) == 1:
return self.outputs[0] return self.outputs[0]
else: else:
raise AttributeError( raise ValueError(
"%s.default_output should be an output index." % self.op "%s.default_output should be an output index." % self.op
) )
elif not isinstance(do, integer_types): elif not isinstance(do, integer_types):
raise AttributeError("%s.default_output should be an int or long" % self.op) raise ValueError("%s.default_output should be an int or long" % self.op)
elif do < 0 or do >= len(self.outputs): elif do < 0 or do >= len(self.outputs):
raise AttributeError("%s.default_output is out of range." % self.op) raise ValueError("%s.default_output is out of range." % self.op)
return self.outputs[do] return self.outputs[do]
out = property(default_output, doc="alias for self.default_output()") out = property(default_output, doc="alias for self.default_output()")
...@@ -383,19 +384,39 @@ class Variable(Node): ...@@ -383,19 +384,39 @@ class Variable(Node):
self.tag = utils.ValidatingScratchpad("test_value", type.filter) self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
self.owner = owner self.owner = owner
if index is not None and not isinstance(index, integer_types): if index is not None and not isinstance(index, integer_types):
raise TypeError("index must be an int", index) raise TypeError("index must be an int", index)
self.index = index self.index = index
if name is not None and not isinstance(name, string_types): if name is not None and not isinstance(name, string_types):
raise TypeError("name must be a string", name) raise TypeError("name must be a string", name)
self.name = name self.name = name
self.auto_name = "auto_" + str(next(self.__count__)) self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self) Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
Raises
------
TestValueError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise TestValueError(
"{} has no test value {}".format(self, detailed_err_msg)
)
return self.tag.test_value
def __str__(self): def __str__(self):
"""Return a str representation of the Variable.""" """Return a str representation of the Variable."""
if self.name is not None: if self.name is not None:
...@@ -416,7 +437,7 @@ class Variable(Node): ...@@ -416,7 +437,7 @@ class Variable(Node):
overridden by classes with non printable test_value to provide a overridden by classes with non printable test_value to provide a
suitable representation of the test_value. suitable representation of the test_value.
""" """
return repr(theano.gof.op.get_test_value(self)) return repr(self.get_test_value())
def __repr__(self, firstPass=True): def __repr__(self, firstPass=True):
"""Return a repr of the Variable. """Return a repr of the Variable.
...@@ -429,7 +450,7 @@ class Variable(Node): ...@@ -429,7 +450,7 @@ class Variable(Node):
if config.print_test_value and firstPass: if config.print_test_value and firstPass:
try: try:
to_print.append(self.__repr_test_value__()) to_print.append(self.__repr_test_value__())
except AttributeError: except TestValueError:
pass pass
return "\n".join(to_print) return "\n".join(to_print)
...@@ -583,6 +604,9 @@ class Constant(Variable): ...@@ -583,6 +604,9 @@ class Constant(Variable):
self.data = type.filter(data) self.data = type.filter(data)
utils.add_tag_trace(self) utils.add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id # this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature() return isinstance(other, Constant) and self.signature() == other.signature()
......
差异被折叠。
...@@ -156,6 +156,12 @@ def hashtype(self): ...@@ -156,6 +156,12 @@ def hashtype(self):
undef = object() undef = object()
class TestValueError(Exception):
"""Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
......
...@@ -15,7 +15,7 @@ from theano import gof ...@@ -15,7 +15,7 @@ from theano import gof
from theano.gof import utils, Variable from theano.gof import utils, Variable
from theano.gof.null_type import NullType, null_type from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values from theano.gof.op import get_test_values
from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow" __authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
...@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
continue continue
if isinstance(new_output_grad.type, DisconnectedType): if isinstance(new_output_grad.type, DisconnectedType):
continue continue
for orig_output_v, new_output_grad_v in get_debug_values(*packed): for orig_output_v, new_output_grad_v in get_test_values(*packed):
o_shape = orig_output_v.shape o_shape = orig_output_v.shape
g_shape = new_output_grad_v.shape g_shape = new_output_grad_v.shape
if o_shape != g_shape: if o_shape != g_shape:
...@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape # has the right shape
if hasattr(term, "shape"): if hasattr(term, "shape"):
orig_ipt = inputs[i] orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term): for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape i_shape = orig_ipt_v.shape
t_shape = term_v.shape t_shape = term_v.shape
if i_shape != t_shape: if i_shape != t_shape:
......
...@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config ...@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
from theano.compile import SharedVariable, function, ops from theano.compile import SharedVariable, function, ops
from theano.tensor import opt from theano.tensor import opt
from theano.updates import OrderedUpdates from theano.updates import OrderedUpdates
from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse from theano.scan_module.scan_utils import safe_new, traverse
...@@ -523,18 +524,17 @@ def scan( ...@@ -523,18 +524,17 @@ def scan(
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice) nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except AttributeError as e: except TestValueError:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, # No need to print a warning or raise an error now,
# it will be done when fn will be called. # it will be done when fn will be called.
_logger.info( _logger.warning(
( (
"Cannot compute test value for " "Cannot compute test value for "
"the inner function of scan, input value " "the inner function of scan, input value "
"missing %s" "missing {}"
), ).format(_seq_val_slice)
e,
) )
# Add names to slices for debugging and pretty printing .. # Add names to slices for debugging and pretty printing ..
...@@ -655,17 +655,14 @@ def scan( ...@@ -655,17 +655,14 @@ def scan(
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg) arg.tag.test_value = gof.get_test_value(actual_arg)
except AttributeError as e: except TestValueError:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, _logger.warning(
# it will be done when fn will be called.
_logger.info(
( (
"Cannot compute test value for the " "Cannot compute test value for the "
"inner function of scan, input value missing %s" "inner function of scan, test value missing: {}"
), ).format(actual_arg)
e,
) )
if getattr(init_out["initial"], "name", None) is not None: if getattr(init_out["initial"], "name", None) is not None:
...@@ -716,20 +713,17 @@ def scan( ...@@ -716,20 +713,17 @@ def scan(
# Try to transfer test_value to the new variable # Try to transfer test_value to the new variable
if config.compute_test_value != "off": if config.compute_test_value != "off":
try: try:
nw_slice.tag.test_value = gof.Op._get_test_value( nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice _init_out_var_slice
) )
except AttributeError as e: except TestValueError:
if config.compute_test_value != "ignore": if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now, _logger.warning(
# it will be done when fn will be called.
_logger.info(
( (
"Cannot compute test value for " "Cannot compute test value for "
"the inner function of scan, input value " "the inner function of scan, test value "
"missing. %s" "missing: {}"
), ).format(_init_out_var_slice)
e,
) )
# give it a name or debugging and pretty printing # give it a name or debugging and pretty printing
...@@ -808,10 +802,8 @@ def scan( ...@@ -808,10 +802,8 @@ def scan(
_logger.warning( _logger.warning(
( (
"When the number of steps is fixed and equal " "When the number of steps is fixed and equal "
"to 1, the provided stopping condition, ", "to 1, the provided stopping condition, {} is ignored",
str(condition), ).format(condition)
" is ignored",
)
) )
for pos, inner_out in enumerate(outputs): for pos, inner_out in enumerate(outputs):
......
...@@ -33,6 +33,7 @@ from six import string_types ...@@ -33,6 +33,7 @@ from six import string_types
from theano import gof, compat, tensor, scalar from theano import gof, compat, tensor, scalar
from theano.compile.pfunc import rebuild_collect_shared from theano.compile.pfunc import rebuild_collect_shared
from theano.tensor.basic import get_scalar_constant_value from theano.tensor.basic import get_scalar_constant_value
from theano.gof.utils import TestValueError
# Logging function for sending warning or info # Logging function for sending warning or info
...@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None): ...@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
# Copy test value, cast it if necessary # Copy test value, cast it if necessary
try: try:
x_test_value = gof.op.get_test_value(x) x_test_value = gof.op.get_test_value(x)
except AttributeError: except TestValueError:
# There is no test value
pass pass
else: else:
# This clause is executed if no exception was raised # This clause is executed if no exception was raised
...@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None): ...@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
if theano.config.compute_test_value != "off": if theano.config.compute_test_value != "off":
try: try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x)) nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError: except TestValueError:
# This means `x` has no test value.
pass pass
return nw_x return nw_x
......
...@@ -156,6 +156,7 @@ from theano.gof import ( ...@@ -156,6 +156,7 @@ from theano.gof import (
Apply, Apply,
ReplacementDidntRemovedError, ReplacementDidntRemovedError,
) )
from theano.gof.utils import TestValueError
from theano.gof.params_type import ParamsType from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint from theano.printing import pprint, FunctionPrinter, debugprint
...@@ -2492,45 +2493,45 @@ class BatchedDot(Op): ...@@ -2492,45 +2493,45 @@ class BatchedDot(Op):
if eval_points[0] is None and eval_points[1] is None: if eval_points[0] is None and eval_points[1] is None:
return [None] return [None]
debugger_available = config.compute_test_value != "off" test_values_enabled = config.compute_test_value != "off"
if debugger_available: if test_values_enabled:
try: try:
iv0 = theano.gof.op.get_test_value(inputs[0]) iv0 = theano.gof.op.get_test_value(inputs[0])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"first input passed to BatchedDot.R_op has no test value" "first input passed to BatchedDot.R_op has no test value"
) )
debugger_available = False test_values_enabled = False
try: try:
iv1 = theano.gof.op.get_test_value(inputs[1]) iv1 = theano.gof.op.get_test_value(inputs[1])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"second input passed to BatchedDot.R_op has no test value" "second input passed to BatchedDot.R_op has no test value"
) )
debugger_available = False test_values_enabled = False
if eval_points[0]: if eval_points[0]:
try: try:
ev0 = theano.gof.op.get_test_value(eval_points[0]) ev0 = theano.gof.op.get_test_value(eval_points[0])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"first eval point passed to BatchedDot.R_op " "first eval point passed to BatchedDot.R_op "
"has no test value" "has no test value"
) )
debugger_available = False test_values_enabled = False
if eval_points[1]: if eval_points[1]:
try: try:
ev1 = theano.gof.op.get_test_value(eval_points[1]) ev1 = theano.gof.op.get_test_value(eval_points[1])
except AttributeError: except TestValueError:
theano.gof.op.missing_test_message( theano.gof.op.missing_test_message(
"second eval point passed to BatchedDot.R_op " "second eval point passed to BatchedDot.R_op "
"has no test value" "has no test value"
) )
debugger_available = False test_values_enabled = False
if debugger_available: if test_values_enabled:
input_values = [iv0, iv1] input_values = [iv0, iv1]
eval_point_values = [ev0, ev1] eval_point_values = [ev0, ev1]
......
...@@ -42,7 +42,7 @@ from theano.gof.opt import ( ...@@ -42,7 +42,7 @@ from theano.gof.opt import (
pre_constant_merge, pre_constant_merge,
pre_greedy_local_optimizer, pre_greedy_local_optimizer,
) )
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined, TestValueError
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.tensor.elemwise import Elemwise, DimShuffle from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import ( from theano.tensor.subtensor import (
...@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"Cannot construct a scalar test value" "Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii) " from a test value with no size: {}".format(ii)
) )
except AttributeError: except TestValueError:
pass pass
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
...@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None ...@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
v = gof.op.get_test_value(i) v = gof.op.get_test_value(i)
if v.size > 0: if v.size > 0:
s.tag.test_value = v.flatten()[0] s.tag.test_value = v.flatten()[0]
except AttributeError: except TestValueError:
pass pass
inputs.append(i) inputs.append(i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论