提交 7c4c87f8 authored 作者: Thomas George's avatar Thomas George

Fix previous commit

Added test for case for b
上级 5622fc2a
......@@ -61,7 +61,7 @@ class GpuCusolverSolve(Op):
self, [inp1, inp2],
[GpuArrayType('float32',
broadcastable=inp1.broadcastable,
context_name=self.context)()])
context_name=context_name)()])
def prepare_node(self, node, storage_map, compute_map, impl):
ctx = node.inputs[0].type.context
......@@ -118,7 +118,7 @@ class GpuCusolverSolve(Op):
with context:
workspace_size = cusolver.cusolverDnSgetrf_bufferSize(
cusolver_handle, n, n, A_ptr, lda)
context.cusolver_handle, n, n, A_ptr, lda)
workspace = pygpu.zeros(workspace_size, dtype='float32',
context=context)
......@@ -127,17 +127,17 @@ class GpuCusolverSolve(Op):
dev_info = pygpu.zeros((1,), dtype='int32', context=context)
workspace_ptr = thunk.workspace.gpudata
pivots_ptr = thunk.pivots.gpudata
dev_info_ptr = thunk.dev_info.gpudata
workspace_ptr = workspace.gpudata
pivots_ptr = pivots.gpudata
dev_info_ptr = dev_info.gpudata
with context:
cusolver.cusolverDnSgetrf(
cusolver_handle, n, n, A_ptr, lda, workspace_ptr,
context.cusolver_handle, n, n, A_ptr, lda, workspace_ptr,
pivots_ptr, dev_info_ptr)
cusolver.cusolverDnSgetrs(
cusolver_handle, trans, n, m, A_ptr, lda,
context.cusolver_handle, trans, n, m, A_ptr, lda,
pivots_ptr, b_ptr, ldb, dev_info_ptr)
z[0] = b
......
......@@ -41,6 +41,17 @@ class TestCusolver(unittest.TestCase):
1)).astype("float32")
self.run_gpu_solve(A_val, x_val)
def test_bshape_solve(self):
"""
Test when shape of b (k, m) is such as m > k
"""
numpy.random.seed(1)
A_val = numpy.asarray([[2, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype="float32")
x_val = numpy.random.uniform(-0.4, 0.4, (A_val.shape[1],
A_val.shape[1] + 1)).astype("float32")
self.run_gpu_solve(A_val, x_val)
def test_sym_solve(self):
numpy.random.seed(1)
A_val = numpy.random.uniform(-0.4, 0.4, (5, 5)).astype("float32")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论