提交 2b076111 authored 作者: Frederic's avatar Frederic

Make MulSS and AddSS support mixed input format.

I do not force the grad to be the same format. This would require conversion. I do not think we always enforce this already, so I do not want to enforce that now.
上级 2f29cf86
......@@ -1877,8 +1877,6 @@ class AddSS(gof.op.Op):
assert x.format in ["csr", "csc"]
assert y.format in ["csr", "csc"]
out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype=out_dtype,
......@@ -2130,10 +2128,6 @@ class MulSS(gof.op.Op):
assert x.format in ["csr", "csc"]
assert y.format in ["csr", "csc"]
out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError(
"MulSS not supported for differing types. "
"Got %s and %s." % (str(x.type), str(y.type)))
return gof.Apply(self, [x, y],
[SparseType(dtype=out_dtype,
format=x.type.format
......
......@@ -16,6 +16,7 @@ from theano import sparse
from theano import compile, config, gof
from theano.sparse import enable_sparse
from theano.gof.python25 import all, any, product
from theano.gof.python25 import product as itertools_product
from theano.tensor.basic import _allclose
if not enable_sparse:
......@@ -520,17 +521,18 @@ class T_AddMul(unittest.TestCase):
def _testSS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
for mtype in _mtypes:
for mtype1, mtype2 in itertools_product(_mtypes, _mtypes):
for dtype1, dtype2 in [('float64', 'int8'),
('int8', 'float64'),
('float64', 'float64'),
]:
a = mtype(array1).astype(dtype1)
a = mtype1(array1).astype(dtype1)
aR = as_sparse_variable(a)
self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR))
b = mtype(array2).astype(dtype2)
b = mtype2(array2).astype(dtype2)
bR = as_sparse_variable(b)
self.assertFalse(bR.data is b)
self.assertTrue(_is_sparse(b))
......@@ -540,7 +542,6 @@ class T_AddMul(unittest.TestCase):
self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(apb.type.format == aR.type.format, apb.type.format)
self.assertTrue(apb.type.format == bR.type.format, apb.type.format)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论