提交 d70f4daa authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test for AddSS (structured=False), and a fail test for structured=True

上级 75ec37f3
...@@ -186,7 +186,7 @@ def verify_grad_sparse(op, pt, structured=False, *args, **kwargs): ...@@ -186,7 +186,7 @@ def verify_grad_sparse(op, pt, structured=False, *args, **kwargs):
raise NotImplementedError("verify_grad can't deal with " raise NotImplementedError("verify_grad can't deal with "
"multiple outputs") "multiple outputs")
if _is_sparse_variable(output): if _is_sparse_variable(output):
oconv = DenseFromSparse(structured=False) oconv = DenseFromSparse(structured=structured)
else: else:
oconv = conv_none oconv = conv_none
def conv_op(*inputs): def conv_op(*inputs):
......
...@@ -26,7 +26,7 @@ from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg ...@@ -26,7 +26,7 @@ from theano.sparse import AddSS, AddSD, MulSS, MulSD, Transpose, Neg
from theano.sparse import add, mul, structured_dot, transpose from theano.sparse import add, mul, structured_dot, transpose
from theano.sparse import (csc_from_dense, csr_from_dense, dense_from_sparse, from theano.sparse import (csc_from_dense, csr_from_dense, dense_from_sparse,
SparseFromDense) SparseFromDense)
from theano.sparse import Dot, Usmm, UsmmCscDense from theano.sparse import Dot, Usmm, UsmmCscDense, sp_ones_like
#from theano.sparse import get_item_2d, get_item_scalar #from theano.sparse import get_item_2d, get_item_scalar
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -64,44 +64,45 @@ def random_lil(shape, dtype, nnz): ...@@ -64,44 +64,45 @@ def random_lil(shape, dtype, nnz):
class T_verify_grad_sparse(unittest.TestCase): class T_verify_grad_sparse(unittest.TestCase):
class FailOp(gof.op.Op): class FailOp(gof.op.Op):
def __init__(self, structured):
self.structured = structured
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other)) and \
self.structured == other.structured
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) + hash(self.structured)
def make_node(self, x, y): def make_node(self, x):
x, y = map(as_sparse_variable, [x, y]) x = as_sparse_variable(x)
if x.type.dtype != y.type.dtype: return gof.Apply(self, [x], [x.type()])
raise NotImplementedError()
if x.type.format != y.type.format: def perform(self, node, (x, ), (out, )):
raise NotImplementedError() assert _is_sparse(x)
return gof.Apply(self, out[0] = -x
[x, y],
[SparseType(dtype=x.type.dtype, def grad(self, (x,), (gz,)):
format=x.type.format assert _is_sparse_variable(x) and _is_sparse_variable(gz)
).make_variable()]) if self.structured:
return sp_ones_like(x)*dense_from_sparse(gz),
def perform(self, node, (x, y), (out, )): else:
assert _is_sparse(x) and _is_sparse(y) return gz,
assert x.shape == y.shape
out[0] = x + y
def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_sparse_variable(y)
assert _is_sparse_variable(gz)
return 2*gz, gz
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
def test_grad_fail(self): def test_grad_fail(self):
self.assertRaises(verify_grad_sparse.E_grad, self.assertRaises(verify_grad_sparse.E_grad,
verify_grad_sparse, verify_grad_sparse,
self.FailOp(), self.FailOp(structured=False),
[sp.csr_matrix(random_lil((10, 40),
config.floatX, 3))])
self.assertRaises(verify_grad_sparse.E_grad,
verify_grad_sparse,
self.FailOp(structured=True),
[sp.csr_matrix(random_lil((10, 40), [sp.csr_matrix(random_lil((10, 40),
config.floatX, 3)),
sp.csr_matrix(random_lil((10, 40),
config.floatX, 3))]) config.floatX, 3))])
...@@ -294,6 +295,7 @@ class T_AddMul(unittest.TestCase): ...@@ -294,6 +295,7 @@ class T_AddMul(unittest.TestCase):
self.assertTrue(numpy.all(val.todense() == (a + b).todense())) self.assertTrue(numpy.all(val.todense() == (a + b).todense()))
ans = numpy.array([[1., 2], [3, 4], [5, 6]]) ans = numpy.array([[1., 2], [3, 4], [5, 6]])
self.assertTrue(numpy.all(val.todense() == ans)) self.assertTrue(numpy.all(val.todense() == ans))
verify_grad_sparse(op, [a, b], structured=False)
elif op is mul: elif op is mul:
self.assertTrue(numpy.all(val.todense() self.assertTrue(numpy.all(val.todense()
== (a.multiply(b)).todense())) == (a.multiply(b)).todense()))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论