提交 0582554d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix crash in GpuGemv.perform with zero-size inputs

上级 5cb343e6
......@@ -73,8 +73,12 @@ class GpuGemv(BlasOp):
inplace = self.inplace
if inplace and y.strides[0] < 0:
inplace = False
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y,
overwrite_y=inplace)
if A.shape[1] == 0:
out_storage[0][0] = pygpu.zeros(y.shape, dtype=y.dtype,
context=y.context)
else:
out_storage[0][0] = blas.gemv(alpha, A, x, beta, y,
overwrite_y=inplace)
def c_code(self, node, name, inp, out, sub):
vars = dict(out=out[0], y=inp[0], alpha=inp[1], A=inp[2], x=inp[3],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论