Unverified 提交 3861b86c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #112 from brandonwillard/move-get_test_value-to-classes

Move get_test_value method to the Variable (sub)classes
......@@ -5,14 +5,12 @@ import theano.gof.op as op
import theano.tensor as tt
from six import string_types
from theano import scalar, shared
from theano import scalar, shared, config
from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable
from theano.gof.type import Generic, Type
config = theano.config
Op = op.Op
utils = op.utils
from theano.gof.op import Op
from theano.gof.utils import TestValueError, MethodNotDefined
def as_variable(x):
......@@ -177,7 +175,7 @@ class TestMakeThunk:
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
......@@ -213,7 +211,7 @@ class TestMakeThunk:
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
......@@ -229,7 +227,7 @@ class TestMakeThunk:
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
def test_no_make_node(self):
......@@ -288,23 +286,23 @@ def test_test_value_op():
@change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
def test_get_test_values_no_debugger():
"""Tests that `get_test_values` returns `[]` when debugger is off."""
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore():
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
def test_get_test_values_ignore():
"""Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
def test_get_debug_values_success():
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
def test_get_test_values_success():
"""Tests that `get_test_values` returns values when available (and the debugger is on)."""
for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode):
......@@ -314,7 +312,7 @@ def test_get_debug_values_success():
iters = 0
for x_val, y_val in op.get_debug_values(x, y):
for x_val, y_val in op.get_test_values(x, y):
assert x_val.shape == (4,)
assert y_val.shape == (5, 5)
......@@ -325,9 +323,9 @@ def test_get_debug_values_success():
@change_flags(compute_test_value="raise")
def test_get_debug_values_exc():
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError):
with pytest.raises(TestValueError):
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
......@@ -2799,11 +2799,11 @@ class TestAsTensorVariable:
as_tensor_variable(good_apply_var)
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
_ = as_tensor_variable(bad_apply_var)
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
_ = as_tensor_variable(bad_apply_var)
def test_list(self):
......@@ -2816,7 +2816,7 @@ class TestAsTensorVariable:
_ = as_tensor_variable(y)
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
as_tensor_variable(bad_apply_var)
def test_strip_leading_broadcastable(self):
......
......@@ -172,28 +172,26 @@ else:
np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid)
del _all, _divide, _over, _under, _invalid
# This is defined here because it is designed to work across symbolic
# datatypes (Sparse and Tensor)
def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """
rval = NotImplemented
e0, e1 = None, None
if rval == NotImplemented and hasattr(l, "__dot__"):
try:
rval = l.__dot__(r)
except Exception as e0:
rval = NotImplemented
if rval == NotImplemented and hasattr(r, "__rdot__"):
try:
rval = r.__rdot__(l)
except Exception as e1:
rval = NotImplemented
if rval == NotImplemented:
raise NotImplementedError("Dot failed for the following reasons:", (e0, e1))
return rval
"""Return a symbolic dot product.
This is designed to work with both sparse and dense tensors types.
"""
try:
res = l.__dot__(r)
if res is NotImplemented:
raise NotImplementedError()
return res
except (NotImplementedError, AttributeError, TypeError):
res = r.__rdot__(l)
if res is NotImplemented:
raise NotImplementedError()
return res
def get_scalar_constant_value(v):
......
......@@ -150,6 +150,9 @@ class SharedVariable(Variable):
else:
self.container.value = copy.deepcopy(new_value)
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
......
......@@ -58,7 +58,14 @@ from theano.gof.link import (
WrapLinkerMany,
)
from theano.gof.op import Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.op import (
Op,
OpenMPOp,
PureOp,
COp,
ops_with_inner_function,
get_test_value,
)
from theano.gof.type import EnumType, EnumList, CEnumType
......
......@@ -14,7 +14,7 @@ from six.moves import StringIO
from theano import config
from theano.gof import graph, utils, toolbox
from theano.gof.utils import get_variable_trace_string
from theano.gof.utils import get_variable_trace_string, TestValueError
from theano.misc.ordered_set import OrderedSet
NullType = None
......@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError:
except TestValueError:
pass
else:
tval_shape = getattr(tval, "shape", None)
......
......@@ -14,6 +14,7 @@ from six import string_types, integer_types
from theano import config
from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en"
......@@ -164,13 +165,13 @@ class Apply(Node):
if len(self.outputs) == 1:
return self.outputs[0]
else:
raise AttributeError(
raise ValueError(
"%s.default_output should be an output index." % self.op
)
elif not isinstance(do, integer_types):
raise AttributeError("%s.default_output should be an int or long" % self.op)
raise ValueError("%s.default_output should be an int or long" % self.op)
elif do < 0 or do >= len(self.outputs):
raise AttributeError("%s.default_output is out of range." % self.op)
raise ValueError("%s.default_output is out of range." % self.op)
return self.outputs[do]
out = property(default_output, doc="alias for self.default_output()")
......@@ -383,19 +384,39 @@ class Variable(Node):
self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
if index is not None and not isinstance(index, integer_types):
raise TypeError("index must be an int", index)
self.index = index
if name is not None and not isinstance(name, string_types):
raise TypeError("name must be a string", name)
self.name = name
self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
Raises
------
TestValueError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise TestValueError(
"{} has no test value {}".format(self, detailed_err_msg)
)
return self.tag.test_value
def __str__(self):
"""Return a str representation of the Variable."""
if self.name is not None:
......@@ -416,7 +437,7 @@ class Variable(Node):
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
return repr(theano.gof.op.get_test_value(self))
return repr(self.get_test_value())
def __repr__(self, firstPass=True):
"""Return a repr of the Variable.
......@@ -429,7 +450,7 @@ class Variable(Node):
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
except AttributeError:
except TestValueError:
pass
return "\n".join(to_print)
......@@ -583,6 +604,9 @@ class Constant(Variable):
self.data = type.filter(data)
utils.add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
......
......@@ -22,6 +22,7 @@ from six import PY3
from theano import config
from theano.gof import graph
from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.gof.cmodule import GCC_compiler
from theano.gof.fg import FunctionGraph
......@@ -47,6 +48,89 @@ else:
return open(file, "U")
def compute_test_value(node):
"""Computes the test value of a node.
Parameters
----------
node : Apply
The `Apply` node for which the test value is computed.
Returns
-------
None
The `tag.test_value`s are updated in each `Variable` in `node.outputs`.
"""
# Gather the test values for each input of the node
storage_map = {}
compute_map = {}
for i, ins in enumerate(node.inputs):
try:
storage_map[ins] = [ins.get_test_value()]
compute_map[ins] = [True]
except TestValueError:
# no test-value was specified, act accordingly
if config.compute_test_value == "warn":
warnings.warn(
"Warning, Cannot compute test value: input %i (%s) of Op %s missing default value"
% (i, ins, node),
stacklevel=2,
)
return
elif config.compute_test_value == "raise":
detailed_err_msg = utils.get_variable_trace_string(ins)
raise ValueError(
"Cannot compute test value: input %i (%s) of Op %s missing default value. %s"
% (i, ins, node, detailed_err_msg)
)
elif config.compute_test_value == "ignore":
return
elif config.compute_test_value == "pdb":
import pdb
pdb.post_mortem(sys.exc_info()[2])
else:
raise ValueError(
"%s is invalid for option config.compute_test_value"
% config.compute_test_value
)
# All inputs have test-values; perform the `Op`'s computation
# The original values should not be destroyed, so we copy the values of the
# inputs in `destroy_map`
destroyed_inputs_idx = set()
if getattr(node.op, "destroy_map", None):
for i_pos_list in node.op.destroy_map.values():
destroyed_inputs_idx.update(i_pos_list)
for inp_idx in destroyed_inputs_idx:
inp = node.inputs[inp_idx]
storage_map[inp] = [storage_map[inp][0].copy()]
# Prepare `storage_map` and `compute_map` for the outputs
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
# Create a thunk that performs the computation
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk()
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0])
# Add 'test_value' to output tag, so that downstream `Op`s can use
# these numerical values as test values
output.tag.test_value = storage_map[output][0]
class CLinkerObject(object):
"""
Standard elements of an Op or Type used with the CLinker.
......@@ -529,31 +613,6 @@ class PureOp(object):
"""
raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__)
@classmethod
def _get_test_value(cls, v):
"""
Extract test value from variable v.
Raises AttributeError if there is none.
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
if isinstance(v, graph.Constant):
return v.value
elif isinstance(v, SharedVariable):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, "test_value"):
return v.tag.test_value
detailed_err_msg = utils.get_variable_trace_string(v)
raise AttributeError("%s has no test value %s" % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs):
"""Construct an `Apply` node using `self.make_node` and return its outputs.
......@@ -565,13 +624,16 @@ class PureOp(object):
x = tensor.matrix()
# tensor.exp is an Op instance, calls
# Op.__call__(self=<instance of exp>, inputs=(x,))
y = tensor.exp(x)
This class implements a convenience function (for graph-building) which
uses `default_output`, but subclasses are free to override this function
and ignore `default_output`.
`tensor.exp` is an Op instance, so `tensor.exp(x)` calls
`tensor.exp.__call__` (i.e. this method) and returns its single output
`Variable`, `y`. The `Apply` node constructed by `self.make_node`
behind the scenes is available via `y.owner`.
`PureOp` authors are able to determine which output is returned by this method
via the `PureOp.default_output` property., but subclasses are free to override this
function and ignore `default_output`.
Parameters
----------
......@@ -580,88 +642,25 @@ class PureOp(object):
kwargs
Additional keyword arguments to be forwarded to
`make_node()` *except* for optional argument `return_list` (which
defaults to False). If `return_list` is True, then the returned
value is always a list. Otherwise it is either a single Variable
defaults to `False`). If `return_list` is `True`, then the returned
value is always a `list`. Otherwise it is either a single `Variable`
when the output of `make_node()` contains a single element, or this
output (unchanged) when it contains multiple elements.
Returns
-------
outputs : list of Variable or Variable
Either a list of output `Variable`s, or a single `Variable`.
This is determined by the number of outputs produced by the
`PureOp`, the value of the keyword `return_list`, and the value of
the `PureOp.default_output` property.
"""
return_list = kwargs.pop("return_list", False)
node = self.make_node(*inputs, **kwargs)
if config.compute_test_value != "off":
run_perform = True
# build test input-values
storage_map = {}
compute_map = {}
for i, ins in enumerate(node.inputs):
try:
storage_map[ins] = [self._get_test_value(ins)]
compute_map[ins] = [True]
except AttributeError:
# no test-value was specified, act accordingly
if config.compute_test_value == "warn":
warnings.warn(
"Warning, Cannot compute test value: input %i (%s) of Op %s missing default value"
% (i, ins, node),
stacklevel=2,
)
run_perform = False
elif config.compute_test_value == "raise":
detailed_err_msg = utils.get_variable_trace_string(ins)
raise ValueError(
"Cannot compute test value: input %i (%s) of Op %s missing default value. %s"
% (i, ins, node, detailed_err_msg)
)
elif config.compute_test_value == "ignore":
# silently skip test
run_perform = False
elif config.compute_test_value == "pdb":
import pdb
pdb.post_mortem(sys.exc_info()[2])
else:
raise ValueError(
"%s is invalid for option config.compute_test_value"
% config.compute_test_value
)
# if all inputs have test-values, run the actual op
if run_perform:
# Original values should not be destroyed:
# copy the values of the inputs in destroy_map
destroyed_inputs_idx = set()
if getattr(node.op, "destroy_map", None):
for i_pos_list in node.op.destroy_map.values():
destroyed_inputs_idx.update(i_pos_list)
for inp_idx in destroyed_inputs_idx:
inp = node.inputs[inp_idx]
storage_map[inp] = [storage_map[inp][0].copy()]
# Prepare storage_map and compute_map for the outputs
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
# compute output value once with test inputs to validate graph
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling=[]
)
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk()
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0])
# add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method.
output.tag.test_value = storage_map[output][0]
compute_test_value(node)
if self.default_output is not None:
rval = node.outputs[self.default_output]
......@@ -1052,10 +1051,9 @@ def get_test_value(v):
"""
if not isinstance(v, graph.Variable):
v_var = theano.tensor.as_tensor_variable(v)
else:
v_var = v
return PureOp._get_test_value(v_var)
v = theano.tensor.as_tensor_variable(v)
return v.get_test_value()
def missing_test_message(msg):
......@@ -1076,14 +1074,14 @@ def missing_test_message(msg):
"""
action = config.compute_test_value
if action == "raise":
raise AttributeError(msg)
raise TestValueError(msg)
elif action == "warn":
warnings.warn(msg, stacklevel=2)
else:
assert action in ["ignore", "off"]
def get_debug_values(*args):
def get_test_values(*args):
"""
Intended use:
......@@ -1116,7 +1114,7 @@ def get_debug_values(*args):
for i, arg in enumerate(args):
try:
rval.append(get_test_value(arg))
except AttributeError:
except TestValueError:
if hasattr(arg, "name") and arg.name is not None:
missing_test_message(
"Argument {} ('{}') has no test value".format(i, arg.name)
......
......@@ -156,6 +156,12 @@ def hashtype(self):
undef = object()
class TestValueError(Exception):
"""Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......
......@@ -15,7 +15,7 @@ from theano import gof
from theano.gof import utils, Variable
from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values
from theano.gof.op import get_test_values
from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
......@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
continue
if isinstance(new_output_grad.type, DisconnectedType):
continue
for orig_output_v, new_output_grad_v in get_debug_values(*packed):
for orig_output_v, new_output_grad_v in get_test_values(*packed):
o_shape = orig_output_v.shape
g_shape = new_output_grad_v.shape
if o_shape != g_shape:
......@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape
if hasattr(term, "shape"):
orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term):
for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
......
......@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
from theano.compile import SharedVariable, function, ops
from theano.tensor import opt
from theano.updates import OrderedUpdates
from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse
......@@ -523,18 +524,17 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice)
except AttributeError as e:
nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except TestValueError:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(
_logger.warning(
(
"Cannot compute test value for "
"the inner function of scan, input value "
"missing %s"
),
e,
"missing {}"
).format(_seq_val_slice)
)
# Add names to slices for debugging and pretty printing ..
......@@ -655,17 +655,14 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg)
except AttributeError as e:
arg.tag.test_value = gof.get_test_value(actual_arg)
except TestValueError:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(
_logger.warning(
(
"Cannot compute test value for the "
"inner function of scan, input value missing %s"
),
e,
"inner function of scan, test value missing: {}"
).format(actual_arg)
)
if getattr(init_out["initial"], "name", None) is not None:
......@@ -716,20 +713,17 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(
nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice
)
except AttributeError as e:
except TestValueError:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(
_logger.warning(
(
"Cannot compute test value for "
"the inner function of scan, input value "
"missing. %s"
),
e,
"the inner function of scan, test value "
"missing: {}"
).format(_init_out_var_slice)
)
# give it a name or debugging and pretty printing
......@@ -808,10 +802,8 @@ def scan(
_logger.warning(
(
"When the number of steps is fixed and equal "
"to 1, the provided stopping condition, ",
str(condition),
" is ignored",
)
"to 1, the provided stopping condition, {} is ignored",
).format(condition)
)
for pos, inner_out in enumerate(outputs):
......
......@@ -33,6 +33,7 @@ from six import string_types
from theano import gof, compat, tensor, scalar
from theano.compile.pfunc import rebuild_collect_shared
from theano.tensor.basic import get_scalar_constant_value
from theano.gof.utils import TestValueError
# Logging function for sending warning or info
......@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
# Copy test value, cast it if necessary
try:
x_test_value = gof.op.get_test_value(x)
except AttributeError:
# There is no test value
except TestValueError:
pass
else:
# This clause is executed if no exception was raised
......@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
if theano.config.compute_test_value != "off":
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
except TestValueError:
pass
return nw_x
......
......@@ -156,6 +156,7 @@ from theano.gof import (
Apply,
ReplacementDidntRemovedError,
)
from theano.gof.utils import TestValueError
from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint
......@@ -2492,45 +2493,45 @@ class BatchedDot(Op):
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != "off"
test_values_enabled = config.compute_test_value != "off"
if debugger_available:
if test_values_enabled:
try:
iv0 = theano.gof.op.get_test_value(inputs[0])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"first input passed to BatchedDot.R_op has no test value"
)
debugger_available = False
test_values_enabled = False
try:
iv1 = theano.gof.op.get_test_value(inputs[1])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"second input passed to BatchedDot.R_op has no test value"
)
debugger_available = False
test_values_enabled = False
if eval_points[0]:
try:
ev0 = theano.gof.op.get_test_value(eval_points[0])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"first eval point passed to BatchedDot.R_op "
"has no test value"
)
debugger_available = False
test_values_enabled = False
if eval_points[1]:
try:
ev1 = theano.gof.op.get_test_value(eval_points[1])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"second eval point passed to BatchedDot.R_op "
"has no test value"
)
debugger_available = False
test_values_enabled = False
if debugger_available:
if test_values_enabled:
input_values = [iv0, iv1]
eval_point_values = [ev0, ev1]
......
......@@ -42,7 +42,7 @@ from theano.gof.opt import (
pre_constant_merge,
pre_greedy_local_optimizer,
)
from theano.gof.utils import MethodNotDefined
from theano.gof.utils import MethodNotDefined, TestValueError
from theano.gradient import DisconnectedType
from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (
......@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except AttributeError:
except TestValueError:
pass
tmp_s_input.append(tmp)
......@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = v.flatten()[0]
except AttributeError:
except TestValueError:
pass
inputs.append(i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论