提交 41d3e005 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor the code to extract test_value from different types of variables.

It enables using that code in scan.
上级 1fdd4c84
...@@ -308,6 +308,28 @@ class PureOp(object): ...@@ -308,6 +308,28 @@ class PureOp(object):
""" """
raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__) raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__)
@classmethod
def _get_test_value(cls, v):
"""
Extract test value from variable v. Raises AttributeError if there is none.
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
if isinstance(v, graph.Constant):
return v.value
elif isinstance(v, SharedVariable):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct
return v.type.filter(v.tag.test_value)
raise AttributeError('%s has not test value' % v)
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
"""Optional: Return some or all output[s] of `make_node`. """Optional: Return some or all output[s] of `make_node`.
...@@ -329,21 +351,14 @@ class PureOp(object): ...@@ -329,21 +351,14 @@ class PureOp(object):
self.add_tag_trace(node) self.add_tag_trace(node)
if config.compute_test_value != 'off': if config.compute_test_value != 'off':
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
run_perform = True run_perform = True
# build test input-values # build test input-values
input_vals = [] input_vals = []
for i, ins in enumerate(node.inputs): for i, ins in enumerate(node.inputs):
if isinstance(ins, graph.Constant): try:
input_vals.append(ins.value) input_vals.append(self._get_test_value(ins))
elif isinstance(ins,SharedVariable): except AttributeError:
input_vals.append(ins.get_value(borrow=True, return_internal_type=True))
elif isinstance(ins,graph.Variable) and hasattr(ins.tag, 'test_value'):
# ensure that the test value is correct
input_vals.append(ins.type.filter(ins.tag.test_value))
else:
# no test-value was specified, act accordingly # no test-value was specified, act accordingly
if config.compute_test_value == 'warn': if config.compute_test_value == 'warn':
warnings.warn('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node), stacklevel=2) warnings.warn('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node), stacklevel=2)
...@@ -373,7 +388,7 @@ class PureOp(object): ...@@ -373,7 +388,7 @@ class PureOp(object):
try: try:
node.op.perform(node, input_vals, output_storage) node.op.perform(node, input_vals, output_storage)
# add 'test_value' to output tags, so that downstream ops can use these # add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method. # numerical values as inputs to their perform method.
for (outval, node_output) in zip(output_storage, node.outputs): for (outval, node_output) in zip(output_storage, node.outputs):
node_output.tag.test_value = outval[0] node_output.tag.test_value = outval[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论