提交 2805e910 authored 作者: Frederic's avatar Frederic

test Usmm* with infer_shape.

上级 75b81e3a
...@@ -719,6 +719,56 @@ class UsmmTests(unittest.TestCase): ...@@ -719,6 +719,56 @@ class UsmmTests(unittest.TestCase):
assert isinstance(topo[2].op, theano.sparse.Usmm) assert isinstance(topo[2].op, theano.sparse.Usmm)
def test_infer_shape(self):
def mat(format, name, dtype):
if format == 'dense':
return theano.tensor.matrix(name, dtype=dtype)
else:
return theano.sparse.matrix(format, name, dtype=dtype)
params = [('float32', 'float64', 'int16', 'complex64', 'csc', 'dense'),
('float32', 'float64', 'int16', 'complex64', 'csr', 'dense')]
for dtype1, dtype2, dtype3, dtype4, format1, format2 in params:
if format1 == 'dense' and format2 == 'dense':
# Usmm won't be used!
continue
x = mat(format1, 'x', dtype1)
y = mat(format2, 'y', dtype2)
a = theano.tensor.scalar('a', dtype=dtype3)
z = theano.shared(numpy.asarray(self.z, dtype=dtype4).copy())
f_b = lambda z, a, x, y: z - a * (x * y)
x_data = numpy.asarray(self.x, dtype=dtype1)
if format1 != 'dense':
x_data = as_sparse_format(x_data, format1)
y_data = numpy.asarray(self.y, dtype=dtype2)
if format2 != 'dense':
y_data = as_sparse_format(y_data, format2)
a_data = numpy.asarray(1.5, dtype=dtype3)
z_data = numpy.asarray(self.z, dtype=dtype4)
f_b_out = f_b(z_data, a_data, x_data, y_data)
# Can it work inplace?
inplace = dtype4 == theano.scalar.upcast(dtype1, dtype2, dtype3)
# To make it easier to check the toposort
mode = theano.compile.mode.get_default_mode().excluding('fusion')
# test infer_shape of Dot got applied
f_shape = theano.function([a, x, y],
(z - a * theano.sparse.dot(x, y)).shape,
mode=mode)
assert all(f_shape(a_data, x_data, y_data) == f_b_out.shape)
topo = f_shape.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE':
nb = 0
else:
nb = 1
assert sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo]) == nb
def test_shape_i(): def test_shape_i():
sparse_dtype = 'float32' sparse_dtype = 'float32'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论