提交 b7b88b1c authored 作者: Frederic's avatar Frederic

Make gpugemv and gpugemm work correctly when the back-end don't support inplace.

add assert too.
上级 2590560c
......@@ -28,12 +28,16 @@ class GpuGemv(BlasOp, Gemv):
A = as_gpuarray_variable(A)
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
assert A.dtype == x.dtype == y.dtype == alpha.dtype == beta.dtype
return Apply(self, [y, alpha, A, x, beta], [y.type()])
def perform(self, node, inputs, out_storage):
y, alpha, A, x, beta = inputs
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y, trans=False,
overwrite_y=self.inplace)
inplace = self.inplace
if inplace and y.strides[0] < 0:
inplace = False
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y,
overwrite_y=inplace)
def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3],
......@@ -64,7 +68,7 @@ class GpuGemv(BlasOp, Gemv):
if config.gpuarray.sync:
code += """
GpuArray_sync(&%(out)s->ga);
"""
""" % vars
return code
def c_code_cache_version(self):
......@@ -80,12 +84,16 @@ class GpuGemm(BlasOp, Gemm):
A = as_gpuarray_variable(A)
B = as_gpuarray_variable(B)
C = as_gpuarray_variable(C)
assert A.dtype == B.dtype == C.dtype == alpha.dtype == beta.dtype
return Apply(self, [C, alpha, A, B, beta], [C.type()])
def perform(self, node, inputs, outputs):
C, alpha, A, B, beta = inputs
inplace = self.inplace
if inplace and not C.flags.forc:
inplace = False
outputs[0][0] = blas.gemm(alpha, A, B, beta, C,
overwrite_c=self.inplace)
overwrite_c=inplace)
def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], C=inp[0], alpha=inp[1], A=inp[2], B=inp[3],
......@@ -116,7 +124,7 @@ class GpuGemm(BlasOp, Gemm):
if config.gpuarray.sync:
code += """
GpuArray_sync(&%(out)s->ga);
"""
""" % vars
return code
def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论