提交 e0fbb48b authored 作者: Philippe  Hamel's avatar Philippe Hamel

det OP

reviewed code... everything seems OK added infer_shape added a bit of doc
上级 fa896acb
...@@ -548,8 +548,10 @@ def diag(x): ...@@ -548,8 +548,10 @@ def diag(x):
raise TypeError('diag requires vector or matrix argument', x) raise TypeError('diag requires vector or matrix argument', x)
class Det(Op): class Det(Op):
"""matrix determinant """Matrix determinant
Input should be a square matrix
:note: Requires scipy
TODO: move this op to another file that request scipy. TODO: move this op to another file that request scipy.
""" """
def make_node(self, x): def make_node(self, x):
...@@ -566,6 +568,8 @@ class Det(Op): ...@@ -566,6 +568,8 @@ class Det(Op):
gz, = g_outputs gz, = g_outputs
x, = inputs x, = inputs
return [gz * self(x) * matrix_inverse(x).T] return [gz * self(x) * matrix_inverse(x).T]
def infer_shape(self, node, shapes):
return [(1, )]
def __str__(self): def __str__(self):
return "Det" return "Det"
det = Det() det = Det()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论