提交 3ff7e92b authored 作者: abergeron's avatar abergeron

Merge pull request #2294 from nouiz/mixed

Fix crashes related to sparse
...@@ -1088,6 +1088,17 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1088,6 +1088,17 @@ def _populate_grad_dict(var_to_app_to_idx,
if len(input_grads) != len(inputs): if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of" + raise ValueError(("%s returned the wrong number of" +
" gradient terms.") % str(node.op)) " gradient terms.") % str(node.op))
# We can not enforce this, as AdvancedSubtensor1 has an option to
# return the sparse grad for optimization reason.
# for ig, i in zip(input_grads, inputs):
# if (not isinstance(ig.type, (DisconnectedType, NullType)) and
# type(ig.type) != type(i.type)):
# raise ValueError(
# "%s returned the wrong type for gradient terms."
# " Sparse inputs must have sparse grads and dense"
# " inputs must have dense grad. Got %s, expected %s" % (
# str(node.op), ig.type, i.type))
# must convert to list in case the op returns a tuple # must convert to list in case the op returns a tuple
# we won't be able to post-process out the Nones if it does that # we won't be able to post-process out the Nones if it does that
......
...@@ -1877,8 +1877,6 @@ class AddSS(gof.op.Op): ...@@ -1877,8 +1877,6 @@ class AddSS(gof.op.Op):
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
assert y.format in ["csr", "csc"] assert y.format in ["csr", "csc"]
out_dtype = scalar.upcast(x.type.dtype, y.type.dtype) out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=out_dtype, [SparseType(dtype=out_dtype,
...@@ -2130,10 +2128,6 @@ class MulSS(gof.op.Op): ...@@ -2130,10 +2128,6 @@ class MulSS(gof.op.Op):
assert x.format in ["csr", "csc"] assert x.format in ["csr", "csc"]
assert y.format in ["csr", "csc"] assert y.format in ["csr", "csc"]
out_dtype = scalar.upcast(x.type.dtype, y.type.dtype) out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError(
"MulSS not supported for differing types. "
"Got %s and %s." % (str(x.type), str(y.type)))
return gof.Apply(self, [x, y], return gof.Apply(self, [x, y],
[SparseType(dtype=out_dtype, [SparseType(dtype=out_dtype,
format=x.type.format format=x.type.format
...@@ -2244,7 +2238,7 @@ class MulSD(gof.op.Op): ...@@ -2244,7 +2238,7 @@ class MulSD(gof.op.Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_dense_variable(y) assert _is_sparse_variable(x) and _is_dense_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
return y * gz, x * gz return y * gz, dense_from_sparse(x * gz)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -3715,11 +3709,15 @@ class Dot(gof.op.Op): ...@@ -3715,11 +3709,15 @@ class Dot(gof.op.Op):
raise NotImplementedError() raise NotImplementedError()
def make_node(self, x, y): def make_node(self, x, y):
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype) dtype_out = scalar.upcast(x.dtype, y.dtype)
# Sparse dot product should have at least one sparse variable # Sparse dot product should have at least one sparse variable
# as input. If the other one is not sparse, it has to be converted # as input. If the other one is not sparse, it has to be converted
# into a tensor. # into a tensor.
if isinstance(x, scipy.sparse.spmatrix):
x = as_sparse_variable(x)
if isinstance(y, scipy.sparse.spmatrix):
y = as_sparse_variable(y)
x_is_sparse_var = _is_sparse_variable(x) x_is_sparse_var = _is_sparse_variable(x)
y_is_sparse_var = _is_sparse_variable(y) y_is_sparse_var = _is_sparse_variable(y)
......
...@@ -16,6 +16,7 @@ from theano import sparse ...@@ -16,6 +16,7 @@ from theano import sparse
from theano import compile, config, gof from theano import compile, config, gof
from theano.sparse import enable_sparse from theano.sparse import enable_sparse
from theano.gof.python25 import all, any, product from theano.gof.python25 import all, any, product
from theano.gof.python25 import product as itertools_product
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
if not enable_sparse: if not enable_sparse:
...@@ -520,17 +521,18 @@ class T_AddMul(unittest.TestCase): ...@@ -520,17 +521,18 @@ class T_AddMul(unittest.TestCase):
def _testSS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]), def _testSS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])): array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
for mtype in _mtypes: for mtype1, mtype2 in itertools_product(_mtypes, _mtypes):
for dtype1, dtype2 in [('float64', 'int8'), for dtype1, dtype2 in [('float64', 'int8'),
('int8', 'float64'), ('int8', 'float64'),
('float64', 'float64'),
]: ]:
a = mtype(array1).astype(dtype1) a = mtype1(array1).astype(dtype1)
aR = as_sparse_variable(a) aR = as_sparse_variable(a)
self.assertFalse(aR.data is a) self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a)) self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR)) self.assertTrue(_is_sparse_variable(aR))
b = mtype(array2).astype(dtype2) b = mtype2(array2).astype(dtype2)
bR = as_sparse_variable(b) bR = as_sparse_variable(b)
self.assertFalse(bR.data is b) self.assertFalse(bR.data is b)
self.assertTrue(_is_sparse(b)) self.assertTrue(_is_sparse(b))
...@@ -540,7 +542,6 @@ class T_AddMul(unittest.TestCase): ...@@ -540,7 +542,6 @@ class T_AddMul(unittest.TestCase):
self.assertTrue(_is_sparse_variable(apb)) self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(apb.type.format == aR.type.format, apb.type.format) self.assertTrue(apb.type.format == aR.type.format, apb.type.format)
self.assertTrue(apb.type.format == bR.type.format, apb.type.format)
val = eval_outputs([apb]) val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2)) self.assertTrue(val.shape == (3, 2))
...@@ -561,6 +562,8 @@ class T_AddMul(unittest.TestCase): ...@@ -561,6 +562,8 @@ class T_AddMul(unittest.TestCase):
theano.shared(array1)]: theano.shared(array1)]:
for dtype1, dtype2 in [('float64', 'int8'), for dtype1, dtype2 in [('float64', 'int8'),
('int8', 'float64'), ('int8', 'float64'),
# Needed to test the grad
('float32', 'float64'),
]: ]:
a = a.astype(dtype1) a = a.astype(dtype1)
b = mtype(array2).astype(dtype2) b = mtype(array2).astype(dtype2)
...@@ -580,6 +583,8 @@ class T_AddMul(unittest.TestCase): ...@@ -580,6 +583,8 @@ class T_AddMul(unittest.TestCase):
self.assertTrue(numpy.all(val == ans)) self.assertTrue(numpy.all(val == ans))
if isinstance(a, theano.Constant): if isinstance(a, theano.Constant):
a = a.data a = a.data
if getattr(a, 'owner', None):
continue
if dtype1.startswith('float') and dtype2.startswith('float'): if dtype1.startswith('float') and dtype2.startswith('float'):
verify_grad_sparse(op, [a, b], structured=True) verify_grad_sparse(op, [a, b], structured=True)
elif op is mul: elif op is mul:
...@@ -589,6 +594,8 @@ class T_AddMul(unittest.TestCase): ...@@ -589,6 +594,8 @@ class T_AddMul(unittest.TestCase):
[[1, 0], [9, 0], [0, 36]]))) [[1, 0], [9, 0], [0, 36]])))
if isinstance(a, theano.Constant): if isinstance(a, theano.Constant):
a = a.data a = a.data
if getattr(a, 'owner', None):
continue
if dtype1.startswith('float') and dtype2.startswith('float'): if dtype1.startswith('float') and dtype2.startswith('float'):
verify_grad_sparse(op, [a, b], structured=False) verify_grad_sparse(op, [a, b], structured=False)
...@@ -1330,6 +1337,20 @@ class DotTests(utt.InferShapeTester): ...@@ -1330,6 +1337,20 @@ class DotTests(utt.InferShapeTester):
dtype=intX) dtype=intX)
f(i, a) f(i, a)
def test_csr_dense_grad(self):
#shortcut: testing csc in float32, testing csr in float64
# allocate a random sparse matrix
spmat = sp.csr_matrix(random_lil((4, 3), 'float64', 3))
mat = numpy.asarray(numpy.random.randn(2, 4), 'float64')
def buildgraph_T(mat):
return Dot()(mat, spmat)
theano.tests.unittest_tools.verify_grad(buildgraph_T, [mat])
class UsmmTests(unittest.TestCase): class UsmmTests(unittest.TestCase):
""" Test the Usmm and UsmmCscDense class and related optimization """ """ Test the Usmm and UsmmCscDense class and related optimization """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论