提交 2d90d1f8 authored 作者: Roy Xue's avatar Roy Xue

Move at compute_map creation to speed things up

上级 023876a0
...@@ -774,6 +774,10 @@ class ProfileStats(object): ...@@ -774,6 +774,10 @@ class ProfileStats(object):
compute_map = defaultdict(lambda: [0]) compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs: for var in fgraph.inputs:
compute_map[var][0] = 1 compute_map[var][0] = 1
for var in node_list:
for val in var.inputs:
if isinstance(val, graph.Constant):
compute_map[val][0] = 1
def check_node_state(node): def check_node_state(node):
""" """
...@@ -784,10 +788,6 @@ class ProfileStats(object): ...@@ -784,10 +788,6 @@ class ProfileStats(object):
inputs = node.inputs inputs = node.inputs
outputs = node.outputs outputs = node.outputs
deps = inputs + node.destroy_dependencies deps = inputs + node.destroy_dependencies
# TODO: Move at compute_map creation to speed things up.
for node in inputs:
if isinstance(node, graph.Constant):
compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in deps) computed_ins = all(compute_map[v][0] for v in deps)
return computed_ins return computed_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论