提交 030fbafb authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

enh: infer_shapes for MatrixInverse

上级 69c86f01
......@@ -264,8 +264,8 @@ def transinv_to_invtrans(node):
A, = node.inputs
if A.owner:
if isinstance(A.owner.op, MatrixInverse):
X, = A.owner.inputs
return [A.owner.op(node.op(X))]
X, = A.owner.inputs
return [A.owner.op(node.op(X))]
@register_stabilize
......@@ -286,7 +286,7 @@ def inv_as_solve(node):
@register_stabilize
@register_canonicalize
@local_optimizer(None) # XXX: solve is defined later and can't be used here
@local_optimizer(None) # XXX: solve is defined later and can't be used here
def tag_solve_triangular(node):
"""
If a general solve() is applied to the output of a cholesky op, then
......@@ -300,8 +300,16 @@ def tag_solve_triangular(node):
return [Solve('lower_triangular')(A, b)]
else:
return [Solve('upper_triangular')(A, b)]
if (isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs
if isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower:
return [Solve('upper_triangular')(A, b)]
else:
return [Solve('lower_triangular')(A, b)]
@register_canonicalize
@register_stabilize
@register_specialize
......
......@@ -111,6 +111,9 @@ class MatrixInverse(Op):
return [None]
return [-matrix_dot(xi, ev, xi)]
def infer_shape(self, node, shapes):
return shapes
matrix_inverse = MatrixInverse()
......
......@@ -171,9 +171,16 @@ class Solve(Op):
def perform(self, node, inputs, output_storage):
A, b = inputs
#TODO: use the A_structure to go faster
output_storage[0][0] = scipy.linalg.solve(A, b)
if self.A_structure == 'lower_triangular':
rval = scipy.linalg.solve_triangular(
A, b, lower=True)
elif self.A_structure == 'upper_triangular':
rval = scipy.linalg.solve_triangular(
A, b, lower=False)
else:
rval = scipy.linalg.solve(A, b)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b
def infer_shape(self, node, shapes):
Ashape, Bshape = shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论