提交 0ea10588 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Move get_test_value method to the Variable (sub)classes

The test value computation steps were also extracted from `PureOp.__call__` and put into a stand-alone `compute_test_value` function.
上级 03b82c26
......@@ -150,6 +150,9 @@ class SharedVariable(Variable):
else:
self.container.value = copy.deepcopy(new_value)
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
......
......@@ -58,7 +58,14 @@ from theano.gof.link import (
WrapLinkerMany,
)
from theano.gof.op import Op, OpenMPOp, PureOp, COp, ops_with_inner_function
from theano.gof.op import (
Op,
OpenMPOp,
PureOp,
COp,
ops_with_inner_function,
get_test_value,
)
from theano.gof.type import EnumType, EnumList, CEnumType
......
......@@ -383,19 +383,39 @@ class Variable(Node):
self.tag = utils.ValidatingScratchpad("test_value", type.filter)
self.type = type
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
if index is not None and not isinstance(index, integer_types):
raise TypeError("index must be an int", index)
self.index = index
if name is not None and not isinstance(name, string_types):
raise TypeError("name must be a string", name)
self.name = name
self.auto_name = "auto_" + str(next(self.__count__))
Variable.notify_construction_observers(self)
def get_test_value(self):
"""Get the test value.
Raises
------
AttributeError
"""
if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self)
raise AttributeError(
"{} has no test value {}".format(self, detailed_err_msg)
)
return self.tag.test_value
def __str__(self):
"""Return a str representation of the Variable."""
if self.name is not None:
......@@ -583,6 +603,9 @@ class Constant(Variable):
self.data = type.filter(data)
utils.add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
......
......@@ -47,6 +47,89 @@ else:
return open(file, "U")
def compute_test_value(node):
"""Computes the test value of a node.
Parameters
----------
node : Apply
The `Apply` node for which the test value is computed.
Returns
-------
None
The `tag.test_value`s are updated in each `Variable` in `node.outputs`.
"""
# Gather the test values for each input of the node
storage_map = {}
compute_map = {}
for i, ins in enumerate(node.inputs):
try:
storage_map[ins] = [ins.get_test_value()]
compute_map[ins] = [True]
except AttributeError:
# no test-value was specified, act accordingly
if config.compute_test_value == "warn":
warnings.warn(
"Warning, Cannot compute test value: input %i (%s) of Op %s missing default value"
% (i, ins, node),
stacklevel=2,
)
return
elif config.compute_test_value == "raise":
detailed_err_msg = utils.get_variable_trace_string(ins)
raise ValueError(
"Cannot compute test value: input %i (%s) of Op %s missing default value. %s"
% (i, ins, node, detailed_err_msg)
)
elif config.compute_test_value == "ignore":
return
elif config.compute_test_value == "pdb":
import pdb
pdb.post_mortem(sys.exc_info()[2])
else:
raise ValueError(
"%s is invalid for option config.compute_test_value"
% config.compute_test_value
)
# All inputs have test-values; perform the `Op`'s computation
# The original values should not be destroyed, so we copy the values of the
# inputs in `destroy_map`
destroyed_inputs_idx = set()
if getattr(node.op, "destroy_map", None):
for i_pos_list in node.op.destroy_map.values():
destroyed_inputs_idx.update(i_pos_list)
for inp_idx in destroyed_inputs_idx:
inp = node.inputs[inp_idx]
storage_map[inp] = [storage_map[inp][0].copy()]
# Prepare `storage_map` and `compute_map` for the outputs
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
# Create a thunk that performs the computation
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk()
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0])
# Add 'test_value' to output tag, so that downstream `Op`s can use
# these numerical values as test values
output.tag.test_value = storage_map[output][0]
class CLinkerObject(object):
"""
Standard elements of an Op or Type used with the CLinker.
......@@ -529,31 +612,6 @@ class PureOp(object):
"""
raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__)
@classmethod
def _get_test_value(cls, v):
"""
Extract test value from variable v.
Raises AttributeError if there is none.
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
if isinstance(v, graph.Constant):
return v.value
elif isinstance(v, SharedVariable):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, "test_value"):
return v.tag.test_value
detailed_err_msg = utils.get_variable_trace_string(v)
raise AttributeError("%s has no test value %s" % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs):
"""Construct an `Apply` node using `self.make_node` and return its outputs.
......@@ -565,13 +623,16 @@ class PureOp(object):
x = tensor.matrix()
# tensor.exp is an Op instance, calls
# Op.__call__(self=<instance of exp>, inputs=(x,))
y = tensor.exp(x)
This class implements a convenience function (for graph-building) which
uses `default_output`, but subclasses are free to override this function
and ignore `default_output`.
`tensor.exp` is an Op instance, so `tensor.exp(x)` calls
`tensor.exp.__call__` (i.e. this method) and returns its single output
`Variable`, `y`. The `Apply` node constructed by `self.make_node`
behind the scenes is available via `y.owner`.
`PureOp` authors are able to determine which output is returned by this method
via the `PureOp.default_output` property., but subclasses are free to override this
function and ignore `default_output`.
Parameters
----------
......@@ -580,88 +641,25 @@ class PureOp(object):
kwargs
Additional keyword arguments to be forwarded to
`make_node()` *except* for optional argument `return_list` (which
defaults to False). If `return_list` is True, then the returned
value is always a list. Otherwise it is either a single Variable
defaults to `False`). If `return_list` is `True`, then the returned
value is always a `list`. Otherwise it is either a single `Variable`
when the output of `make_node()` contains a single element, or this
output (unchanged) when it contains multiple elements.
Returns
-------
outputs : list of Variable or Variable
Either a list of output `Variable`s, or a single `Variable`.
This is determined by the number of outputs produced by the
`PureOp`, the value of the keyword `return_list`, and the value of
the `PureOp.default_output` property.
"""
return_list = kwargs.pop("return_list", False)
node = self.make_node(*inputs, **kwargs)
if config.compute_test_value != "off":
run_perform = True
# build test input-values
storage_map = {}
compute_map = {}
for i, ins in enumerate(node.inputs):
try:
storage_map[ins] = [self._get_test_value(ins)]
compute_map[ins] = [True]
except AttributeError:
# no test-value was specified, act accordingly
if config.compute_test_value == "warn":
warnings.warn(
"Warning, Cannot compute test value: input %i (%s) of Op %s missing default value"
% (i, ins, node),
stacklevel=2,
)
run_perform = False
elif config.compute_test_value == "raise":
detailed_err_msg = utils.get_variable_trace_string(ins)
raise ValueError(
"Cannot compute test value: input %i (%s) of Op %s missing default value. %s"
% (i, ins, node, detailed_err_msg)
)
elif config.compute_test_value == "ignore":
# silently skip test
run_perform = False
elif config.compute_test_value == "pdb":
import pdb
pdb.post_mortem(sys.exc_info()[2])
else:
raise ValueError(
"%s is invalid for option config.compute_test_value"
% config.compute_test_value
)
# if all inputs have test-values, run the actual op
if run_perform:
# Original values should not be destroyed:
# copy the values of the inputs in destroy_map
destroyed_inputs_idx = set()
if getattr(node.op, "destroy_map", None):
for i_pos_list in node.op.destroy_map.values():
destroyed_inputs_idx.update(i_pos_list)
for inp_idx in destroyed_inputs_idx:
inp = node.inputs[inp_idx]
storage_map[inp] = [storage_map[inp][0].copy()]
# Prepare storage_map and compute_map for the outputs
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
# compute output value once with test inputs to validate graph
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling=[]
)
thunk.inputs = [storage_map[v] for v in node.inputs]
thunk.outputs = [storage_map[v] for v in node.outputs]
required = thunk()
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0])
# add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method.
output.tag.test_value = storage_map[output][0]
compute_test_value(node)
if self.default_output is not None:
rval = node.outputs[self.default_output]
......@@ -1052,10 +1050,9 @@ def get_test_value(v):
"""
if not isinstance(v, graph.Variable):
v_var = theano.tensor.as_tensor_variable(v)
else:
v_var = v
return PureOp._get_test_value(v_var)
v = theano.tensor.as_tensor_variable(v)
return v.get_test_value()
def missing_test_message(msg):
......
......@@ -523,7 +523,7 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(_seq_val_slice)
nw_slice.tag.test_value = gof.get_test_value(_seq_val_slice)
except AttributeError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
......@@ -655,7 +655,7 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
arg.tag.test_value = gof.Op._get_test_value(actual_arg)
arg.tag.test_value = gof.get_test_value(actual_arg)
except AttributeError as e:
if config.compute_test_value != "ignore":
# No need to print a warning or raise an error now,
......@@ -716,7 +716,7 @@ def scan(
# Try to transfer test_value to the new variable
if config.compute_test_value != "off":
try:
nw_slice.tag.test_value = gof.Op._get_test_value(
nw_slice.tag.test_value = gof.get_test_value(
_init_out_var_slice
)
except AttributeError as e:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论