提交 2a4ef8da authored 作者: Caglar's avatar Caglar

added lamblin's changes + flake8

上级 4dffed73
......@@ -3,6 +3,7 @@ from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda import GpuOp
from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable
from theano.sandbox.cuda import cuda_ndarray
dimshuffle = cuda_ndarray.cuda_ndarray.dimshuffle
......@@ -19,19 +20,15 @@ class GpuSolve(GpuOp):
"""
CULA GPU solver OP.
trans: Whether to take the transpose of the input matrix
:param trans: Whether to take the transpose of the input matrix
or not.
"""
__props__ = ('trans')
def __init__(self, trans='N'):
self.trans = trans
super(GpuSolve, self).__init__()
def __eq__(self, other):
return (type(other) == type(self))
def __hash__(self):
return hash(type(self))
def output_type(self, inp):
return CudaNdarrayType(broadcastable=[False] * inp.type.ndim)
......@@ -39,8 +36,6 @@ class GpuSolve(GpuOp):
inp1 = as_cuda_ndarray_variable(inp1)
inp2 = as_cuda_ndarray_variable(inp2)
assert inp1.dtype == "float32"
assert inp2.dtype == "float32"
assert inp1.ndim == 2
assert inp2.ndim == 2
return theano.Apply(self, [inp1, inp2], [self.output_type(inp1)()])
......@@ -49,7 +44,6 @@ class GpuSolve(GpuOp):
node,
storage_map, _,
no_recycling=[]):
from theano.misc.pycuda_utils import to_gpuarray
# Initialize CULA the first time it is needed
global cula_initialized
......@@ -85,9 +79,6 @@ class GpuSolve(GpuOp):
A_cpy = A.copy()
b_cpy = b_cpy.copy()
A_pycuda = to_gpuarray(A_cpy)
b_pycuda = to_gpuarray(b_cpy)
def cula_gpu_solve(A_, b_, trans='T'):
A_shape = A_.shape
......@@ -120,9 +111,9 @@ class GpuSolve(GpuOp):
cula.culaDeviceSgels(trans, n, l, m, A_ptr, lda, b_ptr, ldb)
return A_, b_
A_pycuda, b_pycuda = cula_gpu_solve(A_pycuda, b_pycuda, trans)
A_pycuda, b_pycuda = cula_gpu_solve(A_cpy, b_cpy, trans)
#Convert b to F-order from c-order and assign it to output:
# Convert b to F-order from c-order and assign it to output:
b_cpy = b_cpy.reshape(b.shape[::-1])
b_cpy = dimshuffle(b_cpy, (1, 0))
z[0] = b_cpy
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论