提交 2c845d5f authored 作者: Alassane Ndiaye's avatar Alassane Ndiaye

Small fixes to the svd and lstsq ops

上级 92ad39a1
......@@ -657,7 +657,7 @@ class SVD(Op):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix."
w = theano.tensor.matrix(dtype=x.dtype)
u = theano.tensor.matrix(dtype=x.dtype)
u = theano.tensor.vector(dtype=x.dtype)
v = theano.tensor.matrix(dtype=x.dtype)
return Apply(self, [x], [w, u, v])
......@@ -718,16 +718,18 @@ class lstsq(Op):
x = theano.tensor.as_tensor_variable(x)
y = theano.tensor.as_tensor_variable(y)
rcond = theano.tensor.as_tensor_variable(rcond)
return theano.Apply(self, [x, y, rcond], [y.type(), theano.tensor.dvector(), theano.tensor.lscalar(), theano.tensor.dvector()])
return theano.Apply(self, [x, y, rcond],
[theano.tensor.matrix(), theano.tensor.dvector(),
theano.tensor.lscalar(), theano.tensor.dvector()])
def perform(self, node, inputs, outputs):
x = inputs[0]
y = inputs[1]
rcond = inputs[2]
zz = numpy.linalg.lstsq(inputs[0], inputs[1], inputs[2])
zz = numpy.linalg.lstsq(inputs[0], inputs[1], inputs[2])
outputs[0][0] = zz[0]
outputs[1][0] = zz[1]
outputs[2][0] = zz[2]
outputs[2][0] = numpy.array(zz[2])
outputs[3][0] = zz[3]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论