提交 8a243e89 authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Fix package import order and rank check

上级 8f8dd922
......@@ -373,8 +373,7 @@ AddConfigVar('magma.library_path',
AddConfigVar('magma.enabled',
" If True, use magma for matrix computation."
" If False, disable magma"
" If no_check, assume present and the version between header and library match (so less compilation at context init)",
" If False, disable magma",
BoolParam(False),
in_c_key=False)
......
......@@ -3,8 +3,8 @@ from __future__ import absolute_import, division, print_function
import os
import warnings
import numpy as np
import pkg_resources
import numpy as np
from numpy.linalg.linalg import LinAlgError
import theano
......@@ -379,10 +379,10 @@ class GpuMagmaSVD(COp):
return []
def make_node(self, A):
if A.ndim != 2:
raise LinAlgError("Matrix rank error")
ctx_name = infer_context_name(A)
A = as_gpuarray_variable(A, ctx_name)
if A.ndim != 2:
raise LinAlgError("Matrix rank error")
return theano.Apply(self, [A],
[A.type(),
GpuArrayType(A.dtype, broadcastable=[False],
......@@ -452,10 +452,10 @@ class GpuMagmaMatrixInverse(COp):
return []
def make_node(self, x):
if x.ndim != 2:
raise LinAlgError("Matrix rank error")
ctx_name = infer_context_name(x)
x = as_gpuarray_variable(x, ctx_name)
if x.ndim != 2:
raise LinAlgError("Matrix rank error")
return theano.Apply(self, [x], [x.type()])
def get_params(self, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论