提交 724d0d32 authored 作者: Frederic's avatar Frederic

Do not move scalar elemwise to the gpu.

上级 5e97ac68
...@@ -170,6 +170,8 @@ def local_gpuflatten(node): ...@@ -170,6 +170,8 @@ def local_gpuflatten(node):
def local_gpu_elemwise(node): def local_gpu_elemwise(node):
op = node.op op = node.op
name = op.name name = op.name
if node.outputs[0].ndim == 0:
return
if name: if name:
name = 'Gpu'+name name = 'Gpu'+name
res = GpuElemwise(op.scalar_op, name=name, res = GpuElemwise(op.scalar_op, name=name,
...@@ -242,11 +244,13 @@ def local_gpua_careduce(node): ...@@ -242,11 +244,13 @@ def local_gpua_careduce(node):
def local_gpua_gemv(node): def local_gpua_gemv(node):
return GpuGemv(inplace=node.op.inplace) return GpuGemv(inplace=node.op.inplace)
@register_opt() @register_opt()
@op_lifter([tensor.blas_c.CGemv]) @op_lifter([tensor.blas_c.CGemv])
def local_gpua_gemv2(node): def local_gpua_gemv2(node):
return GpuGemv(inplace=node.op.inplace) return GpuGemv(inplace=node.op.inplace)
@register_opt() @register_opt()
@op_lifter([tensor.blas.Gemm]) @op_lifter([tensor.blas.Gemm])
def local_gpua_gemm(node): def local_gpua_gemm(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论