提交 2d90d1f8 authored 作者: Roy Xue's avatar Roy Xue

Move at compute_map creation to speed things up

上级 023876a0
......@@ -774,6 +774,10 @@ class ProfileStats(object):
compute_map = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map[var][0] = 1
for var in node_list:
for val in var.inputs:
if isinstance(val, graph.Constant):
compute_map[val][0] = 1
def check_node_state(node):
"""
......@@ -784,10 +788,6 @@ class ProfileStats(object):
inputs = node.inputs
outputs = node.outputs
deps = inputs + node.destroy_dependencies
# TODO: Move at compute_map creation to speed things up.
for node in inputs:
if isinstance(node, graph.Constant):
compute_map[node][0] = 1
computed_ins = all(compute_map[v][0] for v in deps)
return computed_ins
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论