提交 c2fdaad0 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

Solve.infer_shape now handles 1d array case

上级 2d5f3b14
......@@ -655,9 +655,15 @@ class Solve(Op):
#TODO: use the A_structure to go faster
output_storage[0][0] = scipy.linalg.solve(A, b)
# computes shape of x where x = inv(A) * b
def infer_shape(self, node, shapes):
(Ar, Ac), (Br, Bc) = shapes
return [(Ac, Bc)]
Ashape, Bshape = shapes
rows = Ashape[1]
if len(Bshape) == 1: # b is a Vector
return [(rows,)]
else:
cols = Bshape[1] # b is a Matrix
return [(rows, cols)]
solve = Solve() # general solve
......
......@@ -456,3 +456,14 @@ class test_Solve(utt.InferShapeTester):
numpy.asarray(rng.rand(5, 1),
dtype=config.floatX)],
self.op_class)
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.vector()
self._compile_and_check([A, b], # theano.function inputs
[self.op(A, b)], # theano.function outputs
# A must be square
[numpy.asarray(rng.rand(5, 5),
dtype=config.floatX),
numpy.asarray(rng.rand(5),
dtype=config.floatX)],
self.op_class)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论