提交 1dd5787a authored 作者: Thomas George's avatar Thomas George

added test for output dtype (solve)

上级 630876f4
......@@ -229,6 +229,45 @@ class test_Solve(utt.InferShapeTester):
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
def test_solve_dtype(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
A_32 = tensor.fmatrix()
b_32 = tensor.fmatrix()
A_64 = tensor.dmatrix()
b_64 = tensor.dmatrix()
A_val = numpy.eye(2)
b_val = numpy.ones((2, 1))
# A 64, b 64
x_6464_out = solve(A_64, b_64)
fn_6464 = function([A_64, b_64], x_6464_out)
x_6464_result = fn_6464(A_val.astype('float64'), b_val.astype('float64'))
assert x_6464_out.dtype == 'float64'
assert x_6464_result.dtype == 'float64'
# A 32, b 32
x_3232_out = solve(A_32, b_32)
fn_3232 = function([A_32, b_32], x_3232_out)
x_3232_result = fn_3232(A_val.astype('float32'), b_val.astype('float32'))
assert x_3232_out.dtype == 'float32'
assert x_3232_result.dtype == 'float32'
# A 64, b 32
x_6432_out = solve(A_64, b_32)
fn_6432 = function([A_64, b_32], x_6432_out)
x_6432_result = fn_6432(A_val.astype('float64'), b_val.astype('float32'))
assert x_6432_out.dtype == 'float64'
assert x_6432_result.dtype == 'float64'
# A 32, b 64
x_3264_out = solve(A_32, b_64)
fn_3264 = function([A_32, b_64], x_3264_out)
x_3264_result = fn_3264(A_val.astype('float32'), b_val.astype('float64'))
assert x_3264_out.dtype == 'float64'
assert x_3264_result.dtype == 'float64'
def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论