提交 50947b37 authored 作者: Tanjay94's avatar Tanjay94

Fixed qr function to fit numpy update.

上级 39e63787
...@@ -1021,8 +1021,8 @@ class QRFull(Op): ...@@ -1021,8 +1021,8 @@ class QRFull(Op):
""" """
_numop = staticmethod(numpy.linalg.qr) _numop = staticmethod(numpy.linalg.qr)
def __init__(self): def __init__(self, mode):
self.mode = "full" self.mode = mode
def __hash__(self): def __hash__(self):
return hash((type(self), self.props())) return hash((type(self), self.props()))
...@@ -1060,8 +1060,7 @@ class QRIncomplete(Op): ...@@ -1060,8 +1060,7 @@ class QRIncomplete(Op):
""" """
_numop = staticmethod(numpy.linalg.qr) _numop = staticmethod(numpy.linalg.qr)
def __init__(self, mode="raw"): def __init__(self, mode):
assert mode != "full"
self.mode = mode self.mode = mode
def __hash__(self): def __hash__(self):
...@@ -1132,8 +1131,8 @@ def qr(a, mode="full"): ...@@ -1132,8 +1131,8 @@ def qr(a, mode="full"):
""" """
x = [[2, 1], [3, 4]] x = [[2, 1], [3, 4]]
if hasattr(numpy.linalg.qr(x,mode), tuple): if isinstance(numpy.linalg.qr(x,mode), tuple):
return QRFull()(a) return QRFull(mode)(a)
else: else:
return QRIncomplete(mode)(a) return QRIncomplete(mode)(a)
......
...@@ -177,8 +177,8 @@ def test_matrix_dot(): ...@@ -177,8 +177,8 @@ def test_matrix_dot():
def test_qr(): def test_qr():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX) A = tensor.matrix("A", dtype=theano.config.floatX)
Q, R = qr(A) Q = qr(A)
fn = function([A], [Q, R]) fn = function([A], Q)
a = rng.rand(4, 4).astype(theano.config.floatX) a = rng.rand(4, 4).astype(theano.config.floatX)
n_q, n_r = numpy.linalg.qr(a) n_q, n_r = numpy.linalg.qr(a)
t_q, t_r = fn(a) t_q, t_r = fn(a)
...@@ -190,7 +190,7 @@ def test_qr_reduced(): ...@@ -190,7 +190,7 @@ def test_qr_reduced():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX) A = tensor.matrix("A", dtype=theano.config.floatX)
Q = qr(A, mode="reduced") Q = qr(A, mode="reduced")
fn = function([A], [Q]) fn = function([A], Q)
a = rng.rand(4, 4).astype(theano.config.floatX) a = rng.rand(4, 4).astype(theano.config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论