提交 061e875d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use the new make_thunk mechanism in compute_test_value

上级 913462e7
...@@ -370,10 +370,12 @@ class PureOp(object): ...@@ -370,10 +370,12 @@ class PureOp(object):
run_perform = True run_perform = True
# build test input-values # build test input-values
input_vals = [] storage_map = {}
compute_map = {}
for i, ins in enumerate(node.inputs): for i, ins in enumerate(node.inputs):
try: try:
input_vals.append(self._get_test_value(ins)) storage_map[ins] = [self._get_test_value(ins)]
compute_map[ins] = [True]
except AttributeError: except AttributeError:
# no test-value was specified, act accordingly # no test-value was specified, act accordingly
if config.compute_test_value == 'warn': if config.compute_test_value == 'warn':
...@@ -389,33 +391,34 @@ class PureOp(object): ...@@ -389,33 +391,34 @@ class PureOp(object):
# if all inputs have test-values, run the actual op # if all inputs have test-values, run the actual op
if run_perform: if run_perform:
# Original values should not be destroyed: # Original values should not be destroyed:
# copy the values of the inputs in destroy_map # copy the values of the inputs in destroy_map
destroyed_inputs_idx = [] destroyed_inputs = []
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
for i_pos_list in node.op.destroy_map.itervalues(): for i_pos_list in node.op.destroy_map.itervalues():
destroyed_inputs_idx.extend(i_pos_list) destroyed_inputs_idx.extend(node.inputs[i_pos_list])
for i in destroyed_inputs_idx: for inp in destroyed_inputs:
input_vals[i] = input_vals[i].copy() storage_map[inp] = [storage_map[inp][0].copy()]
# Prepare storage_map and compute_map for the outputs
for o in node.outputs:
storage_map[o] = [None]
compute_map[o] = [False]
# compute output value once with test inputs to validate graph # compute output value once with test inputs to validate graph
output_storage = [[None]] * len(node.outputs) thunk = node.op.make_thunk(node, storage_map, compute_map,
try: no_recycling=[])
node.op.perform(node, input_vals, output_storage)
required = thunk()
assert not required # We provided all inputs
for output in node.outputs:
# Check that the output has been computed
assert compute_map[output][0], (output, storage_map[output][0])
# add 'test_value' to output tag, so that downstream ops can use these # add 'test_value' to output tag, so that downstream ops can use these
# numerical values as inputs to their perform method. # numerical values as inputs to their perform method.
for (outval, node_output) in zip(output_storage, node.outputs): output.tag.test_value = storage_map[output][0]
node_output.tag.test_value = outval[0]
except utils.MethodNotDefined, e:
# This case happens when the perform method is not defined
# for a certain Op.
#TODO: use the c_thunk?
if config.compute_test_value == 'warn':
warnings.warn('Warning, in compute_test_value:' + type(e), stacklevel=2)
elif config.compute_test_value == 'raise':
raise
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] return node.outputs[self.default_output]
......
...@@ -226,7 +226,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -226,7 +226,7 @@ class TestComputeTestValue(unittest.TestCase):
# Get traceback # Get traceback
tb = sys.exc_info()[2] tb = sys.exc_info()[2]
# Get frame info 3 layers up # Get frame info 3 layers up
frame_info = traceback.extract_tb(tb)[-3] frame_info = traceback.extract_tb(tb)[-4]
# We should be in the "fx" function defined above # We should be in the "fx" function defined above
assert os.path.split(frame_info[0])[1] == 'test_compute_test_value.py' assert os.path.split(frame_info[0])[1] == 'test_compute_test_value.py'
assert frame_info[2] == 'fx' assert frame_info[2] == 'fx'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论