提交 061e875d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use the new make_thunk mechanism in compute_test_value

上级 913462e7
......@@ -370,10 +370,12 @@ class PureOp(object):
run_perform = True
# build test input-values
input_vals = []
storage_map = {}
compute_map = {}
for i, ins in enumerate(node.inputs):
try:
input_vals.append(self._get_test_value(ins))
storage_map[ins] = [self._get_test_value(ins)]
compute_map[ins] = [True]
except AttributeError:
# no test-value was specified, act accordingly
if config.compute_test_value == 'warn':
......@@ -389,33 +391,34 @@ class PureOp(object):
# if all inputs have test-values, run the actual op
if run_perform:
# Original values should not be destroyed:
# copy the values of the inputs in destroy_map
destroyed_inputs_idx = []
destroyed_inputs = []
if getattr(node.op, 'destroy_map', None):
for i_pos_list in node.op.destroy_map.itervalues():
destroyed_inputs_idx.extend(i_pos_list)
for i in destroyed_inputs_idx:
input_vals[i] = input_vals[i].copy()
destroyed_inputs_idx.extend(node.inputs[i_pos_list])
for inp in destroyed_inputs:
storage_map[inp] = [storage_map[inp][0].copy()]
# Prepare storage_map and compute_map for the outputs
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
# compute output value once with test inputs to validate graph
output_storage = [[None]] * len(node.outputs)
try:
node.op.perform(node, input_vals, output_storage)
thunk = node.op.make_thunk(node, storage_map, compute_map,
no_recycling=[])
required = thunk()
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0])
# add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method.
for (outval, node_output) in zip(output_storage, node.outputs):
node_output.tag.test_value = outval[0]
except utils.MethodNotDefined, e:
# This case happens when the perform method is not defined
# for a certain Op.
#TODO: use the c_thunk?
if config.compute_test_value == 'warn':
warnings.warn('Warning, in compute_test_value:' + type(e), stacklevel=2)
elif config.compute_test_value == 'raise':
raise
output.tag.test_value = storage_map[output][0]
if self.default_output is not None:
return node.outputs[self.default_output]
......
......@@ -226,7 +226,7 @@ class TestComputeTestValue(unittest.TestCase):
# Get traceback
tb = sys.exc_info()[2]
# Get frame info 3 layers up
frame_info = traceback.extract_tb(tb)[-3]
frame_info = traceback.extract_tb(tb)[-4]
# We should be in the "fx" function defined above
assert os.path.split(frame_info[0])[1] == 'test_compute_test_value.py'
assert frame_info[2] == 'fx'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论