提交 eebb3ef0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix linalg tests in float32.

上级 05cdd8bc
import numpy import numpy
import theano
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose
from theano.sandbox.linalg.ops import * from theano.sandbox.linalg.ops import *
...@@ -33,7 +35,7 @@ def test_inverse_correctness(): ...@@ -33,7 +35,7 @@ def test_inverse_correctness():
#todo: unittest randomseed #todo: unittest randomseed
rng = numpy.random.RandomState(12345) rng = numpy.random.RandomState(12345)
r = rng.randn(4,4) r = rng.randn(4,4).astype(theano.config.floatX)
x = tensor.matrix() x = tensor.matrix()
xi = matrix_inverse(x) xi = matrix_inverse(x)
...@@ -45,8 +47,8 @@ def test_inverse_correctness(): ...@@ -45,8 +47,8 @@ def test_inverse_correctness():
rir = numpy.dot(ri,r) rir = numpy.dot(ri,r)
rri = numpy.dot(r,ri) rri = numpy.dot(r,ri)
assert numpy.allclose(numpy.identity(4), rir), rir assert _allclose(numpy.identity(4), rir), rir
assert numpy.allclose(numpy.identity(4), rri), rri assert _allclose(numpy.identity(4), rri), rri
def test_inverse_grad(): def test_inverse_grad():
......
...@@ -949,6 +949,10 @@ SecondBroadcastTester = makeTester( ...@@ -949,6 +949,10 @@ SecondBroadcastTester = makeTester(
) )
) )
def_mode = get_default_mode()
print >>sys.stderr,'default mode:', def_mode
grmbl_mode = def_mode.excluding('local_fill_to_alloc')
print >>sys.stderr,'grmbl mode:', grmbl_mode
SecondSameRankTester = makeTester( SecondSameRankTester = makeTester(
name='SecondSameRankTester', name='SecondSameRankTester',
op=second, op=second,
...@@ -963,7 +967,9 @@ SecondSameRankTester = makeTester( ...@@ -963,7 +967,9 @@ SecondSameRankTester = makeTester(
bad_runtime=dict(itertools.chain( bad_runtime=dict(itertools.chain(
multi_dtype_checks((4, 5), (5, 4)), multi_dtype_checks((4, 5), (5, 4)),
multi_dtype_checks((1, 5), (5, 4)), multi_dtype_checks((1, 5), (5, 4)),
)) )),
#mode=get_default_mode().excluding('local_fill_to_alloc')
mode=grmbl_mode
) )
def test_eye(): def test_eye():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论