提交 86be5809 authored 作者: Thomas George's avatar Thomas George

This fixes a bug in cusolver solve op when m > k

上级 09318e0d
......@@ -7,7 +7,7 @@ from theano import Op
from theano.gpuarray import basic_ops, GpuArrayType
try:
from pygpu import gpuarray
import pygpu
except ImportError:
pass
......@@ -109,12 +109,12 @@ class GpuCusolverSolve(Op):
raise ValueError('A and b must be aligned.')
lda = max(1, n)
ldb = max(1, k, m)
ldb = max(1, k)
# We copy A and b as cusolver operates inplace
b = gpuarray.array(b, copy=True, order='F')
b = pygpu.array(b, copy=True, order='F')
if not self.inplace:
A = gpuarray.array(A, copy=True)
A = pygpu.array(A, copy=True)
A_ptr = A.gpudata
b_ptr = b.gpudata
......@@ -129,19 +129,19 @@ class GpuCusolverSolve(Op):
if (thunk.workspace is None or
thunk.workspace.size != workspace_size):
thunk.workspace = gpuarray.zeros((workspace_size,),
dtype='float32',
context=context)
if thunk.pivots is None or thunk.pivots.size != min(n, n):
thunk.pivots = gpuarray.zeros((min(n, n),),
thunk.workspace = pygpu.zeros(workspace_size,
dtype='float32',
context=context)
if thunk.pivots is None or thunk.pivots.size != min(n, n):
thunk.pivots = pygpu.zeros(n,
dtype='int32',
context=context)
if thunk.dev_info is None:
thunk.dev_info = gpuarray.zeros((1,),
dtype='float32',
context=context)
thunk.dev_info = pygpu.zeros((1,),
dtype='int32',
context=context)
workspace_ptr = thunk.workspace.gpudata
pivots_ptr = thunk.pivots.gpudata
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论