提交 5a04235b authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made some sparse.basic ops work with DisconnectedType

上级 43e64647
......@@ -7,7 +7,6 @@ http://www-users.cs.umn.edu/~saad/software/SPARSKIT/paper.ps
# TODO
# Automatic methods for determining best sparse format?
from itertools import izip
import sys
import numpy
......@@ -16,7 +15,7 @@ import scipy.sparse
from theano import gof, tensor, compile, scalar, config
from theano.gof.python25 import all
from theano.tensor import blas
from theano.gradient import DisconnectedType
from theano.sparse.utils import hash_from_sparse
import theano.tests.unittest_tools as utt
......@@ -626,7 +625,18 @@ class CSMProperties(gof.Op):
out[3][0] = theano._asarray(csm.shape, dtype='int32')
def grad(self, (csm,), g):
assert [gg is None for gg in g[1:]]
#g[1:] is all integers, so their Jacobian in this op
#is 0. We thus don't need to worry about what their values
#are.
#if g[0] is disconnected, then this op doesn't contribute
#any gradient anywhere. but we know that at least one of
#g[1:] is connected, or this grad method wouldn't have been
#called, so we should report zeros
if isinstance(g[0].type, DisconnectedType):
return [csm.zeros_like()]
data, indices, indptr, shape = csm_properties(csm)
return [CSM(csm.format)(g[0], indices, indptr, shape)]
# don't make this a function or it breaks some optimizations below
......@@ -662,10 +672,10 @@ class CSM(gof.Op):
:param data: One dimensionnal tensor representing
the data of the sparse to construct.
:param indices: One dimensionnal tensor of integers
:param indices: One dimensional tensor of integers
representing the indices of the sparse
matrix to construct.
:param indptr: One dimensionnal tensor of integers
:param indptr: One dimensional tensor of integers
representing the indice pointer for
the sparse matrix to construct.
:param shape: One dimensionnal tensor of integers
......@@ -673,9 +683,9 @@ class CSM(gof.Op):
matrix to construct.
:return: A sparse matrix having the properties
speficied by the inputs.
specified by the inputs.
:note: The grad method returns a dense vector, so it provide
:note: The grad method returns a dense vector, so it provides
a regular grad.
"""
......@@ -777,7 +787,7 @@ class CSM(gof.Op):
#unpack the data vector and wrap it as a 1d TensorType
g_data = csm_grad(self.kmap)(x_data, x_indices, x_indptr, x_shape,
g_data, g_indices, g_indptr, g_shape)
return [g_data, None, None, None]
return [g_data, DisconnectedType()(), DisconnectedType()(), DisconnectedType()]
def infer_shape(self, node, shapes):
if self.kmap is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论