提交 ae398862 authored 作者: Frederic's avatar Frederic

MulSS and AddSS now support different input dtype.

上级 ced94276
......@@ -1654,13 +1654,12 @@ class AddSS(gof.op.Op):
def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype=x.type.dtype,
[SparseType(dtype=out_dtype,
format=x.type.format
).make_variable()])
......@@ -1923,11 +1922,16 @@ class MulSS(gof.op.Op):
def make_node(self, x, y):
x, y = as_sparse_variable(x), as_sparse_variable(y)
if x.type != y.type:
out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError(
"MulSS not supported for differing types. "
"Got %s and %s." % (str(x.type), str(y.type)))
return gof.Apply(self, [x, y], [x.type()])
return gof.Apply(self, [x, y],
[SparseType(dtype=out_dtype,
format=x.type.format
)()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
......
......@@ -536,35 +536,38 @@ class T_AddMul(unittest.TestCase):
def _testSS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
for mtype in _mtypes:
a = mtype(array1)
aR = as_sparse_variable(a)
self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR))
for dtype1, dtype2 in [('float64', 'int8'),
('int8', 'float64'),
]:
a = mtype(array1).astype(dtype1)
aR = as_sparse_variable(a)
self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR))
b = mtype(array2)
bR = as_sparse_variable(b)
self.assertFalse(bR.data is b)
self.assertTrue(_is_sparse(b))
self.assertTrue(_is_sparse_variable(bR))
apb = op(aR, bR)
self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(apb.type.dtype == aR.type.dtype, apb.type.dtype)
self.assertTrue(apb.type.dtype == bR.type.dtype, apb.type.dtype)
self.assertTrue(apb.type.format == aR.type.format, apb.type.format)
self.assertTrue(apb.type.format == bR.type.format, apb.type.format)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
if op is add:
self.assertTrue(numpy.all(val.todense() == (array1 + array2)))
verify_grad_sparse(op, [a, b], structured=False)
elif op is mul:
self.assertTrue(numpy.all(val.todense()
== (array1 * array2)))
verify_grad_sparse(op, [a, b], structured=False)
b = mtype(array2).astype(dtype2)
bR = as_sparse_variable(b)
self.assertFalse(bR.data is b)
self.assertTrue(_is_sparse(b))
self.assertTrue(_is_sparse_variable(bR))
apb = op(aR, bR)
self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(apb.type.format == aR.type.format, apb.type.format)
self.assertTrue(apb.type.format == bR.type.format, apb.type.format)
val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
if op is add:
self.assertTrue(numpy.all(val.todense() == (array1 + array2)))
if dtype1.startswith('float') and dtype2.startswith('float'):
verify_grad_sparse(op, [a, b], structured=False)
elif op is mul:
self.assertTrue(numpy.all(val.todense()
== (array1 * array2)))
if dtype1.startswith('float') and dtype2.startswith('float'):
verify_grad_sparse(op, [a, b], structured=False)
def _testSD(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
......@@ -639,29 +642,7 @@ class T_AddMul(unittest.TestCase):
array2 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='int32')
array3 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='int8')
# AddSS and MulSS
for mtype in _mtypes:
a = mtype(array1)
aR = as_sparse_variable(a)
b = mtype(array2)
bR = as_sparse_variable(b)
c = mtype(array3)
cR = as_sparse_variable(c)
# Ops that do not upcast
self.assertRaises(NotImplementedError, add, aR, bR)
self.assertRaises(NotImplementedError, add, bR, aR)
self.assertRaises(NotImplementedError, add, bR, cR)
self.assertRaises(NotImplementedError, add, cR, bR)
self.assertRaises(NotImplementedError, add, aR, cR)
self.assertRaises(NotImplementedError, add, cR, aR)
self.assertRaises(NotImplementedError, mul, aR, bR)
self.assertRaises(NotImplementedError, mul, bR, aR)
self.assertRaises(NotImplementedError, mul, bR, cR)
self.assertRaises(NotImplementedError, mul, cR, bR)
self.assertRaises(NotImplementedError, mul, aR, cR)
self.assertRaises(NotImplementedError, mul, cR, aR)
# AddSS and MulSS upcated tested in _testSS
# AddSD and MulSD
for mtype in _mtypes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论