提交 abcfa04e authored 作者: Frederic Bastien's avatar Frederic Bastien

Make one scan function not recurse and have the last line put more node in the…

Make one scan function not recurse and have the last line put more node in the self.valid. So this will help investigate less the graph.
上级 9835fbdb
...@@ -855,50 +855,76 @@ class Validator(object): ...@@ -855,50 +855,76 @@ class Validator(object):
If out is not valid and has no equivalent, None is returned. If out is not valid and has no equivalent, None is returned.
""" """
if out in self.valid:
return out, True def get_value(out):
elif out in self.valid_equivalent: if out in self.valid:
return self.valid_equivalent[out], False return out, True
elif out in self.invalid: elif out in self.valid_equivalent:
return None return self.valid_equivalent[out], False
elif out in self.invalid:
if out.owner is None: return None
if isinstance(out, tensor.TensorConstant): else:
# This might be a constant from the outer graph or a constant raise RuntimeError("This should not happen")
# from the inner graph. In all cases, we can clone it to be
# certain we have a valid constant q = [out]
cloned_out = out.clone() while q:
self.valid.add(cloned_out) out = q.pop()
if out in self.valid:
continue
elif out in self.invalid:
continue
if out.owner is None:
if isinstance(out, tensor.TensorConstant):
# This might be a constant from the outer graph or a constant
# from the inner graph. In all cases, we can clone it to be
# certain we have a valid constant
# TODO: FRED I think the clone is not needed.
cloned_out = out.clone()
self.valid.add(cloned_out)
self.invalid.add(out)
self.valid_equivalent[out] = cloned_out
continue
else:
# This is an input node and it has not been explicitly marked
# as invalid so we can use it
self.valid.add(out)
continue
# Process the input if needed
continue_while = False
for inp in out.owner.inputs:
if inp not in self.valid and inp not in self.invalid:
q.append(out)
q.extend(out.owner.inputs)
continue_while = True
break
if continue_while:
continue
inputs = [get_value(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
continue
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out) self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out self.valid_equivalent[out] = cloned_out
return cloned_out, False continue
else:
# This is an input node and it has not been explicitly marked # All inputs are valid, so is out
# as invalid so we can use it self.valid.add(out)
return out, True
# Recurse over inputs return get_value(out)
inputs = [self.check(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out
if None in inputs:
self.invalid.add(out)
return None
# If some inputs are invalid with equivalent,
# an equivalent out should be built and returned
all_inputs = [inp for (inp, is_valid) in inputs]
equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid]
if equiv_inputs:
cloned_node = out.owner.clone_with_new_inputs(all_inputs)
cloned_out = cloned_node.outputs[out.index]
self.invalid.add(out)
self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
# All inputs are valid, so is out
return out, True
def scan_can_remove_outs(op, out_idxs): def scan_can_remove_outs(op, out_idxs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论