提交 d4ff7188 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Small updates to compute_test_values

上级 71a8a1c9
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
The `Op` class is the base interface for all operations The `Op` class is the base interface for all operations
compatible with `gof`'s :doc:`graph` routines. compatible with `gof`'s :doc:`graph` routines.
""" """
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -331,7 +329,7 @@ class PureOp(object): ...@@ -331,7 +329,7 @@ class PureOp(object):
# build test input-values # build test input-values
input_vals = [] input_vals = []
for ins in node.inputs: for i, ins in enumerate(node.inputs):
if isinstance(ins, graph.Constant): if isinstance(ins, graph.Constant):
input_vals.append(ins.value) input_vals.append(ins.value)
elif isinstance(ins,SharedVariable): elif isinstance(ins,SharedVariable):
...@@ -342,10 +340,11 @@ class PureOp(object): ...@@ -342,10 +340,11 @@ class PureOp(object):
else: else:
# no test-value was specified, act accordingly # no test-value was specified, act accordingly
if config.compute_test_value == 'warn': if config.compute_test_value == 'warn':
raise Warning('Cannot compute test value: input %s of Op %s missing default value') # TODO: use warnings.warn, http://docs.python.org/library/warnings.html#warnings.warn
print >>sys.stderr, ('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node))
run_perform = False run_perform = False
elif config.compute_test_value == 'err': elif config.compute_test_value == 'err':
raise ValueError('Cannot compute test value: input %s of Op %s missing default value') raise ValueError('Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node))
else: else:
# silently skip test # silently skip test
run_perform = False run_perform = False
...@@ -355,12 +354,23 @@ class PureOp(object): ...@@ -355,12 +354,23 @@ class PureOp(object):
# compute output value once with test inputs to validate graph # compute output value once with test inputs to validate graph
output_storage = [[None]] * len(node.outputs) output_storage = [[None]] * len(node.outputs)
node.op.perform(node, input_vals, output_storage) try:
node.op.perform(node, input_vals, output_storage)
# add 'test_value' to output tags, so that downstream ops can use these
# numerical values as inputs to their perform method. # add 'test_value' to output tags, so that downstream ops can use these
for (outval, node_output) in zip(output_storage, node.outputs): # numerical values as inputs to their perform method.
node_output.tag.test_value = outval[0] for (outval, node_output) in zip(output_storage, node.outputs):
node_output.tag.test_value = outval[0]
except utils.MethodNotDefined, e:
# This case happens when the perform method is not defined
# for a certain Op.
#TODO: use the c_thunk?
if config.compute_test_value == 'warn':
# TODO: use warnings.warn
print >>sys.stderr, 'Warning, in compute_test_value:', type(e)
print >>sys.stderr, e
elif config.compute_test_value == 'err':
raise
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] return node.outputs[self.default_output]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论