提交 2b09eb72 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2916 from carriepl/fix_scan_while_on_gpuarray

In local_scan.to_gpua, don't transfer as_while condition on gpu
...@@ -713,7 +713,14 @@ def local_scan_to_gpua(node): ...@@ -713,7 +713,14 @@ def local_scan_to_gpua(node):
nw_ins += node.inputs[b:e] nw_ins += node.inputs[b:e]
nw_ins += [safe_to_gpu(x) for x in node.inputs[e:]] nw_ins += [safe_to_gpu(x) for x in node.inputs[e:]]
scan_ins = [tensor_to_gpu(x) for x in node.op.inputs] scan_ins = [tensor_to_gpu(x) for x in node.op.inputs]
scan_outs = [safe_to_gpu(x) for x in node.op.outputs]
# The inner output corresponding to the looping condition should not be
# moved to the gpu
if node.op.info['as_while']:
scan_outs = [safe_to_gpu(x) for x in node.op.outputs[:-1]]
scan_outs += [node.op.outputs[-1]]
else:
scan_outs = [safe_to_gpu(x) for x in node.op.outputs]
scan_outs = scan_utils.clone( scan_outs = scan_utils.clone(
scan_outs, scan_outs,
replace=zip(node.op.inputs, replace=zip(node.op.inputs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论