Unverified 提交 3861b86c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #112 from brandonwillard/move-get_test_value-to-classes

Move get_test_value method to the Variable (sub)classes
......@@ -5,14 +5,12 @@ import theano.gof.op as op
import theano.tensor as tt
from six import string_types
from theano import scalar, shared
from theano import scalar, shared, config
from theano.configparser import change_flags
from theano.gof.graph import Apply, Variable
from theano.gof.type import Generic, Type
config = theano.config
Op = op.Op
utils = op.utils
from theano.gof.op import Op
from theano.gof.utils import TestValueError, MethodNotDefined
def as_variable(x):
......@@ -177,7 +175,7 @@ class TestMakeThunk:
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
......@@ -213,7 +211,7 @@ class TestMakeThunk:
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
......@@ -229,7 +227,7 @@ class TestMakeThunk:
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, utils.MethodNotDefined)):
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
def test_no_make_node(self):
......@@ -288,23 +286,23 @@ def test_test_value_op():
@change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
def test_get_test_values_no_debugger():
"""Tests that `get_test_values` returns `[]` when debugger is off."""
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore():
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
def test_get_test_values_ignore():
"""Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
def test_get_debug_values_success():
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
def test_get_test_values_success():
"""Tests that `get_test_values` returns values when available (and the debugger is on)."""
for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode):
......@@ -314,7 +312,7 @@ def test_get_debug_values_success():
iters = 0
for x_val, y_val in op.get_debug_values(x, y):
for x_val, y_val in op.get_test_values(x, y):
assert x_val.shape == (4,)
assert y_val.shape == (5, 5)
......@@ -325,9 +323,9 @@ def test_get_debug_values_success():
@change_flags(compute_test_value="raise")
def test_get_debug_values_exc():
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
with pytest.raises(AttributeError):
with pytest.raises(TestValueError):
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
......@@ -2799,11 +2799,11 @@ class TestAsTensorVariable:
as_tensor_variable(good_apply_var)
bad_apply_var = ApplyDefaultTestOp(-1).make_node(self.x)
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
_ = as_tensor_variable(bad_apply_var)
bad_apply_var = ApplyDefaultTestOp(2).make_node(self.x)
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
_ = as_tensor_variable(bad_apply_var)
def test_list(self):
......@@ -2816,7 +2816,7 @@ class TestAsTensorVariable:
_ = as_tensor_variable(y)
bad_apply_var = ApplyDefaultTestOp([0, 1]).make_node(self.x)
with pytest.raises(AttributeError):
with pytest.raises(ValueError):
as_tensor_variable(bad_apply_var)
def test_strip_leading_broadcastable(self):
......
......@@ -172,28 +172,26 @@ else:
np.seterr(all=_all, divide=_divide, over=_over, under=_under, invalid=_invalid)
del _all, _divide, _over, _under, _invalid
# This is defined here because it is designed to work across symbolic
# datatypes (Sparse and Tensor)
def dot(l, r):
"""Return a symbolic matrix/dot product between l and r """
rval = NotImplemented
e0, e1 = None, None
if rval == NotImplemented and hasattr(l, "__dot__"):
try:
rval = l.__dot__(r)
except Exception as e0:
rval = NotImplemented
if rval == NotImplemented and hasattr(r, "__rdot__"):
try:
rval = r.__rdot__(l)
except Exception as e1:
rval = NotImplemented
if rval == NotImplemented:
raise NotImplementedError("Dot failed for the following reasons:", (e0, e1))
return rval
"""Return a symbolic dot product.
This is designed to work with both sparse and dense tensors types.
"""
try:
res = l.__dot__(r)
if res is NotImplemented:
raise NotImplementedError()
return res
except (NotImplementedError, AttributeError, TypeError):
res = r.__rdot__(l)
if res is NotImplemented:
raise NotImplementedError()
return res
def get_scalar_constant_value(v):
......
......@@ -150,6 +150,9 @@ class SharedVariable(Variable):
else:
self.container.value = copy.deepcopy(new_value)
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
......
......@@ -58,7 +58,14 @@ from theano.gof.link import (
WrapLinkerMany,
)
from theano.gof.op import Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.op import (
Op,
OpenMPOp,
PureOp,
COp,
ops_with_inner_function,
get_test_value,
)
from theano.gof.type import EnumType, EnumList, CEnumType
......
......@@ -14,7 +14,7 @@ from six.moves import StringIO
from theano import config
from theano.gof import graph, utils, toolbox
from theano.gof.utils import get_variable_trace_string
from theano.gof.utils import get_variable_trace_string, TestValueError
from theano.misc.ordered_set import OrderedSet
NullType = None
......@@ -511,7 +511,7 @@ class FunctionGraph(utils.object2):
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
except AttributeError:
except TestValueError:
pass
else:
tval_shape = getattr(tval, "shape", None)
......
......@@ -14,6 +14,7 @@ from six import string_types, integer_types
from theano import config
from theano.gof import utils
from theano.gof.utils import TestValueError
from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en"
......@@ -164,13 +165,13 @@ class Apply(Node):
if len(self.outputs) == 1:
return self.outputs[0]
else:
raise AttributeError(
raise ValueError(
"%s.default_output should be an output index." % self.op
)
elif not isinstance(do, integer_types):
raise AttributeError("%s.default_output should be an int or long" % self.op)
raise ValueError("%s.default_output should be an int or long" % self.op)
elif do < 0 or do >= len(self.outputs):
raise AttributeError("%s.default_output is out of range." % self.op)
raise ValueError("%s.default_output is out of range." % self.op)
return self.outputs[do]
out = property(default_output, doc="alias for self.default_output()")
......@@ -383,19 +384,39 @@ class Variable(Node):
self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
if index is not None and not isinstance(index, integer_types):
raise TypeError("index must be an int", index)
self.index = index
if name is not None and not isinstance(name, string_types):
raise TypeError("name must be a string", name)
self.name = name
self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
Raises
------
TestValueError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise TestValueError(
"{} has no test value {}".format(self, detailed_err_msg)
)
return self.tag.test_value
def __str__(self):
"""Return a str representation of the Variable."""
if self.name is not None:
......@@ -416,7 +437,7 @@ class Variable(Node):
overridden by classes with non printable test_value to provide a
suitable representation of the test_value.
"""
return repr(theano.gof.op.get_test_value(self))
return repr(self.get_test_value())
def __repr__(self, firstPass=True):
"""Return a repr of the Variable.
......@@ -429,7 +450,7 @@ class Variable(Node):
if config.print_test_value and firstPass:
try:
to_print.append(self.__repr_test_value__())
except AttributeError:
except TestValueError:
pass
return "\n".join(to_print)
......@@ -583,6 +604,9 @@ class Constant(Variable):
self.data = type.filter(data)
utils.add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
......
差异被折叠。
......@@ -156,6 +156,12 @@ def hashtype(self):
undef = object()
class TestValueError(Exception):
"""Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......
......@@ -15,7 +15,7 @@ from theano import gof
from theano.gof import utils, Variable
from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values
from theano.gof.op import get_test_values
from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
......@@ -1217,7 +1217,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
continue
if isinstance(new_output_grad.type, DisconnectedType):
continue
for orig_output_v, new_output_grad_v in get_debug_values(*packed):
for orig_output_v, new_output_grad_v in get_test_values(*packed):
o_shape = orig_output_v.shape
g_shape = new_output_grad_v.shape
if o_shape != g_shape:
......@@ -1310,7 +1310,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape
if hasattr(term, "shape"):
orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term):
for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
......
......@@ -54,6 +54,7 @@ from theano import compile, gof, tensor, config
from theano.compile import SharedVariable, function, ops
from theano.tensor import opt
from theano.updates import OrderedUpdates
from theano.gof.utils import TestValueError
from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import safe_new, traverse
......@@ -523,18 +524,17 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice)
except AttributeError as e:
nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except TestValueError:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(
_logger.warning(
(
"Cannot compute test value for "
"the inner function of scan, input value "
"missing %s"
),
e,
"missing {}"
).format(_seq_val_slice)
)
# Add names to slices for debugging and pretty printing ..
......@@ -655,17 +655,14 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg)
except AttributeError as e:
arg.tag.test_value = gof.get_test_value(actual_arg)
except TestValueError:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(
_logger.warning(
(
"Cannot compute test value for the "
"inner function of scan, input value missing %s"
),
e,
"inner function of scan, test value missing: {}"
).format(actual_arg)
)
if getattr(init_out["initial"], "name", None) is not None:
......@@ -716,20 +713,17 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(
nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice
)
except AttributeError as e:
except TestValueError:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
# it will be done when fn will be called.
_logger.info(
_logger.warning(
(
"Cannot compute test value for "
"the inner function of scan, input value "
"missing. %s"
),
e,
"the inner function of scan, test value "
"missing: {}"
).format(_init_out_var_slice)
)
# give it a name or debugging and pretty printing
......@@ -808,10 +802,8 @@ def scan(
_logger.warning(
(
"When the number of steps is fixed and equal "
"to 1, the provided stopping condition, ",
str(condition),
" is ignored",
)
"to 1, the provided stopping condition, {} is ignored",
).format(condition)
)
for pos, inner_out in enumerate(outputs):
......
......@@ -33,6 +33,7 @@ from six import string_types
from theano import gof, compat, tensor, scalar
from theano.compile.pfunc import rebuild_collect_shared
from theano.tensor.basic import get_scalar_constant_value
from theano.gof.utils import TestValueError
# Logging function for sending warning or info
......@@ -74,8 +75,7 @@ def safe_new(x, tag="", dtype=None):
# Copy test value, cast it if necessary
try:
x_test_value = gof.op.get_test_value(x)
except AttributeError:
# There is no test value
except TestValueError:
pass
else:
# This clause is executed if no exception was raised
......@@ -101,8 +101,7 @@ def safe_new(x, tag="", dtype=None):
if theano.config.compute_test_value != "off":
try:
nw_x.tag.test_value = copy.deepcopy(gof.op.get_test_value(x))
except AttributeError:
# This means `x` has no test value.
except TestValueError:
pass
return nw_x
......
......@@ -156,6 +156,7 @@ from theano.gof import (
Apply,
ReplacementDidntRemovedError,
)
from theano.gof.utils import TestValueError
from theano.gof.params_type import ParamsType
from theano.gof.opt import inherit_stack_trace
from theano.printing import pprint, FunctionPrinter, debugprint
......@@ -2492,45 +2493,45 @@ class BatchedDot(Op):
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != "off"
test_values_enabled = config.compute_test_value != "off"
if debugger_available:
if test_values_enabled:
try:
iv0 = theano.gof.op.get_test_value(inputs[0])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"first input passed to BatchedDot.R_op has no test value"
)
debugger_available = False
test_values_enabled = False
try:
iv1 = theano.gof.op.get_test_value(inputs[1])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"second input passed to BatchedDot.R_op has no test value"
)
debugger_available = False
test_values_enabled = False
if eval_points[0]:
try:
ev0 = theano.gof.op.get_test_value(eval_points[0])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"first eval point passed to BatchedDot.R_op "
"has no test value"
)
debugger_available = False
test_values_enabled = False
if eval_points[1]:
try:
ev1 = theano.gof.op.get_test_value(eval_points[1])
except AttributeError:
except TestValueError:
theano.gof.op.missing_test_message(
"second eval point passed to BatchedDot.R_op "
"has no test value"
)
debugger_available = False
test_values_enabled = False
if debugger_available:
if test_values_enabled:
input_values = [iv0, iv1]
eval_point_values = [ev0, ev1]
......
......@@ -42,7 +42,7 @@ from theano.gof.opt import (
pre_constant_merge,
pre_greedy_local_optimizer,
)
from theano.gof.utils import MethodNotDefined
from theano.gof.utils import MethodNotDefined, TestValueError
from theano.gradient import DisconnectedType
from theano.tensor.elemwise import Elemwise, DimShuffle
from theano.tensor.subtensor import (
......@@ -7747,7 +7747,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
"Cannot construct a scalar test value"
" from a test value with no size: {}".format(ii)
)
except AttributeError:
except TestValueError:
pass
tmp_s_input.append(tmp)
......@@ -7812,7 +7812,7 @@ def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None
v = gof.op.get_test_value(i)
if v.size > 0:
s.tag.test_value = v.flatten()[0]
except AttributeError:
except TestValueError:
pass
inputs.append(i)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论