提交 ba9343e0 authored 作者: Thomas George's avatar Thomas George

fixed special type cast cases in slinalg.py and added corresponding testsw

上级 1be571fc
......@@ -202,10 +202,16 @@ class Solve(Op):
b = as_tensor_variable(b)
assert A.ndim == 2
assert b.ndim in [1, 2]
otype = tensor.tensor(
if ((A.dtype == 'float32' and b.dtype == 'float32')
or (A.dtype in ['int8', 'int16'] and b.dtype == 'float32')
or (b.dtype in ['int8', 'int16'] and A.dtype == 'float32')):
o_dtype = 'float32'
else:
o_dtype = 'float64'
x = tensor.tensor(
broadcastable=b.broadcastable,
dtype=(A * b).dtype)
return Apply(self, [A, b], [otype])
dtype=o_dtype)
return Apply(self, [A, b], [x])
def perform(self, node, inputs, output_storage):
A, b = inputs
......
......@@ -7,6 +7,8 @@ from numpy.testing import assert_array_almost_equal
from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy import inf
import itertools
import theano
from theano import tensor, function
from theano.tensor.basic import _allclose
......@@ -232,41 +234,23 @@ class test_Solve(utt.InferShapeTester):
def test_solve_dtype(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
A_32 = tensor.fmatrix()
b_32 = tensor.fmatrix()
A_64 = tensor.dmatrix()
b_64 = tensor.dmatrix()
dtypes = ['int8', 'int16', 'int32', 'int64',
'float32', 'float64']
A_val = numpy.eye(2)
b_val = numpy.ones((2, 1))
# A 64, b 64
x_6464_out = solve(A_64, b_64)
fn_6464 = function([A_64, b_64], x_6464_out)
x_6464_result = fn_6464(A_val.astype('float64'), b_val.astype('float64'))
assert x_6464_out.dtype == 'float64'
assert x_6464_result.dtype == 'float64'
# A 32, b 32
x_3232_out = solve(A_32, b_32)
fn_3232 = function([A_32, b_32], x_3232_out)
x_3232_result = fn_3232(A_val.astype('float32'), b_val.astype('float32'))
assert x_3232_out.dtype == 'float32'
assert x_3232_result.dtype == 'float32'
# A 64, b 32
x_6432_out = solve(A_64, b_32)
fn_6432 = function([A_64, b_32], x_6432_out)
x_6432_result = fn_6432(A_val.astype('float64'), b_val.astype('float32'))
assert x_6432_out.dtype == 'float64'
assert x_6432_result.dtype == 'float64'
# A 32, b 64
x_3264_out = solve(A_32, b_64)
fn_3264 = function([A_32, b_64], x_3264_out)
x_3264_result = fn_3264(A_val.astype('float32'), b_val.astype('float64'))
assert x_3264_out.dtype == 'float64'
assert x_3264_result.dtype == 'float64'
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = tensor.matrix(dtype=A_dtype)
b = tensor.matrix(dtype=b_dtype)
x = solve(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
print(A_dtype, b_dtype)
print('x', x.dtype, x_result.dtype, x.dtype == x_result.dtype)
assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论