提交 f4f40bbd authored 作者: Frederic's avatar Frederic

fix Usmm test dtype for a and z variable.

上级 9e0e8bbf
......@@ -649,9 +649,10 @@ class UsmmTests(unittest.TestCase):
y_data = numpy.asarray(self.y, dtype=dtype2)
if format2 != 'dense':
y_data = as_sparse_format(y_data, format2)
z_data = numpy.asarray(self.z, dtype=dtype3)
a_data = numpy.asarray(1.5, dtype=dtype3)
z_data = numpy.asarray(self.z, dtype=dtype4)
f_b_out = f_b(z_data, 1, x_data, y_data)
f_b_out = f_b(z_data, a_data, x_data, y_data)
# Can it work inplace?
inplace = dtype4 == theano.scalar.upcast(dtype1, dtype2, dtype3)
......@@ -664,8 +665,8 @@ class UsmmTests(unittest.TestCase):
f_a = theano.function([a, x, y], [],
updates=updates,
mode=mode)
f_a(1, x_data, y_data)
assert abs(z.get_value(borrow=True) - f_b_out).max() < 1e-4
f_a(a_data, x_data, y_data)
f_a_out = z.get_value(borrow=True)
else:
f_a = theano.function([a, x, y],
z - a * theano.sparse.dot(x, y),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论