提交 030fbafb authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

enh: infer_shapes for MatrixInverse

上级 69c86f01
...@@ -300,6 +300,14 @@ def tag_solve_triangular(node): ...@@ -300,6 +300,14 @@ def tag_solve_triangular(node):
return [Solve('lower_triangular')(A, b)] return [Solve('lower_triangular')(A, b)]
else: else:
return [Solve('upper_triangular')(A, b)] return [Solve('upper_triangular')(A, b)]
if (isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs
if isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower:
return [Solve('upper_triangular')(A, b)]
else:
return [Solve('lower_triangular')(A, b)]
@register_canonicalize @register_canonicalize
......
...@@ -111,6 +111,9 @@ class MatrixInverse(Op): ...@@ -111,6 +111,9 @@ class MatrixInverse(Op):
return [None] return [None]
return [-matrix_dot(xi, ev, xi)] return [-matrix_dot(xi, ev, xi)]
def infer_shape(self, node, shapes):
return shapes
matrix_inverse = MatrixInverse() matrix_inverse = MatrixInverse()
......
...@@ -171,8 +171,15 @@ class Solve(Op): ...@@ -171,8 +171,15 @@ class Solve(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
#TODO: use the A_structure to go faster if self.A_structure == 'lower_triangular':
output_storage[0][0] = scipy.linalg.solve(A, b) rval = scipy.linalg.solve_triangular(
A, b, lower=True)
elif self.A_structure == 'upper_triangular':
rval = scipy.linalg.solve_triangular(
A, b, lower=False)
else:
rval = scipy.linalg.solve(A, b)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b # computes shape of x where x = inv(A) * b
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论