提交 69c86f01 authored 作者: James Bergstra's avatar James Bergstra 提交者: Amjad Almahairi

opt: solve(chol) -> solve(triangular, chol)

Cholesky outputs a triangular matrix, so if a general solve is applied to a cholesky output, then transform it into a triangular solve.
上级 2817ba56
...@@ -284,6 +284,24 @@ def inv_as_solve(node): ...@@ -284,6 +284,24 @@ def inv_as_solve(node):
return [solve(r.owner.inputs[0].T, l.T).T] return [solve(r.owner.inputs[0].T, l.T).T]
@register_stabilize
@register_canonicalize
@local_optimizer(None) # XXX: solve is defined later and can't be used here
def tag_solve_triangular(node):
"""
If a general solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if node.op == solve:
if node.op.A_structure == 'general':
A, b = node.inputs # result is solution Ax=b
if isinstance(A.owner.op, type(cholesky)):
if A.owner.op.lower:
return [Solve('lower_triangular')(A, b)]
else:
return [Solve('upper_triangular')(A, b)]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论