提交 d2de7531 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use isinstance() instead of equality.

上级 662ea98e
......@@ -194,7 +194,7 @@ def local_cut_gpu_transfers(node):
# gpu[ab] -> host -> gpub
if (isinstance(node.op, GpuFromHost) and
node.inputs[0].owner and
node.inputs[0].owner.op == host_from_gpu):
isinstance(node.inputs[0].owner.op, HostFromGpu)):
other = node.inputs[0].owner.inputs[0]
if node.op.context_name == other.type.context_name:
return [other]
......@@ -202,7 +202,7 @@ def local_cut_gpu_transfers(node):
return [GpuToGpu(node.op.context_name)(other)]
# ? -> gpua -> host
elif (node.op == host_from_gpu and
elif (isinstance(node.op, HostFromGpu) and
node.inputs[0].owner):
n2 = node.inputs[0].owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论