提交 41d3e005 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor the code to extract test_value from different types of variables.

It enables using that code in scan.
上级 1fdd4c84
......@@ -308,6 +308,28 @@ class PureOp(object):
"""
raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__)
@classmethod
def _get_test_value(cls, v):
"""
Extract test value from variable v. Raises AttributeError if there is none.
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
if isinstance(v, graph.Constant):
return v.value
elif isinstance(v, SharedVariable):
return v.get_value(borrow=True, return_internal_type=True)
elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'):
# ensure that the test value is correct
return v.type.filter(v.tag.test_value)
raise AttributeError('%s has not test value' % v)
def __call__(self, *inputs, **kwargs):
"""Optional: Return some or all output[s] of `make_node`.
......@@ -329,21 +351,14 @@ class PureOp(object):
self.add_tag_trace(node)
if config.compute_test_value != 'off':
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
run_perform = True
# build test input-values
input_vals = []
for i, ins in enumerate(node.inputs):
if isinstance(ins, graph.Constant):
input_vals.append(ins.value)
elif isinstance(ins,SharedVariable):
input_vals.append(ins.get_value(borrow=True, return_internal_type=True))
elif isinstance(ins,graph.Variable) and hasattr(ins.tag, 'test_value'):
# ensure that the test value is correct
input_vals.append(ins.type.filter(ins.tag.test_value))
else:
try:
input_vals.append(self._get_test_value(ins))
except AttributeError:
# no test-value was specified, act accordingly
if config.compute_test_value == 'warn':
warnings.warn('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node), stacklevel=2)
......@@ -373,7 +388,7 @@ class PureOp(object):
try:
node.op.perform(node, input_vals, output_storage)
# add 'test_value' to output tags, so that downstream ops can use these
# add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method.
for (outval, node_output) in zip(output_storage, node.outputs):
node_output.tag.test_value = outval[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论