提交 e6b9c271 authored 作者: sentient07's avatar sentient07

Printing the nodes that aren't transferred

上级 658563c8
...@@ -62,6 +62,9 @@ gpu_optimizer = EquilibriumDB() ...@@ -62,6 +62,9 @@ gpu_optimizer = EquilibriumDB()
gpu_optimizer2 = EquilibriumDB() gpu_optimizer2 = EquilibriumDB()
gpu_cut_copies = EquilibriumDB() gpu_cut_copies = EquilibriumDB()
old_not_transferred = []
new_not_transferred = []
class GraphToGPUDB(DB): class GraphToGPUDB(DB):
""" """
...@@ -191,6 +194,8 @@ def op_lifter(OP, cuda_only=False): ...@@ -191,6 +194,8 @@ def op_lifter(OP, cuda_only=False):
i.tag.context_name = context_name i.tag.context_name = context_name
new_op = maker(node.op, context_name, node.inputs, node.outputs) new_op = maker(node.op, context_name, node.inputs, node.outputs)
if not new_op:
old_not_transferred.append(node)
# This is needed as sometimes new_op inherits from OP. # This is needed as sometimes new_op inherits from OP.
if new_op and new_op != node.op: if new_op and new_op != node.op:
if isinstance(new_op, theano.Op): if isinstance(new_op, theano.Op):
...@@ -342,6 +347,8 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -342,6 +347,8 @@ class GraphToGPU(NavigatorOptimizer):
if not new_ops: if not new_ops:
newnode = node.clone_with_new_inputs([mapping.get(i) newnode = node.clone_with_new_inputs([mapping.get(i)
for i in node.inputs]) for i in node.inputs])
new_not_transferred.append(newnode)
outputs = newnode.outputs outputs = newnode.outputs
elif isinstance(new_ops, (tuple, list)): elif isinstance(new_ops, (tuple, list)):
outputs = [] outputs = []
...@@ -406,6 +413,13 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -406,6 +413,13 @@ class GraphToGPU(NavigatorOptimizer):
not_used = [] not_used = []
not_used_time = 0 not_used_time = 0
for s in list(set(old_not_transferred)):
print(blanc, 'Nodes not transferred by old opt : ' + str(s), file=stream)
for n in list(set(new_not_transferred)):
print(blanc, 'Nodes not transferred by new optimizer : ' +str(n), file=stream)
for d in list(set(set(new_not_transferred) - set(old_not_transferred))):
print(blanc, 'Not transferred difference : ' , str(d), file=stream)
for o, count in iteritems(process_count): for o, count in iteritems(process_count):
if count > 0: if count > 0:
count_opt.append((time_opts[o], count, count_opt.append((time_opts[o], count,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论