提交 abcfa04e authored 作者: Frederic Bastien's avatar Frederic Bastien

Make one scan function not recurse and have the last line put more node in the…

Make one scan function not recurse and have the last line put more node in the self.valid. So this will help investigate less the graph.
上级 9835fbdb
......@@ -855,50 +855,76 @@ class Validator(object):
If out is not valid and has no equivalent, None is returned.
"""
if out in self.valid:
return out, True
elif out in self.valid_equivalent:
return self.valid_equivalent[out], False
elif out in self.invalid:
return None
if out.owner is None:
if isinstance(out, tensor.TensorConstant):
# This might be a constant from the outer graph or a constant
# from the inner graph. In all cases, we can clone it to be
# certain we have a valid constant
cloned_out = out.clone()
self.valid.add(cloned_out)
def get_value(out):
if out in self.valid:
return out, True
elif out in self.valid_equivalent:
return self.valid_equivalent[out], False
elif out in self.invalid:
return None
else:
raise RuntimeError("This should not happen")
q = [out]
while q:
out = q.pop()
if out in self.valid:
continue
elif out in self.invalid:
continue
if out.owner is None:
if isinstance(out, tensor.TensorConstant):
# This might be a constant from the outer graph or a constant
# from the inner graph. In all cases, we can clone it to be
# certain we have a valid constant
# TODO: FRED I think the clone is not needed.
cloned_out = out.clone()
self.valid.add(cloned_out)
self.invalid.add(out)
self.valid_equivalent[out] = cloned_out
continue
else:
# This is an input node and it has not been explicitly marked
# as invalid so we can use it
self.valid.add(out)
continue
# Process the input if needed
continue_while = False
for inp in out.owner.inputs:
if inp not in self.valid and inp not in self.invalid:
q.append(out)
q.extend(out.owner.inputs)
continue_while = True
break
if continue_while:
continue
inputs = [get_value(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
continue
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
else:
# This is an input node and it has not been explicitly marked
# as invalid so we can use it
return out, True
continue
# All inputs are valid, so is out
self.valid.add(out)
# Recurse over inputs
inputs = [self.check(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
return None
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
# All inputs are valid, so is out
return out, True
return get_value(out)
def scan_can_remove_outs(op, out_idxs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论