提交 27a1361d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix warning and add dtype check in ops that only support float32

上级 99363898
...@@ -261,7 +261,7 @@ class GpuCholesky(Op): ...@@ -261,7 +261,7 @@ class GpuCholesky(Op):
raise RuntimeError('CUSOLVER is not available and ' raise RuntimeError('CUSOLVER is not available and '
'GpuCholesky Op can not be constructed.') 'GpuCholesky Op can not be constructed.')
if skcuda.__version__ <= '0.5.1': if skcuda.__version__ <= '0.5.1':
warnings.warn('The GpuSolve op requires scikit-cuda > 0.5.1 to work with CUDA 8') warnings.warn('The GpuCholesky op requires scikit-cuda > 0.5.1 to work with CUDA 8')
if not pygpu_available: if not pygpu_available:
raise RuntimeError('Missing pygpu or triu/tril functions.' raise RuntimeError('Missing pygpu or triu/tril functions.'
'Install or update libgpuarray.') 'Install or update libgpuarray.')
...@@ -382,6 +382,7 @@ class GpuMagmaSVD(COp): ...@@ -382,6 +382,7 @@ class GpuMagmaSVD(COp):
A = as_gpuarray_variable(A, ctx_name) A = as_gpuarray_variable(A, ctx_name)
if A.ndim != 2: if A.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
assert A.dtype == 'float32'
if self.compute_uv: if self.compute_uv:
return theano.Apply(self, [A], return theano.Apply(self, [A],
[A.type(), [A.type(),
...@@ -476,6 +477,7 @@ class GpuMagmaMatrixInverse(COp): ...@@ -476,6 +477,7 @@ class GpuMagmaMatrixInverse(COp):
def make_node(self, x): def make_node(self, x):
ctx_name = infer_context_name(x) ctx_name = infer_context_name(x)
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
assert x.dtype == 'float32'
if x.ndim != 2: if x.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
return theano.Apply(self, [x], [x.type()]) return theano.Apply(self, [x], [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论