提交 2906e215 authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Throw error if input is not supported

上级 da8c2a12
...@@ -415,7 +415,8 @@ class GpuMagmaSVD(GpuMagmaBase): ...@@ -415,7 +415,8 @@ class GpuMagmaSVD(GpuMagmaBase):
A = gpu_contiguous(A) A = gpu_contiguous(A)
if A.ndim != 2: if A.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
assert A.dtype == 'float32' if A.dtype != 'float32':
raise TypeError("only `float32` is supported for now")
if self.compute_uv: if self.compute_uv:
return theano.Apply(self, [A], return theano.Apply(self, [A],
# return S, U, VT # return S, U, VT
...@@ -504,6 +505,8 @@ class GpuMagmaMatrixInverse(GpuMagmaBase): ...@@ -504,6 +505,8 @@ class GpuMagmaMatrixInverse(GpuMagmaBase):
A = gpu_contiguous(A) A = gpu_contiguous(A)
if A.ndim != 2: if A.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
if A.dtype != 'float32':
raise TypeError("only `float32` is supported for now")
return theano.Apply(self, [A], [A.type()]) return theano.Apply(self, [A], [A.type()])
def get_params(self, node): def get_params(self, node):
...@@ -548,6 +551,8 @@ class GpuMagmaCholesky(GpuMagmaBase, CGpuKernelBase): ...@@ -548,6 +551,8 @@ class GpuMagmaCholesky(GpuMagmaBase, CGpuKernelBase):
A = gpu_contiguous(A) A = gpu_contiguous(A)
if A.ndim != 2: if A.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
if A.dtype != 'float32':
raise TypeError("only `float32` is supported for now")
return theano.Apply(self, [A], [A.type()]) return theano.Apply(self, [A], [A.type()])
def get_op_params(self): def get_op_params(self):
...@@ -583,6 +588,8 @@ class GpuMagmaQR(GpuMagmaBase, CGpuKernelBase): ...@@ -583,6 +588,8 @@ class GpuMagmaQR(GpuMagmaBase, CGpuKernelBase):
A = gpu_contiguous(A) A = gpu_contiguous(A)
if A.ndim != 2: if A.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
if A.dtype != 'float32':
raise TypeError("only `float32` is supported for now")
if self.complete: if self.complete:
return theano.Apply(self, [A], [A.type(), A.type()]) return theano.Apply(self, [A], [A.type(), A.type()])
else: else:
...@@ -620,6 +627,8 @@ class GpuMagmaEigh(GpuMagmaBase): ...@@ -620,6 +627,8 @@ class GpuMagmaEigh(GpuMagmaBase):
A = gpu_contiguous(A) A = gpu_contiguous(A)
if A.ndim != 2: if A.ndim != 2:
raise LinAlgError("Matrix rank error") raise LinAlgError("Matrix rank error")
if A.dtype != 'float32':
raise TypeError("only `float32` is supported for now")
if self.compute_v: if self.compute_v:
return theano.Apply(self, [A], return theano.Apply(self, [A],
[GpuArrayType(A.dtype, broadcastable=[False], [GpuArrayType(A.dtype, broadcastable=[False],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论