提交 412098a0 authored 作者: David Warde-Farley's avatar David Warde-Farley

Add infer_shape for Cholesky and CholeskyGrad.

上级 7740d3fa
...@@ -327,6 +327,9 @@ class Cholesky(Op): ...@@ -327,6 +327,9 @@ class Cholesky(Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self.props() == other.props()) return (type(self) == type(other) and self.props() == other.props())
def infer_shape(self, node, shapes):
return [shapes[0]]
def __str__(self): def __str__(self):
if self.lower: if self.lower:
lu = 'lower' lu = 'lower'
...@@ -431,6 +434,9 @@ class CholeskyGrad(Op): ...@@ -431,6 +434,9 @@ class CholeskyGrad(Op):
F[k, k] /= (2 * L[k, k]) F[k, k] /= (2 * L[k, k])
dx[0] = F dx[0] = F
def infer_shape(self, node, shapes):
return [shapes[0]]
class MatrixInverse(Op): class MatrixInverse(Op):
"""Computes the inverse of a matrix :math:`A`. """Computes the inverse of a matrix :math:`A`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论