提交 8a243e89 authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Fix package import order and rank check

上级 8f8dd922
...@@ -373,8 +373,7 @@ AddConfigVar('magma.library_path', ...@@ -373,8 +373,7 @@ AddConfigVar('magma.library_path',
AddConfigVar('magma.enabled', AddConfigVar('magma.enabled',
" If True, use magma for matrix computation." " If True, use magma for matrix computation."
" If False, disable magma" " If False, disable magma",
" If no_check, assume present and the version between header and library match (so less compilation at context init)",
BoolParam(False), BoolParam(False),
in_c_key=False) in_c_key=False)
......
...@@ -3,8 +3,8 @@ from __future__ import absolute_import, division, print_function ...@@ -3,8 +3,8 @@ from __future__ import absolute_import, division, print_function
import os import os
import warnings import warnings
import numpy as np
import pkg_resources import pkg_resources
import numpy as np
from numpy.linalg.linalg import LinAlgError from numpy.linalg.linalg import LinAlgError
import theano import theano
...@@ -379,10 +379,10 @@ class GpuMagmaSVD(COp): ...@@ -379,10 +379,10 @@ class GpuMagmaSVD(COp):
return [] return []
def make_node(self, A): def make_node(self, A):
if A.ndim != 2:
raise LinAlgError("Matrix rank error")
ctx_name = infer_context_name(A) ctx_name = infer_context_name(A)
A = as_gpuarray_variable(A, ctx_name) A = as_gpuarray_variable(A, ctx_name)
if A.ndim != 2:
raise LinAlgError("Matrix rank error")
return theano.Apply(self, [A], return theano.Apply(self, [A],
[A.type(), [A.type(),
GpuArrayType(A.dtype, broadcastable=[False], GpuArrayType(A.dtype, broadcastable=[False],
...@@ -452,10 +452,10 @@ class GpuMagmaMatrixInverse(COp): ...@@ -452,10 +452,10 @@ class GpuMagmaMatrixInverse(COp):
return [] return []
def make_node(self, x): def make_node(self, x):
if x.ndim != 2:
raise LinAlgError("Matrix rank error")
ctx_name = infer_context_name(x) ctx_name = infer_context_name(x)
x = as_gpuarray_variable(x, ctx_name) x = as_gpuarray_variable(x, ctx_name)
if x.ndim != 2:
raise LinAlgError("Matrix rank error")
return theano.Apply(self, [x], [x.type()]) return theano.Apply(self, [x], [x.type()])
def get_params(self, node): def get_params(self, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论