提交 aa9590f9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Only tranfer to gpu if dtype is float32

上级 86965f6d
......@@ -404,6 +404,7 @@ def gpu_images2neibs(ten4, neib_shape, neib_step=None, mode='valid'):
@local_optimizer()
def use_gpu_images2neibs(node):
if (type(node.op) is Images2Neibs and
node.inputs[0].dtype == 'float32' and
node.op.mode in ['valid', 'wrap_centered']):
return [host_from_gpu(gpu_images2neibs(gpu_from_host(node.inputs[0]),
node.inputs[1], node.inputs[2],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论