提交 d15bce99 authored 作者: Tanjay94's avatar Tanjay94

Fixed import cycle in linalg.

上级 e04e4706
......@@ -69,8 +69,6 @@ FancyModule = Module
from theano.printing import pprint, pp
from theano import tensor
from theano.scan_module import scan, map, reduce, foldl, foldr, clone
from theano.updates import Updates, OrderedUpdates
......
......@@ -4945,8 +4945,7 @@ class Diagonal(Op):
def diagonal(a, offset=0, axis1=0, axis2=1):
if (offset, axis1, axis2) == (0, 0, 1):
from theano.tensor.nlinalg import extract_diag
return extract_diag(a)
return theano.tensor.nlinalg.extract_diag(a)
return Diagonal(offset, axis1, axis2)(a)
......
import logging
import theano
logger = logging.getLogger(__name__)
import numpy
from theano.gof import Op, Apply
from theano import tensor
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor.blas import Dot22
from theano.tensor.opt import (register_stabilize,
......@@ -13,6 +13,8 @@ from theano.tensor.opt import (register_stabilize,
from theano.gof import local_optimizer
from theano.gof.opt import Optimizer
from theano.gradient import DisconnectedType
from theano.tensor import basic
tensor = basic
class MatrixPinv(Op):
......
......@@ -410,7 +410,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.lmatrix()
y = tensor.lmatrix()
z = tensor.lscalar()
b = theano.sandbox.linalg.lstsq()(x, y, z)
b = theano.tensor.nlinalg.lstsq()(x, y, z)
f = function([x, y, z], b)
TestMatrix1 = numpy.asarray([[2, 1], [3, 4]])
TestMatrix2 = numpy.asarray([[17, 20], [43, 50]])
......@@ -423,7 +423,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.vector()
y = tensor.vector()
z = tensor.scalar()
b = theano.sandbox.linalg.lstsq()(x, y, z)
b = theano.tensor.nlinalg.lstsq()(x, y, z)
f = function([x, y, z], b)
self.assertRaises(numpy.linalg.linalg.LinAlgError, f, [2, 1], [2, 1], 1)
......@@ -431,7 +431,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.vector()
y = tensor.vector()
z = tensor.vector()
b = theano.sandbox.linalg.lstsq()(x, y, z)
b = theano.tensor.nlinalg.lstsq()(x, y, z)
f = function([x, y, z], b)
self.assertRaises(numpy.linalg.LinAlgError, f, [2, 1], [2, 1], [2, 1])
......
......@@ -9,7 +9,6 @@ from theano.gof import Constant, Variable
from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray
from theano.tensor.type import TensorType
from theano.tensor import nlinalg
class AsTensorError(TypeError):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论