提交 b7b88b1c authored 作者: Frederic's avatar Frederic

Make gpugemv and gpugemm work correctly when the back-end don't support inplace.

add assert too.
上级 2590560c
...@@ -28,12 +28,16 @@ class GpuGemv(BlasOp, Gemv): ...@@ -28,12 +28,16 @@ class GpuGemv(BlasOp, Gemv):
A = as_gpuarray_variable(A) A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x) x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y) y = as_gpuarray_variable(y)
assert A.dtype == x.dtype == y.dtype == alpha.dtype == beta.dtype
return Apply(self, [y, alpha, A, x, beta], [y.type()]) return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage): def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs y, alpha, A, x, beta = inputs
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y, trans=False, inplace = self.inplace
overwrite_y=self.inplace) if inplace and y.strides[0] < 0:
inplace = False
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y,
overwrite_y=inplace)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3], vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3],
...@@ -64,7 +68,7 @@ class GpuGemv(BlasOp, Gemv): ...@@ -64,7 +68,7 @@ class GpuGemv(BlasOp, Gemv):
if config.gpuarray.sync: if config.gpuarray.sync:
code += """ code += """
GpuArray_sync(&%(out)s->ga); GpuArray_sync(&%(out)s->ga);
""" """ % vars
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -80,12 +84,16 @@ class GpuGemm(BlasOp, Gemm): ...@@ -80,12 +84,16 @@ class GpuGemm(BlasOp, Gemm):
A = as_gpuarray_variable(A) A = as_gpuarray_variable(A)
B = as_gpuarray_variable(B) B = as_gpuarray_variable(B)
C = as_gpuarray_variable(C) C = as_gpuarray_variable(C)
assert A.dtype == B.dtype == C.dtype == alpha.dtype == beta.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()]) return Apply(self, [C, alpha, A, B, beta], [C.type()])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
C, alpha, A, B, beta = inputs C, alpha, A, B, beta = inputs
inplace = self.inplace
if inplace and not C.flags.forc:
inplace = False
outputs[0][0] = blas.gemm(alpha, A, B, beta, C, outputs[0][0] = blas.gemm(alpha, A, B, beta, C,
overwrite_c=self.inplace) overwrite_c=inplace)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], C=inp[0], alpha=inp[1], A=inp[2], B=inp[3], vars = dict(out=out[0], C=inp[0], alpha=inp[1], A=inp[2], B=inp[3],
...@@ -116,7 +124,7 @@ class GpuGemm(BlasOp, Gemm): ...@@ -116,7 +124,7 @@ class GpuGemm(BlasOp, Gemm):
if config.gpuarray.sync: if config.gpuarray.sync:
code += """ code += """
GpuArray_sync(&%(out)s->ga); GpuArray_sync(&%(out)s->ga);
""" """ % vars
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论