提交 4fef5a6e authored 作者: Guillaume Desjardins's avatar Guillaume Desjardins 提交者: Frederic

Reformatted docstrings and spaces between funcs.

上级 d50c1eba
...@@ -561,7 +561,7 @@ solve = Solve() # general solve ...@@ -561,7 +561,7 @@ solve = Solve() # general solve
#TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten) #TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten)
class ExtractDiag(Op): class ExtractDiag(Op):
""" Return the diagonal of a matrix """ """ Return the diagonal of a matrix. """
def __init__(self, view=False): def __init__(self, view=False):
self.view = view self.view = view
if self.view: if self.view:
...@@ -580,13 +580,13 @@ class ExtractDiag(Op): ...@@ -580,13 +580,13 @@ class ExtractDiag(Op):
return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)]) return Apply(self, [x], [tensor.vector(dtype=x.type.dtype)])
def perform(self, node, ins, outs): def perform(self, node, ins, outs):
""" For some reason numpy.diag(x) is really slow, so we implemented our own """ """ For some reason numpy.diag(x) is really slow, so we implemented our own. """
x, = ins x, = ins
z, = outs z, = outs
# zero-dimensional matrices ... # zero-dimensional matrices ...
if x.shape[0] == 0 or x.shape[1] == 0: if x.shape[0] == 0 or x.shape[1] == 0:
z[0] = x z[0] = numpy.zeros(0)
return return
if x.shape[0] < x.shape [1]: if x.shape[0] < x.shape [1]:
...@@ -606,7 +606,7 @@ class ExtractDiag(Op): ...@@ -606,7 +606,7 @@ class ExtractDiag(Op):
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
x = tensor.zeros_like(inputs[0]) x = tensor.zeros_like(inputs[0])
xdiag = alloc_diag(g_outputs[0]) xdiag = alloc_diag(g_outputs[0])
return [tensor.set_subtensor(x[:xdiag.shape[0], :xdiag.shape[1]], xdiag, inplace=True)] return [tensor.set_subtensor(x[:xdiag.shape[0], :xdiag.shape[1]], xdiag)]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
x_s, = shapes x_s, = shapes
...@@ -616,6 +616,7 @@ class ExtractDiag(Op): ...@@ -616,6 +616,7 @@ class ExtractDiag(Op):
extract_diag = ExtractDiag() extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True #TODO: optimization to insert ExtractDiag with view=True
class AllocDiag(Op): class AllocDiag(Op):
""" """
Allocates a square matrix with the given vector as its diagonal. Allocates a square matrix with the given vector as its diagonal.
...@@ -646,9 +647,10 @@ class AllocDiag(Op): ...@@ -646,9 +647,10 @@ class AllocDiag(Op):
alloc_diag = AllocDiag() alloc_diag = AllocDiag()
def diag(x):
"""Numpy-compatibility method
def diag(x):
"""
Numpy-compatibility method
If `x` is a matrix, return its diagonal. If `x` is a matrix, return its diagonal.
If `x` is a vector return a matrix with it as its diagonal. If `x` is a vector return a matrix with it as its diagonal.
...@@ -662,6 +664,7 @@ def diag(x): ...@@ -662,6 +664,7 @@ def diag(x):
else: else:
raise TypeError('diag requires vector or matrix argument', x) raise TypeError('diag requires vector or matrix argument', x)
class Det(Op): class Det(Op):
"""Matrix determinant """Matrix determinant
Input should be a square matrix Input should be a square matrix
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论