提交 16a0b11c authored 作者: Thomas George's avatar Thomas George

general dtype inference using solve

上级 ba9343e0
...@@ -202,12 +202,12 @@ class Solve(Op): ...@@ -202,12 +202,12 @@ class Solve(Op):
b = as_tensor_variable(b) b = as_tensor_variable(b)
assert A.ndim == 2 assert A.ndim == 2
assert b.ndim in [1, 2] assert b.ndim in [1, 2]
if ((A.dtype == 'float32' and b.dtype == 'float32')
or (A.dtype in ['int8', 'int16'] and b.dtype == 'float32') # infer dtype by solving the most simple
or (b.dtype in ['int8', 'int16'] and A.dtype == 'float32')): # case with (1, 1) matrices
o_dtype = 'float32' o_dtype = scipy.linalg.solve(
else: numpy.eye(1).astype(A.dtype),
o_dtype = 'float64' numpy.eye(1).astype(b.dtype)).dtype
x = tensor.tensor( x = tensor.tensor(
broadcastable=b.broadcastable, broadcastable=b.broadcastable,
dtype=o_dtype) dtype=o_dtype)
......
...@@ -235,12 +235,14 @@ class test_Solve(utt.InferShapeTester): ...@@ -235,12 +235,14 @@ class test_Solve(utt.InferShapeTester):
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.") raise SkipTest("Scipy needed for the Solve op.")
dtypes = ['int8', 'int16', 'int32', 'int64', dtypes = ['uint8', 'uint16', 'uint32', 'uint64',
'float32', 'float64'] 'int8', 'int16', 'int32', 'int64',
'float16', 'float32', 'float64']
A_val = numpy.eye(2) A_val = numpy.eye(2)
b_val = numpy.ones((2, 1)) b_val = numpy.ones((2, 1))
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes): for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = tensor.matrix(dtype=A_dtype) A = tensor.matrix(dtype=A_dtype)
b = tensor.matrix(dtype=b_dtype) b = tensor.matrix(dtype=b_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论