提交 fed8e3a6 authored 作者: Eric Laufer's avatar Eric Laufer

fixed input assert

上级 347de69d
...@@ -137,11 +137,13 @@ class Diag(Op): ...@@ -137,11 +137,13 @@ class Diag(Op):
diag = Diag() diag = Diag()
class SquareDiagonal(Op): class SquareDiagonal(Op):
"""Return a square sparse (csc) matrix whose diagonal is given by the dense vector argument.
""" """
Return a square sparse (csc) matrix whose diagonal
def __eq__(self,other): is given by the dense vector argument.
"""
def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self): def __hash__(self):
...@@ -151,29 +153,30 @@ class SquareDiagonal(Op): ...@@ -151,29 +153,30 @@ class SquareDiagonal(Op):
return "SquareDiagonal" return "SquareDiagonal"
def make_node(self, diag): def make_node(self, diag):
assert isinstance(diag.type, theano.tensor.TensorType) diag = tensor.as_tensor_variable(diag)
if diag.type.ndim != 1: if diag.type.ndim != 1:
raise TypeError('data argument must be a vector', diag.type) raise TypeError('data argument must be a vector', diag.type)
return gof.Apply(self, [diag], return gof.Apply(self, [diag],
[sparse.SparseType(dtype = diag.dtype, [sparse.SparseType(dtype=diag.dtype, format='csc')()])
format = 'csc')()])
def perform(self, node, (diag,), (z,)): def perform(self, node, (diag,), (z,)):
N, = diag.shape N, = diag.shape
indptr = range(N+1) indptr = range(N + 1)
indices = indptr[0:N] indices = indptr[0:N]
z[0] = scipy_sparse.csc_matrix((diag, indices, indptr), (N,N), copy=True) z[0] = scipy_sparse.csc_matrix((diag, indices, indptr),
(N, N), copy=True)
def grad(self, input, (gz,)): def grad(self, input, (gz,)):
return [diag(gz)] return [diag(gz)]
def infer_shape(self,nodes,shapes): def infer_shape(self, nodes, shapes):
diag_length = shapes[0][0] diag_length = shapes[0][0]
return [(diag_length,diag_length)] return [(diag_length, diag_length)]
square_diagonal = SquareDiagonal() square_diagonal = SquareDiagonal()
class ColScaleCSC(Op): class ColScaleCSC(Op):
""" """
Scale each columns of a sparse matrix by the corresponding element of a dense vector Scale each columns of a sparse matrix by the corresponding element of a dense vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论