提交 3c297ecc authored 作者: Frederic's avatar Frederic

make op_lifter work for Op with multiple output.

上级 4c62e54f
......@@ -60,7 +60,6 @@ def op_lifter(OP):
def local_opt(node):
if type(node.op) in OP:
# This does not support nodes that have more than one output.
assert len(node.outputs) == 1
# either one of our inputs is on the gpu or
# all of our client are on the gpu
if (any([i.owner and i.owner.op == host_from_gpu
......@@ -71,7 +70,9 @@ def op_lifter(OP):
# This is needed as sometimes new_op inherit from OP.
if new_op and new_op != node.op:
if isinstance(new_op, theano.Op):
return [host_from_gpu(new_op(*node.inputs))]
return [host_from_gpu(o) for o in new_op(*node.inputs, return_list=True)]
elif isinstance(new_op, (tuple, list)):
return [host_from_gpu(o) for o in new_op]
else: # suppose it is a variable on the GPU
return [host_from_gpu(new_op)]
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论