提交 ee29c769 authored 作者: Frederic Bastien's avatar Frederic Bastien

make local_gpu_solve only work on float32

上级 7693a0f6
...@@ -700,6 +700,8 @@ def local_gpu_solve(node): ...@@ -700,6 +700,8 @@ def local_gpu_solve(node):
CpuSolve(host_from_gpu) -> host_from_gpu(GpuSolve) CpuSolve(host_from_gpu) -> host_from_gpu(GpuSolve)
""" """
if node.outputs[0].dtype != 'float32':
return
if isinstance(node.op, GpuFromHost): if isinstance(node.op, GpuFromHost):
host_input = node.inputs[0] host_input = node.inputs[0]
if (host_input.owner and if (host_input.owner and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论