提交 abcfa04e authored 作者: Frederic Bastien's avatar Frederic Bastien

Make one scan function not recurse and have the last line put more node in the…

Make one scan function not recurse and have the last line put more node in the self.valid. So this will help investigate less the graph.
上级 9835fbdb
...@@ -855,35 +855,59 @@ class Validator(object): ...@@ -855,35 +855,59 @@ class Validator(object):
If out is not valid and has no equivalent, None is returned. If out is not valid and has no equivalent, None is returned.
""" """
def get_value(out):
if out in self.valid: if out in self.valid:
return out, True return out, True
elif out in self.valid_equivalent: elif out in self.valid_equivalent:
return self.valid_equivalent[out], False return self.valid_equivalent[out], False
elif out in self.invalid: elif out in self.invalid:
return None return None
else:
raise RuntimeError("This should not happen")
q = [out]
while q:
out = q.pop()
if out in self.valid:
continue
elif out in self.invalid:
continue
if out.owner is None: if out.owner is None:
if isinstance(out, tensor.TensorConstant): if isinstance(out, tensor.TensorConstant):
# This might be a constant from the outer graph or a constant # This might be a constant from the outer graph or a constant
# from the inner graph. In all cases, we can clone it to be # from the inner graph. In all cases, we can clone it to be
# certain we have a valid constant # certain we have a valid constant
# TODO: FRED I think the clone is not needed.
cloned_out = out.clone() cloned_out = out.clone()
self.valid.add(cloned_out) self.valid.add(cloned_out)
self.invalid.add(out) self.invalid.add(out)
self.valid_equivalent[out] = cloned_out self.valid_equivalent[out] = cloned_out
return cloned_out, False continue
else: else:
# This is an input node and it has not been explicitly marked # This is an input node and it has not been explicitly marked
# as invalid so we can use it # as invalid so we can use it
return out, True self.valid.add(out)
continue
# Recurse over inputs # Process the input if needed
inputs = [self.check(i) for i in out.owner.inputs] continue_while = False
for inp in out.owner.inputs:
if inp not in self.valid and inp not in self.invalid:
q.append(out)
q.extend(out.owner.inputs)
continue_while = True
break
if continue_while:
continue
inputs = [get_value(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out # If some inputs are invalid without equivalent, so is out
if None in inputs: if None in inputs:
self.invalid.add(out) self.invalid.add(out)
return None continue
# If some inputs are invalid with equivalent, # If some inputs are invalid with equivalent,
# an equivalent out should be built and returned # an equivalent out should be built and returned
...@@ -895,10 +919,12 @@ class Validator(object): ...@@ -895,10 +919,12 @@ class Validator(object):
self.invalid.add(out) self.invalid.add(out)
self.valid.add(cloned_out) self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out self.valid_equivalent[out] = cloned_out
return cloned_out, False continue
# All inputs are valid, so is out # All inputs are valid, so is out
return out, True self.valid.add(out)
return get_value(out)
def scan_can_remove_outs(op, out_idxs): def scan_can_remove_outs(op, out_idxs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论