提交 0ada7dff authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make verify_grad_sparse work for non-structured op and add a test that it fails as expected.

上级 a4813f37
......@@ -141,7 +141,7 @@ def as_sparse_or_tensor_variable(x, name=None):
return theano.tensor.as_tensor_variable(x, name)
def verify_grad_sparse(op, pt, *args, **kwargs):
def verify_grad_sparse(op, pt, structured=False, *args, **kwargs):
"""
Wrapper for theano.test.unittest_tools.py:verify_grad
Converts sparse variables back and forth.
......@@ -160,11 +160,22 @@ def verify_grad_sparse(op, pt, *args, **kwargs):
for p in pt:
if _is_sparse(p):
dpt.append(p.data)
if structured:
dpt.append(p.data)
else:
dpt.append(p.toarray())
if p.format == 'csr':
iconv.append(conv_csr(p.indices[:p.size], p.indptr, p.shape))
if structured:
iconv.append(conv_csr(p.indices[:p.size], p.indptr,
p.shape))
else:
iconv.append(csr_from_dense)
elif p.format == 'csc':
iconv.append(conv_csc(p.indices[:p.size], p.indptr, p.shape))
if structured:
iconv.append(conv_csc(p.indices[:p.size], p.indptr,
p.shape))
else:
iconv.append(csc_from_dense)
else:
raise NotImplementedError("No conv for %s" % (p.format,))
else:
......@@ -175,7 +186,7 @@ def verify_grad_sparse(op, pt, *args, **kwargs):
raise NotImplementedError("verify_grad can't deal with "
"multiple outputs")
if _is_sparse_variable(output):
oconv = dense_from_sparse
oconv = DenseFromSparse(structured=False)
else:
oconv = conv_none
def conv_op(*inputs):
......@@ -740,13 +751,15 @@ class DenseFromSparse(gof.op.Op):
"""
Convert a sparse matrix to an `ndarray`.
"""
sparse_grad = True
"""WRITEME"""
def __init__(self, structured=True):
self.sparse_grad = structured
def __eq__(self, other):
return (type(self) == type(other))
return (type(self) == type(other)) and \
(self.sparse_grad == other.sparse_grad)
def __hash__(self):
return hash(type(self))
return hash(type(self))+hash(self.sparse_grad)
def make_node(self, x):
x = as_sparse_variable(x)
......
......@@ -10,7 +10,7 @@ except ImportError:
pass # The variable enable_sparse will be used to disable the test file.
import theano
from theano import compile, config
from theano import compile, config, gof
from theano.sparse import enable_sparse
from theano.gof.python25 import all, any, product
......@@ -62,6 +62,48 @@ def random_lil(shape, dtype, nnz):
value)
return rval
class T_verify_grad_sparse(unittest.TestCase):
class FailOp(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype=x.type.dtype,
format=x.type.format
).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape
out[0] = x + y
def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(y)
assert _is_sparse_variable(gz)
return 2*gz, gz
def infer_shape(self, node, shapes):
return [shapes[0]]
def test_grad_fail(self):
self.assertRaises(verify_grad_sparse.E_grad,
verify_grad_sparse,
self.FailOp(),
[sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)),
sp.csr_matrix(random_lil((10, 40),
config.floatX, 3))])
class T_transpose(unittest.TestCase):
def setUp(self):
......@@ -463,19 +505,12 @@ class test_structureddot(unittest.TestCase):
mat = numpy.asarray(numpy.random.randn(3, 2), 'float32')
verify_grad_sparse(structured_dot, [spmat, mat])
def buildgraphCSC_T(spdata, sym_mat):
csc = CSC(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csc.type.dtype == 'float32'
rval = structured_dot(sym_mat.T, csc.T)
assert rval.type.dtype == 'float32'
return rval
verify_grad_sparse(structured_dot, [spmat, mat], structured=True)
def buildgraph_T(spmat, mat):
return structured_dot(mat.T, spmat.T)
utt.verify_grad(buildgraphCSC_T,
[spmat.data, mat])
verify_grad_sparse(buildgraph_T, [spmat, mat], structured=True)
def test_structureddot_csr_grad(self):
......@@ -486,18 +521,12 @@ class test_structureddot(unittest.TestCase):
mat = numpy.asarray(numpy.random.randn(3, 2), 'float64')
verify_grad_sparse(structured_dot, [spmat, mat])
verify_grad_sparse(structured_dot, [spmat, mat], structured=True)
def buildgraph_T(spdata, sym_mat):
csr = CSR(spdata, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csr.type.dtype == 'float64'
rval = structured_dot(sym_mat.T, csr.T)
assert rval.type.dtype == 'float64'
return rval
def buildgraph_T(spmat, mat):
rval = structured_dot(mat.T, spmat.T)
utt.verify_grad(buildgraph,
[spmat.data, mat])
verify_grad_sparse(buildgraph_T, [spmat.data, mat], structured=True)
def test_infer_shape_csr_csc_grad(self):
for sparsetype in ('csr', 'csc'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论