提交 b31fda26 authored 作者: Frederic's avatar Frederic

check all outputs clients.

上级 67448f81
...@@ -369,10 +369,12 @@ def local_gpu_lazy_ifelse(node): ...@@ -369,10 +369,12 @@ def local_gpu_lazy_ifelse(node):
""" """
if isinstance(node.op, theano.ifelse.IfElse) and not node.op.gpu: if isinstance(node.op, theano.ifelse.IfElse) and not node.op.gpu:
gpu_ifelse = theano.ifelse.IfElse(node.op.n_outs, gpu=True) gpu_ifelse = theano.ifelse.IfElse(node.op.n_outs, gpu=True)
outs_clients = reduce(list.__add__,
[out.clients for out in node.outputs])
if numpy.any([(i.owner and i.owner.op == host_from_gpu) if numpy.any([(i.owner and i.owner.op == host_from_gpu)
for i in node.inputs]) or numpy.any( for i in node.inputs]) or numpy.any(
[c != 'output' and c.op == gpu_from_host for c, idx [c != 'output' and c.op == gpu_from_host for c, idx
in node.outputs[0].clients]): in outs_clients]):
c = node.inputs[0] c = node.inputs[0]
outs = node.inputs[1:] outs = node.inputs[1:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论