提交 338384ad authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4944 from nouiz/verify_grad_sparse

Make nose-paremetrised optional for theano.sparse
......@@ -22,7 +22,6 @@ import theano
from theano import gof, tensor, scalar, config
from theano.gradient import DisconnectedType
from theano.sparse.utils import hash_from_sparse
import theano.tests.unittest_tools as utt
from theano.gradient import grad_not_implemented, grad_undefined
from theano.sparse.type import SparseType, _is_sparse
......@@ -157,87 +156,6 @@ def as_sparse_or_tensor_variable(x, name=None):
return theano.tensor.as_tensor_variable(x, name)
def verify_grad_sparse(op, pt, structured=False, *args, **kwargs):
"""
Wrapper for theano.test.unittest_tools.py:verify_grad wich
converts sparse variables back and forth.
Parameters
----------
op
Op to check.
pt
List of inputs to realize the tests.
structured
True to tests with a structured grad, False otherwise.
args
Other `verify_grad` parameters if any.
kwargs
Other `verify_grad` keywords if any.
Returns
-------
None
"""
def conv_none(x):
return x
def conv_csr(ind, indptr, shp):
def f(spdata):
return CSR(spdata, ind, indptr, shp)
return f
def conv_csc(ind, indptr, shp):
def f(spdata):
return CSC(spdata, ind, indptr, shp)
return f
iconv = []
dpt = []
for p in pt:
if _is_sparse(p):
if structured:
dpt.append(p.data)
else:
dpt.append(p.toarray())
if p.format == 'csr':
if structured:
iconv.append(conv_csr(p.indices[:p.size], p.indptr,
p.shape))
else:
iconv.append(csr_from_dense)
elif p.format == 'csc':
if structured:
iconv.append(conv_csc(p.indices[:p.size], p.indptr,
p.shape))
else:
iconv.append(csc_from_dense)
else:
raise NotImplementedError("No conv for %s" % (p.format,))
else:
dpt.append(p)
iconv.append(conv_none)
output = op(*[as_sparse_or_tensor_variable(p) for p in pt])
if isinstance(output, (list, tuple)):
raise NotImplementedError("verify_grad can't deal with "
"multiple outputs")
if _is_sparse_variable(output):
oconv = DenseFromSparse(structured=structured)
else:
oconv = conv_none
def conv_op(*inputs):
ipt = [conv(i) for i, conv in zip(inputs, iconv)]
out = op(*ipt)
return oconv(out)
return utt.verify_grad(conv_op, dpt, *args, **kwargs)
verify_grad_sparse.E_grad = utt.verify_grad.E_grad
def constant(x, name=None):
if not isinstance(x, scipy.sparse.spmatrix):
raise TypeError("sparse.constant must be called on a "
......
......@@ -20,7 +20,6 @@ from theano.compat import next
from theano.sparse.sandbox import sp
from theano.sparse.tests.test_basic import random_lil
from theano.tests import unittest_tools as utt
from theano.sparse import verify_grad_sparse
from theano.sparse.tests.test_basic import sparse_random_inputs
from theano.tests.unittest_tools import attr
......
......@@ -16,7 +16,6 @@ from theano.sparse import SparseType, dense_from_sparse, transpose
from theano.sparse.tests.test_basic import sparse_random_inputs
from theano.tests import unittest_tools as utt
from theano.sparse import verify_grad_sparse
# To maintain compatibility
from theano.sparse.basic import TrueDot, true_dot
......@@ -27,8 +27,9 @@ if not enable_sparse:
from theano.sparse.basic import _is_dense, _is_sparse, _mtypes
from theano.sparse.basic import _is_dense_variable, _is_sparse_variable
from theano.sparse import (
verify_grad_sparse, as_sparse_variable,
CSC, CSM, CSMProperties, csm_properties,
as_sparse_variable, as_sparse_or_tensor_variable,
CSR, CSC, CSM, CSMProperties, csm_properties,
DenseFromSparse,
SparseType, CSMGrad,
StructuredDot,
StructuredDotGradCSC, StructuredDotGradCSR,
......@@ -181,6 +182,87 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None,
return (variable, data)
def verify_grad_sparse(op, pt, structured=False, *args, **kwargs):
"""
Wrapper for theano.test.unittest_tools.py:verify_grad wich
converts sparse variables back and forth.
Parameters
----------
op
Op to check.
pt
List of inputs to realize the tests.
structured
True to tests with a structured grad, False otherwise.
args
Other `verify_grad` parameters if any.
kwargs
Other `verify_grad` keywords if any.
Returns
-------
None
"""
def conv_none(x):
return x
def conv_csr(ind, indptr, shp):
def f(spdata):
return CSR(spdata, ind, indptr, shp)
return f
def conv_csc(ind, indptr, shp):
def f(spdata):
return CSC(spdata, ind, indptr, shp)
return f
iconv = []
dpt = []
for p in pt:
if _is_sparse(p):
if structured:
dpt.append(p.data)
else:
dpt.append(p.toarray())
if p.format == 'csr':
if structured:
iconv.append(conv_csr(p.indices[:p.size], p.indptr,
p.shape))
else:
iconv.append(csr_from_dense)
elif p.format == 'csc':
if structured:
iconv.append(conv_csc(p.indices[:p.size], p.indptr,
p.shape))
else:
iconv.append(csc_from_dense)
else:
raise NotImplementedError("No conv for %s" % (p.format,))
else:
dpt.append(p)
iconv.append(conv_none)
output = op(*[as_sparse_or_tensor_variable(p) for p in pt])
if isinstance(output, (list, tuple)):
raise NotImplementedError("verify_grad can't deal with "
"multiple outputs")
if _is_sparse_variable(output):
oconv = DenseFromSparse(structured=structured)
else:
oconv = conv_none
def conv_op(*inputs):
ipt = [conv(i) for i, conv in zip(inputs, iconv)]
out = op(*ipt)
return oconv(out)
return utt.verify_grad(conv_op, dpt, *args, **kwargs)
verify_grad_sparse.E_grad = utt.verify_grad.E_grad
class T_verify_grad_sparse(unittest.TestCase):
class FailOp(gof.op.Op):
def __init__(self, structured):
......@@ -2914,9 +2996,9 @@ class MulSVTester(unittest.TestCase):
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
mat = numpy.asarray(numpy.random.rand(3), dtype=dtype)
theano.sparse.verify_grad_sparse(mul_s_v,
[spmat, mat],
structured=True)
verify_grad_sparse(mul_s_v,
[spmat, mat],
structured=True)
def test_mul_s_v(self):
sp_types = {'csc': sp.csc_matrix,
......@@ -2949,9 +3031,9 @@ class StructuredAddSVTester(unittest.TestCase):
spmat = sp_types[format](random_lil((4, 3), dtype, 3))
mat = numpy.asarray(numpy.random.rand(3), dtype=dtype)
theano.sparse.verify_grad_sparse(structured_add_s_v,
[spmat, mat],
structured=True)
verify_grad_sparse(structured_add_s_v,
[spmat, mat],
structured=True)
def test_structured_add_s_v(self):
sp_types = {'csc': sp.csc_matrix,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论