提交 ae398862 authored 作者: Frederic's avatar Frederic

MulSS and AddSS now support different input dtype.

上级 ced94276
...@@ -1654,13 +1654,12 @@ class AddSS(gof.op.Op): ...@@ -1654,13 +1654,12 @@ class AddSS(gof.op.Op):
def make_node(self, x, y): def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y]) x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype: out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
raise NotImplementedError()
if x.type.format != y.type.format: if x.type.format != y.type.format:
raise NotImplementedError() raise NotImplementedError()
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=out_dtype,
format=x.type.format format=x.type.format
).make_variable()]) ).make_variable()])
...@@ -1923,11 +1922,16 @@ class MulSS(gof.op.Op): ...@@ -1923,11 +1922,16 @@ class MulSS(gof.op.Op):
def make_node(self, x, y): def make_node(self, x, y):
x, y = as_sparse_variable(x), as_sparse_variable(y) x, y = as_sparse_variable(x), as_sparse_variable(y)
if x.type != y.type: out_dtype = scalar.upcast(x.type.dtype, y.type.dtype)
if x.type.format != y.type.format:
raise NotImplementedError( raise NotImplementedError(
"MulSS not supported for differing types. " "MulSS not supported for differing types. "
"Got %s and %s." % (str(x.type), str(y.type))) "Got %s and %s." % (str(x.type), str(y.type)))
return gof.Apply(self, [x, y], [x.type()]) return gof.Apply(self, [x, y],
[SparseType(dtype=out_dtype,
format=x.type.format
)()])
def perform(self, node, (x, y), (out, )): def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y) assert _is_sparse(x) and _is_sparse(y)
......
...@@ -536,35 +536,38 @@ class T_AddMul(unittest.TestCase): ...@@ -536,35 +536,38 @@ class T_AddMul(unittest.TestCase):
def _testSS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]), def _testSS(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])): array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
for mtype in _mtypes: for mtype in _mtypes:
a = mtype(array1) for dtype1, dtype2 in [('float64', 'int8'),
aR = as_sparse_variable(a) ('int8', 'float64'),
self.assertFalse(aR.data is a) ]:
self.assertTrue(_is_sparse(a)) a = mtype(array1).astype(dtype1)
self.assertTrue(_is_sparse_variable(aR)) aR = as_sparse_variable(a)
self.assertFalse(aR.data is a)
self.assertTrue(_is_sparse(a))
self.assertTrue(_is_sparse_variable(aR))
b = mtype(array2) b = mtype(array2).astype(dtype2)
bR = as_sparse_variable(b) bR = as_sparse_variable(b)
self.assertFalse(bR.data is b) self.assertFalse(bR.data is b)
self.assertTrue(_is_sparse(b)) self.assertTrue(_is_sparse(b))
self.assertTrue(_is_sparse_variable(bR)) self.assertTrue(_is_sparse_variable(bR))
apb = op(aR, bR) apb = op(aR, bR)
self.assertTrue(_is_sparse_variable(apb)) self.assertTrue(_is_sparse_variable(apb))
self.assertTrue(apb.type.dtype == aR.type.dtype, apb.type.dtype) self.assertTrue(apb.type.format == aR.type.format, apb.type.format)
self.assertTrue(apb.type.dtype == bR.type.dtype, apb.type.dtype) self.assertTrue(apb.type.format == bR.type.format, apb.type.format)
self.assertTrue(apb.type.format == aR.type.format, apb.type.format)
self.assertTrue(apb.type.format == bR.type.format, apb.type.format) val = eval_outputs([apb])
self.assertTrue(val.shape == (3, 2))
val = eval_outputs([apb]) if op is add:
self.assertTrue(val.shape == (3, 2)) self.assertTrue(numpy.all(val.todense() == (array1 + array2)))
if op is add: if dtype1.startswith('float') and dtype2.startswith('float'):
self.assertTrue(numpy.all(val.todense() == (array1 + array2))) verify_grad_sparse(op, [a, b], structured=False)
verify_grad_sparse(op, [a, b], structured=False) elif op is mul:
elif op is mul: self.assertTrue(numpy.all(val.todense()
self.assertTrue(numpy.all(val.todense() == (array1 * array2)))
== (array1 * array2))) if dtype1.startswith('float') and dtype2.startswith('float'):
verify_grad_sparse(op, [a, b], structured=False) verify_grad_sparse(op, [a, b], structured=False)
def _testSD(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]), def _testSD(self, op, array1=numpy.array([[1., 0], [3, 0], [0, 6]]),
array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])): array2=numpy.asarray([[0, 2.], [0, 4], [5, 0]])):
...@@ -639,29 +642,7 @@ class T_AddMul(unittest.TestCase): ...@@ -639,29 +642,7 @@ class T_AddMul(unittest.TestCase):
array2 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='int32') array2 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='int32')
array3 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='int8') array3 = numpy.array([[1, 0], [3, 0], [0, 6]], dtype='int8')
# AddSS and MulSS # AddSS and MulSS upcated tested in _testSS
for mtype in _mtypes:
a = mtype(array1)
aR = as_sparse_variable(a)
b = mtype(array2)
bR = as_sparse_variable(b)
c = mtype(array3)
cR = as_sparse_variable(c)
# Ops that do not upcast
self.assertRaises(NotImplementedError, add, aR, bR)
self.assertRaises(NotImplementedError, add, bR, aR)
self.assertRaises(NotImplementedError, add, bR, cR)
self.assertRaises(NotImplementedError, add, cR, bR)
self.assertRaises(NotImplementedError, add, aR, cR)
self.assertRaises(NotImplementedError, add, cR, aR)
self.assertRaises(NotImplementedError, mul, aR, bR)
self.assertRaises(NotImplementedError, mul, bR, aR)
self.assertRaises(NotImplementedError, mul, bR, cR)
self.assertRaises(NotImplementedError, mul, cR, bR)
self.assertRaises(NotImplementedError, mul, aR, cR)
self.assertRaises(NotImplementedError, mul, cR, aR)
# AddSD and MulSD # AddSD and MulSD
for mtype in _mtypes: for mtype in _mtypes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论