提交 b30ebcdb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Implement verify_grad_sparse as a wrapper around verify_grad.

上级 7ccfbb01
...@@ -17,6 +17,7 @@ from theano import gof, tensor, compile, scalar, config ...@@ -17,6 +17,7 @@ from theano import gof, tensor, compile, scalar, config
from theano.gof.python25 import all from theano.gof.python25 import all
from theano.tensor import blas from theano.tensor import blas
from theano.sparse.utils import hash_from_sparse from theano.sparse.utils import hash_from_sparse
import theano.tests.unittest_tools as utt
sparse_formats = ['csc', 'csr'] sparse_formats = ['csc', 'csr']
...@@ -140,6 +141,51 @@ def as_sparse_or_tensor_variable(x, name=None): ...@@ -140,6 +141,51 @@ def as_sparse_or_tensor_variable(x, name=None):
return theano.tensor.as_tensor_variable(x, name) return theano.tensor.as_tensor_variable(x, name)
def verify_grad_sparse(op, pt, *args, **kwargs):
"""
Wrapper for theano.test.unittest_tools.py:verify_grad
Converts sparse variables back and forth.
"""
conv_none = lambda x: x
def conv_csr(ind, indptr, shp):
def f(spdata):
return CSR(spdata, ind, indptr, shp)
return f
def conv_csc(ind, indptr, shp):
def f(spdata):
return CSC(spdata, ind, indptr, shp)
return f
iconv = []
dpt = []
for p in pt:
if _is_sparse(p):
dpt.append(p.data)
if p.format == 'csr':
iconv.append(conv_csr(p.indices[:p.size], p.indptr, p.shape))
elif p.format == 'csc':
iconv.append(conv_csc(p.indices[:p.size], p.indptr, p.shape))
else:
raise NotImplementedError("No conv for %s" % (p.format,))
else:
dpt.append(p)
iconv.append(conv_none)
output = op(*[as_sparse_or_tensor_variable(p) for p in pt])
if isinstance(output, (list, tuple)):
raise NotImplementedError("verify_grad can't deal with "
"multiple outputs")
if _is_sparse_variable(output):
oconv = dense_from_sparse
else:
oconv = conv_none
def conv_op(*inputs):
ipt = [conv(i) for i, conv in zip(inputs, iconv)]
out = op(*ipt)
return oconv(out)
return utt.verify_grad(conv_op, dpt, *args, **kwargs)
verify_grad_sparse.E_grad = utt.verify_grad.E_grad
def constant(x, name=None): def constant(x, name=None):
if not isinstance(x, scipy.sparse.spmatrix): if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.constant must be called on a " raise TypeError("sparse.constant must be called on a "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论