提交 ba9343e0 authored 作者: Thomas George's avatar Thomas George

fixed special type cast cases in slinalg.py and added corresponding testsw

上级 1be571fc
...@@ -202,10 +202,16 @@ class Solve(Op): ...@@ -202,10 +202,16 @@ class Solve(Op):
b = as_tensor_variable(b) b = as_tensor_variable(b)
assert A.ndim == 2 assert A.ndim == 2
assert b.ndim in [1, 2] assert b.ndim in [1, 2]
otype = tensor.tensor( if ((A.dtype == 'float32' and b.dtype == 'float32')
or (A.dtype in ['int8', 'int16'] and b.dtype == 'float32')
or (b.dtype in ['int8', 'int16'] and A.dtype == 'float32')):
o_dtype = 'float32'
else:
o_dtype = 'float64'
x = tensor.tensor(
broadcastable=b.broadcastable, broadcastable=b.broadcastable,
dtype=(A * b).dtype) dtype=o_dtype)
return Apply(self, [A, b], [otype]) return Apply(self, [A, b], [x])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
......
...@@ -7,6 +7,8 @@ from numpy.testing import assert_array_almost_equal ...@@ -7,6 +7,8 @@ from numpy.testing import assert_array_almost_equal
from numpy.testing import dec, assert_array_equal, assert_allclose from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy import inf from numpy import inf
import itertools
import theano import theano
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
...@@ -232,41 +234,23 @@ class test_Solve(utt.InferShapeTester): ...@@ -232,41 +234,23 @@ class test_Solve(utt.InferShapeTester):
def test_solve_dtype(self): def test_solve_dtype(self):
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.") raise SkipTest("Scipy needed for the Solve op.")
A_32 = tensor.fmatrix()
b_32 = tensor.fmatrix() dtypes = ['int8', 'int16', 'int32', 'int64',
A_64 = tensor.dmatrix() 'float32', 'float64']
b_64 = tensor.dmatrix()
A_val = numpy.eye(2) A_val = numpy.eye(2)
b_val = numpy.ones((2, 1)) b_val = numpy.ones((2, 1))
# A 64, b 64 for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
x_6464_out = solve(A_64, b_64) A = tensor.matrix(dtype=A_dtype)
fn_6464 = function([A_64, b_64], x_6464_out) b = tensor.matrix(dtype=b_dtype)
x_6464_result = fn_6464(A_val.astype('float64'), b_val.astype('float64')) x = solve(A, b)
assert x_6464_out.dtype == 'float64' fn = function([A, b], x)
assert x_6464_result.dtype == 'float64' x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
print(A_dtype, b_dtype)
# A 32, b 32 print('x', x.dtype, x_result.dtype, x.dtype == x_result.dtype)
x_3232_out = solve(A_32, b_32)
fn_3232 = function([A_32, b_32], x_3232_out)
x_3232_result = fn_3232(A_val.astype('float32'), b_val.astype('float32'))
assert x_3232_out.dtype == 'float32'
assert x_3232_result.dtype == 'float32'
# A 64, b 32
x_6432_out = solve(A_64, b_32)
fn_6432 = function([A_64, b_32], x_6432_out)
x_6432_result = fn_6432(A_val.astype('float64'), b_val.astype('float32'))
assert x_6432_out.dtype == 'float64'
assert x_6432_result.dtype == 'float64'
# A 32, b 64
x_3264_out = solve(A_32, b_64)
fn_3264 = function([A_32, b_64], x_3264_out)
x_3264_result = fn_3264(A_val.astype('float32'), b_val.astype('float64'))
assert x_3264_out.dtype == 'float64'
assert x_3264_result.dtype == 'float64'
assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, A_structure, lower, rng): def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical # ensure diagonal elements of A relatively large to avoid numerical
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论