提交 27a1361d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix warning and add dtype check in ops that only support float32

上级 99363898
......@@ -261,7 +261,7 @@ class GpuCholesky(Op):
raise RuntimeError('CUSOLVER is not available and '
'GpuCholesky Op can not be constructed.')
if skcuda.__version__ <= '0.5.1':
warnings.warn('The GpuSolve op requires scikit-cuda > 0.5.1 to work with CUDA 8')
warnings.warn('The GpuCholesky op requires scikit-cuda > 0.5.1 to work with CUDA 8')
if not pygpu_available:
raise RuntimeError('Missing pygpu or triu/tril functions.'
'Install or update libgpuarray.')
......@@ -382,6 +382,7 @@ class GpuMagmaSVD(COp):
A = as_gpuarray_variable(A, ctx_name)
if A.ndim != 2:
raise LinAlgError("Matrix rank error")
assert A.dtype == 'float32'
if self.compute_uv:
return theano.Apply(self, [A],
[A.type(),
......@@ -476,6 +477,7 @@ class GpuMagmaMatrixInverse(COp):
def make_node(self, x):
ctx_name = infer_context_name(x)
x = as_gpuarray_variable(x, ctx_name)
assert x.dtype == 'float32'
if x.ndim != 2:
raise LinAlgError("Matrix rank error")
return theano.Apply(self, [x], [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论