提交 94baaa7b authored 作者: James Bergstra's avatar James Bergstra

finished adding stuff to inplace gemm for cuda

上级 18c4e3a5
...@@ -149,9 +149,6 @@ class GpuGemm(Op): ...@@ -149,9 +149,6 @@ class GpuGemm(Op):
""" """
implement the gemm on the gpu. implement the gemm on the gpu.
..note: This probably don't work correctly for no_inplace gemm.
Need to check al least refcount.
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
...@@ -215,8 +212,8 @@ class GpuGemm(Op): ...@@ -215,8 +212,8 @@ class GpuGemm(Op):
) )
{ {
Py_XDECREF(%(z_out)s); Py_XDECREF(%(z_out)s);
%(z_out)s = CudaNdarray_Copy(%(z_in)s); %(z_out)s = (CudaNdarray*)CudaNdarray_Copy(%(z_in)s);
if (!(z_out)s) if (!%(z_out)s)
{ {
%(fail)s; %(fail)s;
} }
...@@ -236,7 +233,9 @@ class GpuGemm(Op): ...@@ -236,7 +233,9 @@ class GpuGemm(Op):
{ {
%(fail)s; %(fail)s;
} }
""" % locals() """
return sio.getvalue() % locals()
gpu_gemm_no_inplace = GpuGemm(inplace=False) gpu_gemm_no_inplace = GpuGemm(inplace=False)
gpu_gemm_inplace = GpuGemm(inplace=True) gpu_gemm_inplace = GpuGemm(inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论