提交 1da248ea authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Update to the blas ops following API change in compyte.

上级 5420caac
...@@ -61,7 +61,7 @@ class GpuGemv(BlasOp, Gemv): ...@@ -61,7 +61,7 @@ class GpuGemv(BlasOp, Gemv):
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0], ((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(x)s, %(A)s, %(x)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0], ((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s) == NULL) { %(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
""" % vars """ % vars
...@@ -72,7 +72,7 @@ class GpuGemv(BlasOp, Gemv): ...@@ -72,7 +72,7 @@ class GpuGemv(BlasOp, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (1,)
gpugemv_no_inplace = GpuGemv(inplace=False) gpugemv_no_inplace = GpuGemv(inplace=False)
gpugemv_inplace = GpuGemv(inplace=True) gpugemv_inplace = GpuGemv(inplace=True)
...@@ -117,7 +117,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -117,7 +117,7 @@ class GpuGemm(BlasOp, Gemm):
((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0], ((dtype_%(alpha)s *)PyArray_DATA(%(alpha)s))[0],
%(A)s, %(B)s, %(A)s, %(B)s,
((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0], ((dtype_%(beta)s *)PyArray_DATA(%(beta)s))[0],
%(out)s) == NULL) { %(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
""" % vars """ % vars
...@@ -128,7 +128,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -128,7 +128,7 @@ class GpuGemm(BlasOp, Gemm):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (1,)
gpugemm_no_inplace = GpuGemm(inplace=False) gpugemm_no_inplace = GpuGemm(inplace=False)
...@@ -176,7 +176,7 @@ class GpuDot22(BlasOp, Dot22): ...@@ -176,7 +176,7 @@ class GpuDot22(BlasOp, Dot22):
one, one,
%(A)s, %(B)s, %(A)s, %(B)s,
zero, zero,
%(out)s) == NULL) { %(out)s, 0) == -1) {
%(fail)s %(fail)s
} }
""" % vars """ % vars
...@@ -187,7 +187,7 @@ class GpuDot22(BlasOp, Dot22): ...@@ -187,7 +187,7 @@ class GpuDot22(BlasOp, Dot22):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (1,)
def c_headers(self): def c_headers(self):
ret = super(GpuDot22, self).c_headers() ret = super(GpuDot22, self).c_headers()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论