提交 b096f1ff authored 作者: Frederic Bastien's avatar Frederic Bastien

fix cuda gemm optimizer following change in how gemm optimizer work.

上级 5fde8166
...@@ -148,10 +148,10 @@ def local_gpu_gemm(node): ...@@ -148,10 +148,10 @@ def local_gpu_gemm(node):
""" """
if node.op == gpu_from_host: if node.op == gpu_from_host:
host_input = node.inputs[0] host_input = node.inputs[0]
if host_input.owner and host_input.owner.op == tensor.blas.gemm: if host_input.owner and host_input.owner.op == tensor.blas.gemm_inplace:
z, a, x, y, b = host_input.owner.inputs z, a, x, y, b = host_input.owner.inputs
return [gpu_gemm(gpu_from_host(z), a, gpu_from_host(x), gpu_from_host(y), b)] return [gpu_gemm(gpu_from_host(z), a, gpu_from_host(x), gpu_from_host(y), b)]
if node.op == tensor.blas.gemm: if node.op == tensor.blas.gemm_inplace:
z, a, x, y, b = node.inputs z, a, x, y, b = node.inputs
x_on_gpu = (x.owner and x.owner.op == host_from_gpu) x_on_gpu = (x.owner and x.owner.op == host_from_gpu)
y_on_gpu = (y.owner and y.owner.op == host_from_gpu) y_on_gpu = (y.owner and y.owner.op == host_from_gpu)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论