提交 9f36d45a authored 作者: Caglar's avatar Caglar

fixed fortran order outputs.

上级 7c95b025
......@@ -55,7 +55,6 @@ class GpuSolve(GpuOp):
node,
storage_map, _,
no_recycling=[]):
from theano.misc.pycuda_utils import to_gpuarray
inputs = [storage_map[v] for v in node.inputs]
......@@ -73,11 +72,14 @@ class GpuSolve(GpuOp):
#Solution vectors
b = inputs[1][0]
b = cuda_ndarray.dimshuffle(b, 1, 0)
A_cpy = A.copy()
b_cpy = b.copy()
A_pycuda = to_gpuarray(A)
b_pycuda = to_gpuarray(b)
#Convert b to F-order from c-order.
b_cpy = b_cpy.dimshuffle(1, 0).reshape((b.shape[0], b.shape[1]))
A_pycuda = to_gpuarray(A_cpy)
b_pycuda = to_gpuarray(b_cpy)
def cula_gpu_solve(A_, b_, trans='T'):
......@@ -90,7 +92,7 @@ class GpuSolve(GpuOp):
if trans in ['T', 'C']:
l, n = A_shape
k, m = b_shape
if n != m:
if n != k:
raise ValueError('A and b must be aligned.')
elif trans in ['N']:
n, l = A_shape
......@@ -110,10 +112,12 @@ class GpuSolve(GpuOp):
b_ptr = b_.gpudata
cula.culaDeviceSgels(trans, n, l, m, A_ptr, lda, b_ptr, ldb)
return A, b
return A_, b_
A_pycuda, b_pycuda = cula_gpu_solve(A_pycuda, b_pycuda, self.trans)
z[0] = b
#Convert b to F-order from c-order and assign it to output:
z[0] = b_cpy.reshape((b.shape[0], b.shape[1])).dimshuffle(1, 0)
thunk.inputs = inputs
thunk.outputs = outputs
......
......@@ -28,7 +28,6 @@ class TestCula(unittest.TestCase):
def run_gpu_solve(self, A_val, x_val):
b_val = numpy.dot(A_val, x_val)
b_val = b_val.T.reshape((b_val.shape[0], b_val.shape[1]))
A = theano.tensor.matrix("A", dtype="float32")
b = theano.tensor.matrix("b", dtype="float32")
......@@ -36,7 +35,6 @@ class TestCula(unittest.TestCase):
fn = theano.function([A, b], [solver])
res = fn(A_val, b_val)
x_res = numpy.array(res[0])
x_res = x_res.reshape((x_res.shape[1], x_res.shape[0])).T
utt.assert_allclose(x_res, x_val)
def test_diag_solve(self):
......
......@@ -555,7 +555,7 @@ def test_local_gpu_solve():
assert numpy.allclose(numpy.dot(a0, out), b0)
cmp((6, 6), (6, 1))
cmp((5, 5), (5, 3))
cmp((5, 5), (5, 1))
def test_local_gpu_dot_to_dot22dot():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论