提交 d15bce99 authored 作者: Tanjay94's avatar Tanjay94

Fixed import cycle in linalg.

上级 e04e4706
...@@ -69,8 +69,6 @@ FancyModule = Module ...@@ -69,8 +69,6 @@ FancyModule = Module
from theano.printing import pprint, pp from theano.printing import pprint, pp
from theano import tensor
from theano.scan_module import scan, map, reduce, foldl, foldr, clone from theano.scan_module import scan, map, reduce, foldl, foldr, clone
from theano.updates import Updates, OrderedUpdates from theano.updates import Updates, OrderedUpdates
......
...@@ -4945,8 +4945,7 @@ class Diagonal(Op): ...@@ -4945,8 +4945,7 @@ class Diagonal(Op):
def diagonal(a, offset=0, axis1=0, axis2=1): def diagonal(a, offset=0, axis1=0, axis2=1):
if (offset, axis1, axis2) == (0, 0, 1): if (offset, axis1, axis2) == (0, 0, 1):
from theano.tensor.nlinalg import extract_diag return theano.tensor.nlinalg.extract_diag(a)
return extract_diag(a)
return Diagonal(offset, axis1, axis2)(a) return Diagonal(offset, axis1, axis2)(a)
......
import logging import logging
import theano
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import numpy import numpy
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano import tensor
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor.blas import Dot22 from theano.tensor.blas import Dot22
from theano.tensor.opt import (register_stabilize, from theano.tensor.opt import (register_stabilize,
...@@ -13,6 +13,8 @@ from theano.tensor.opt import (register_stabilize, ...@@ -13,6 +13,8 @@ from theano.tensor.opt import (register_stabilize,
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.tensor import basic
tensor = basic
class MatrixPinv(Op): class MatrixPinv(Op):
......
...@@ -410,7 +410,7 @@ class T_lstsq(unittest.TestCase): ...@@ -410,7 +410,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.lmatrix() x = tensor.lmatrix()
y = tensor.lmatrix() y = tensor.lmatrix()
z = tensor.lscalar() z = tensor.lscalar()
b = theano.sandbox.linalg.lstsq()(x, y, z) b = theano.tensor.nlinalg.lstsq()(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
TestMatrix1 = numpy.asarray([[2, 1], [3, 4]]) TestMatrix1 = numpy.asarray([[2, 1], [3, 4]])
TestMatrix2 = numpy.asarray([[17, 20], [43, 50]]) TestMatrix2 = numpy.asarray([[17, 20], [43, 50]])
...@@ -423,7 +423,7 @@ class T_lstsq(unittest.TestCase): ...@@ -423,7 +423,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.vector() x = tensor.vector()
y = tensor.vector() y = tensor.vector()
z = tensor.scalar() z = tensor.scalar()
b = theano.sandbox.linalg.lstsq()(x, y, z) b = theano.tensor.nlinalg.lstsq()(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
self.assertRaises(numpy.linalg.linalg.LinAlgError, f, [2, 1], [2, 1], 1) self.assertRaises(numpy.linalg.linalg.LinAlgError, f, [2, 1], [2, 1], 1)
...@@ -431,7 +431,7 @@ class T_lstsq(unittest.TestCase): ...@@ -431,7 +431,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.vector() x = tensor.vector()
y = tensor.vector() y = tensor.vector()
z = tensor.vector() z = tensor.vector()
b = theano.sandbox.linalg.lstsq()(x, y, z) b = theano.tensor.nlinalg.lstsq()(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
self.assertRaises(numpy.linalg.LinAlgError, f, [2, 1], [2, 1], [2, 1]) self.assertRaises(numpy.linalg.LinAlgError, f, [2, 1], [2, 1], [2, 1])
......
...@@ -9,7 +9,6 @@ from theano.gof import Constant, Variable ...@@ -9,7 +9,6 @@ from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
from theano.tensor.utils import hash_from_ndarray from theano.tensor.utils import hash_from_ndarray
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
from theano.tensor import nlinalg
class AsTensorError(TypeError): class AsTensorError(TypeError):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论