提交 c16684cc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use theano.floatX rather than fixed types.

上级 f07172d1
...@@ -4,6 +4,7 @@ import theano ...@@ -4,6 +4,7 @@ import theano
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import config
utt.seed_rng() utt.seed_rng()
...@@ -94,19 +95,17 @@ def test_det_grad(): ...@@ -94,19 +95,17 @@ def test_det_grad():
def test_extract_diag(): def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.fmatrix() x = theano.tensor.matrix()
g = extract_diag(x) g = extract_diag(x)
f = theano.function([x], g) f = theano.function([x], g)
v = numpy.array([0.1, 0.2, 0.4]).astype('float32') m = rng.rand(3,3).astype(config.floatX)
m = numpy.diag(v) v = numpy.diag(m)
r = f(m) r = f(m)
# The right diagonal is extracted # The right diagonal is extracted
assert (r == v).all() assert (r == v).all()
# Make sure we don't have a view
assert r.base is None
m = rng.rand(2, 3).astype('float32') m = rng.rand(2, 3).astype(config.floatX)
ok = False ok = False
try: try:
r = f(m) r = f(m)
...@@ -114,10 +113,10 @@ def test_extract_diag(): ...@@ -114,10 +113,10 @@ def test_extract_diag():
ok = True ok = True
assert ok assert ok
xx = theano.tensor.fvector() xx = theano.tensor.vector()
ok = False ok = False
try: try:
g = extract_diag(xx) extract_diag(xx)
except TypeError: except TypeError:
ok = True ok = True
assert ok assert ok
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论