提交 347de69d authored 作者: Eric Laufer's avatar Eric Laufer

Finished square diagonal op

上级 5e10c1f1
...@@ -140,10 +140,25 @@ diag = Diag() ...@@ -140,10 +140,25 @@ diag = Diag()
class SquareDiagonal(Op): class SquareDiagonal(Op):
"""Return a square sparse (csc) matrix whose diagonal is given by the dense vector argument. """Return a square sparse (csc) matrix whose diagonal is given by the dense vector argument.
""" """
def __eq__(self,other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def __str__(self):
return "SquareDiagonal"
def make_node(self, diag): def make_node(self, diag):
assert isinstance(diag.type, theano.tensor.TensorType)
if diag.type.ndim != 1:
raise TypeError('data argument must be a vector', diag.type)
return gof.Apply(self, [diag], return gof.Apply(self, [diag],
[sparse.SparseType(dtype = diag.dtype, [sparse.SparseType(dtype = diag.dtype,
format = 'csc')()]) format = 'csc')()])
def perform(self, node, (diag,), (z,)): def perform(self, node, (diag,), (z,)):
N, = diag.shape N, = diag.shape
indptr = range(N+1) indptr = range(N+1)
...@@ -153,6 +168,10 @@ class SquareDiagonal(Op): ...@@ -153,6 +168,10 @@ class SquareDiagonal(Op):
def grad(self, input, (gz,)): def grad(self, input, (gz,)):
return [diag(gz)] return [diag(gz)]
def infer_shape(self,nodes,shapes):
diag_length = shapes[0][0]
return [(diag_length,diag_length)]
square_diagonal = SquareDiagonal() square_diagonal = SquareDiagonal()
class ColScaleCSC(Op): class ColScaleCSC(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论