提交 1819b71d authored 作者: Tanjay94's avatar Tanjay94

Fixed small mistake in code.

上级 94148810
...@@ -37,7 +37,6 @@ from theano.tensor import sharedvar # adds shared-variable constructors ...@@ -37,7 +37,6 @@ from theano.tensor import sharedvar # adds shared-variable constructors
from theano.tensor.sharedvar import tensor_constructor as _shared from theano.tensor.sharedvar import tensor_constructor as _shared
from theano.tensor.io import * from theano.tensor.io import *
from theano.tensor import nlinalg
def shared(*args, **kw): def shared(*args, **kw):
......
...@@ -5,7 +5,7 @@ import numpy ...@@ -5,7 +5,7 @@ import numpy
from theano.gof import Op, Apply from theano.gof import Op, Apply
import theano.tensor from theano import tensor
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor.blas import Dot22 from theano.tensor.blas import Dot22
from theano.tensor.opt import (register_stabilize, from theano.tensor.opt import (register_stabilize,
...@@ -773,3 +773,17 @@ def norm(x,ord): ...@@ -773,3 +773,17 @@ def norm(x,ord):
elif ndim > 2: elif ndim > 2:
raise NotImplementedError("We don't support norm witn ndim > 2") raise NotImplementedError("We don't support norm witn ndim > 2")
"""
import theano
from theano import tensor as T
from theano import function
from theano.tensor import nlinalg
x = T.matrix()
y = nlinalg.norm(x,1)
"""
...@@ -34,7 +34,8 @@ from theano.tensor.nlinalg import ( MatrixInverse, ...@@ -34,7 +34,8 @@ from theano.tensor.nlinalg import ( MatrixInverse,
_zero_disconnected, _zero_disconnected,
qr, qr,
matrix_power, matrix_power,
norm norm,
svd
) )
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -409,7 +410,7 @@ class T_lstsq(unittest.TestCase): ...@@ -409,7 +410,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.lmatrix() x = tensor.lmatrix()
y = tensor.lmatrix() y = tensor.lmatrix()
z = tensor.lscalar() z = tensor.lscalar()
b = theano.sandbox.linalg.ops.lstsq()(x, y, z) b = theano.sandbox.linalg.lstsq()(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
TestMatrix1 = numpy.asarray([[2, 1], [3, 4]]) TestMatrix1 = numpy.asarray([[2, 1], [3, 4]])
TestMatrix2 = numpy.asarray([[17, 20], [43, 50]]) TestMatrix2 = numpy.asarray([[17, 20], [43, 50]])
...@@ -422,7 +423,7 @@ class T_lstsq(unittest.TestCase): ...@@ -422,7 +423,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.vector() x = tensor.vector()
y = tensor.vector() y = tensor.vector()
z = tensor.scalar() z = tensor.scalar()
b = theano.sandbox.linalg.ops.lstsq()(x, y, z) b = theano.sandbox.linalg.lstsq()(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
self.assertRaises(numpy.linalg.linalg.LinAlgError, f, [2, 1], [2, 1], 1) self.assertRaises(numpy.linalg.linalg.LinAlgError, f, [2, 1], [2, 1], 1)
...@@ -430,7 +431,7 @@ class T_lstsq(unittest.TestCase): ...@@ -430,7 +431,7 @@ class T_lstsq(unittest.TestCase):
x = tensor.vector() x = tensor.vector()
y = tensor.vector() y = tensor.vector()
z = tensor.vector() z = tensor.vector()
b = theano.sandbox.linalg.ops.lstsq()(x, y, z) b = theano.sandbox.linalg.lstsq()(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
self.assertRaises(numpy.linalg.LinAlgError, f, [2, 1], [2, 1], [2, 1]) self.assertRaises(numpy.linalg.LinAlgError, f, [2, 1], [2, 1], [2, 1])
......
...@@ -27,6 +27,12 @@ from nose.plugins.skip import SkipTest ...@@ -27,6 +27,12 @@ from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import assert_raises from nose.tools import assert_raises
try:
import scipy.linalg
imported_scipy = True
except ImportError:
# some ops (e.g. Cholesky, Solve, A_Xinv_b) won't work
imported_scipy = False
def check_lower_triangular(pd, ch_f): def check_lower_triangular(pd, ch_f):
ch = ch_f(pd) ch = ch_f(pd)
...@@ -179,7 +185,3 @@ class test_Solve(utt.InferShapeTester): ...@@ -179,7 +185,3 @@ class test_Solve(utt.InferShapeTester):
dtype=config.floatX)], dtype=config.floatX)],
self.op_class, self.op_class,
warn=False) warn=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论