提交 c26605d5 authored 作者: Frederic Bastien's avatar Frederic Bastien

make shared constructor for scalar respect floatX and make test that tensor…

make shared constructor for scalar respect floatX and make test that tensor shared constructor respects it.
上级 a266b4c1
......@@ -17,7 +17,7 @@ class Test_SharedVariable(unittest.TestCase):
else:
assert shared(7).type == theano.tensor.lscalar
assert shared(7.0).type == theano.tensor.dscalar
assert shared(7.0).type == theano.tensor.scalar().type
assert shared(7, dtype='float64').type == theano.tensor.dscalar
# test tensor constructor
......@@ -110,18 +110,121 @@ class Test_SharedVariable(unittest.TestCase):
u.value = uval
assert u.value is uval
def test_strict(self):
def test_scalar_strict(self):
def f(var, val): var.value = val
b = shared(numpy.int64(7), strict=True)
#assert b.type == Scalar('int64')
assert b.type == theano.tensor.lscalar
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.float64(7.234), strict=True)
#assert b.type == Scalar('float64')
assert b.type == theano.tensor.dscalar
self.failUnlessRaises(TypeError, f, b, 8)
b = shared(numpy.float32(7.234), strict=True)
assert b.type == theano.tensor.fscalar
self.failUnlessRaises(TypeError, f, b, 8)
b = shared(numpy.float(7.234), strict=True)
assert b.type == theano.tensor.dscalar
self.failUnlessRaises(TypeError, f, b, 8)
b = shared(7.234, strict=True)
assert b.type == theano.tensor.dscalar
self.failUnlessRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
def test_tensor_strict(self):
def f(var, val): var.value = val
b = shared(numpy.int64([7]), strict=True)
assert b.type == theano.tensor.lvector
self.failUnlessRaises(TypeError, f, b, 8.23)
b = shared(numpy.float64([7.234]), strict=True)
assert b.type == theano.tensor.dvector
self.failUnlessRaises(TypeError, f, b, 8)
b = shared(numpy.float32([7.234]), strict=True)
assert b.type == theano.tensor.fvector
self.failUnlessRaises(TypeError, f, b, 8)
#numpy.float([7.234]) don't work
# b = shared(numpy.float([7.234]), strict=True)
# assert b.type == theano.tensor.dvector
# self.failUnlessRaises(TypeError, f, b, 8)
#This generate a generic type. Should we cast? I don't think.
# b = shared([7.234], strict=True)
# assert b.type == theano.tensor.dvector
# self.failUnlessRaises(TypeError, f, b, 8)
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
def test_scalar_floatX(self):
def f(var, val): var.value = val
b = shared(numpy.int64(7))
assert b.type == theano.tensor.lscalar
f(b,8.23)
b = shared(numpy.float64(7.234))
assert b.type == theano.tensor.dscalar
f(b,8)
b = shared(numpy.float32(7.234))
assert b.type == theano.tensor.fscalar
f(b,8)
b = shared(numpy.float(7.234))
assert b.dtype == theano.config.floatX
f(b,8)
b = shared(7.234)
assert b.dtype == theano.config.floatX
f(b,8)
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
def test_tensor_floatX(self):
def f(var, val): var.value = val
b = shared(numpy.int64([7]))
assert b.type == theano.tensor.lvector
f(b,[8.23])
b = shared(numpy.float64([7.234]))
assert b.type == theano.tensor.dvector
f(b,[8])
b = shared(numpy.float32([7.234]))
assert b.type == theano.tensor.fvector
f(b,[8])
#numpy.float([7.234]) don't work
# b = shared(numpy.float([7.234]))
# assert b.type == theano.tensor.dvector
# f(b,[8])
#This generate a generic type. Should we cast? I don't think.
# b = shared([7.234])
# assert b.type == theano.tensor.dvector
# f(b,[8])
b = shared(numpy.asarray([7.234],dtype=theano.config.floatX))
assert b.dtype == theano.config.floatX
f(b,[8])
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f, b, numpy.random.rand(5,5))
......
......@@ -3,6 +3,7 @@ import numpy
import theano.tensor.basic
from basic import TensorType, _tensor_py_operators
from theano.compile import shared_constructor, SharedVariable
from theano import config
class TensorSharedVariable(SharedVariable, _tensor_py_operators):
pass
......@@ -43,7 +44,13 @@ def scalar_constructor(value, name=None, strict=False, dtype=None):
if not isinstance (value, (numpy.number, float, int)):
raise TypeError()
if dtype is None:
if isinstance(value, float):
if isinstance(value, numpy.float64):
dtype = 'float64'
elif isinstance(value, numpy.float32):
dtype = 'float32'
elif isinstance(value, float) and not strict:
dtype = config.floatX
elif isinstance(value, float):
dtype = 'float64'
elif isinstance(value, int):
dtype = 'int64'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论