提交 a4ba5401 authored 作者: Roy Xue's avatar Roy Xue

modify viewed_by check part

上级 d9c3b034
...@@ -675,9 +675,8 @@ class ProfileStats(object): ...@@ -675,9 +675,8 @@ class ProfileStats(object):
viewed_by = {}# {var1: [vars that view var1]} viewed_by = {}# {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for node in fgraph.apply_nodes: for var in fgraph.variables:
for var in node.inputs: viewed_by[var] = []
viewed_by[var] = []
view_of = {} # {var1: original var viewed by var1} view_of = {} # {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
...@@ -715,7 +714,7 @@ class ProfileStats(object): ...@@ -715,7 +714,7 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc # Mimic the combination of Theano and Python gc
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and ins in viewed_by) assert not (ins in view_of and viewed_by[ins])
# we keep trac of the original var, so this shouldn't happen # we keep trac of the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs and ins.owner: if dependencies[ins] and ins not in fgraph.outputs and ins.owner:
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
...@@ -774,9 +773,8 @@ class ProfileStats(object): ...@@ -774,9 +773,8 @@ class ProfileStats(object):
viewed_by = {}# {var1: [vars that view var1]} viewed_by = {}# {var1: [vars that view var1]}
# The len of the list is the value of python ref count. But we use a list, not just the ref count value. # The len of the list is the value of python ref count. But we use a list, not just the ref count value.
# This is more safe to help detect potential bug in the algo # This is more safe to help detect potential bug in the algo
for node in fgraph.apply_nodes: for var in fgraph.apply_nodes:
for var in node.inputs: viewed_by[var] = []
viewed_by[var] = []
view_of = {}# {var1: original var viewed by var1} view_of = {}# {var1: original var viewed by var1}
# The orignal mean that we don't keep trac of all the intermediate relationship in the view. # The orignal mean that we don't keep trac of all the intermediate relationship in the view.
...@@ -826,7 +824,7 @@ class ProfileStats(object): ...@@ -826,7 +824,7 @@ class ProfileStats(object):
# Mimic the combination of Theano and Python gc. # Mimic the combination of Theano and Python gc.
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and ins in viewed_by) assert not (ins in view_of and viewed_by[ins])
# we keep track of the original var, so this shouldn't happen # we keep track of the original var, so this shouldn't happen
if dependencies[ins] and ins not in fgraph.outputs and ins.owner: if dependencies[ins] and ins not in fgraph.outputs and ins.owner:
if all(compute_map[v] for v in dependencies[ins]): if all(compute_map[v] for v in dependencies[ins]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论