提交 0960e8b8 authored 作者: Frederic Bastien's avatar Frederic Bastien

added the theano dtype floatX with THEANO_FLAGS=floatX=float{32,64} to allow…

added the theano dtype floatX with THEANO_FLAGS=floatX=float{32,64} to allow selection of the type at execution. Usefulle for GPU stuff and speed up some computation.
上级 b8a03dad
......@@ -10,6 +10,7 @@ default_={
'lib.amdlibm':False,
'op.set_flops':False,#currently used only in ConvOp. The profile mode will print the flops/s for the op.
'nvcc.fastmath':False,
'scalar.floatX':'float64',
}
#default value taked from env variable
......@@ -147,3 +148,5 @@ class TheanoConfig(object):
config = TheanoConfig()
if config.get('scalar.floatX') not in ['float32', 'float64']:
raise Exception("the configuration scalar.real must have value float32 or float64")
......@@ -228,6 +228,11 @@ class Kouh2008(object):
class Config(object):
use_gpu = True
dtype='float32'
dtype2=dtype
if dtype2=='floatX':
import theano.config as c
dtype2 = c.config.get('scalar.floatX')
rng_seed = 23498
n_hid = 300
......@@ -296,7 +301,7 @@ def test_bench_elemwise(n_iter=1000, **kwargs):
xval = numpy.asarray(
rng.uniform(size=(conf.ft_batchsize, x.type.shape[1])),
dtype=conf.dtype,
dtype=conf.dtype2,
)
yval = numpy.arange(conf.ft_batchsize)
for i in xrange(n_iter):
......
......@@ -56,6 +56,8 @@ def constant(x):
class Scalar(Type):
def __init__(self, dtype):
if dtype=='floatX':
dtype=config.config.get('floatX')
self.dtype = dtype
self.dtype_specs() # error checking
......@@ -252,6 +254,7 @@ uint32 = Scalar('uint32')
uint64 = Scalar('uint64')
float32 = Scalar('float32')
float64 = Scalar('float64')
floatX = Scalar(config.config.get('scalar.floatX'))
complex64 = Scalar('complex64')
complex128 = Scalar('complex128')
......@@ -934,6 +937,7 @@ convert_to_uint32 = Cast(uint32, name='convert_to_uint32')
convert_to_uint64 = Cast(uint64, name='convert_to_uint64')
convert_to_float32 = Cast(float32, name='convert_to_float32')
convert_to_float64 = Cast(float64, name='convert_to_float64')
convert_to_floatX = Cast(floatX, name='convert_to_floatX')
convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
......@@ -948,6 +952,7 @@ _cast_mapping = {
'uint64': convert_to_uint64,
'float32': convert_to_float32,
'float64': convert_to_float64,
'floatX': convert_to_floatX,
'complex64': convert_to_complex64,
'complex128': convert_to_complex128}
def cast(x, dtype):
......
......@@ -282,6 +282,8 @@ class TensorType(Type):
Optional name for this type.
"""
self.dtype = str(dtype)
if self.dtype=='floatX':
self.dtype=config.config.get('scalar.floatX')
self.broadcastable = tuple(broadcastable)
self.dtype_specs() # error checking is done there
self.name = name
......@@ -608,6 +610,7 @@ cscalar = TensorType('complex64', ())
zscalar = TensorType('complex128', ())
fscalar = TensorType('float32', ())
dscalar = TensorType('float64', ())
xscalar = TensorType('floatX',())
bscalar = TensorType('int8', ())
wscalar = TensorType('int16', ())
iscalar = TensorType('int32', ())
......@@ -628,6 +631,7 @@ cvector = TensorType('complex64', (False, ))
zvector = TensorType('complex128', (False, ))
fvector = TensorType('float32', (False, ))
dvector = TensorType('float64', (False, ))
xvector = TensorType('floatX', (False, ))
bvector = TensorType('int8', (False,))
wvector = TensorType('int16', (False,))
ivector = TensorType('int32', (False, ))
......@@ -645,6 +649,7 @@ cmatrix = TensorType('complex64', (False, False))
zmatrix = TensorType('complex128', (False, False))
fmatrix = TensorType('float32', (False, False))
dmatrix = TensorType('float64', (False, False))
xmatrix = TensorType('floatX', (False, False))
bmatrix = TensorType('int8', (False, False))
wmatrix = TensorType('int16', (False, False))
imatrix = TensorType('int32', (False, False))
......@@ -662,6 +667,7 @@ crow = TensorType('complex64', (True, False))
zrow = TensorType('complex128', (True, False))
frow = TensorType('float32', (True, False))
drow = TensorType('float64', (True, False))
xrow = TensorType('floatX', (True, False))
brow = TensorType('int8', (True, False))
wrow = TensorType('int16', (True, False))
irow = TensorType('int32', (True, False))
......@@ -675,6 +681,7 @@ ccol = TensorType('complex64', (False, True))
zcol = TensorType('complex128', (False, True))
fcol = TensorType('float32', (False, True))
dcol = TensorType('float64', (False, True))
xcol = TensorType('floatX', (False, True))
bcol = TensorType('int8', (False, True))
wcol = TensorType('int16', (False, True))
icol = TensorType('int32', (False, True))
......@@ -688,6 +695,7 @@ ctensor3 = TensorType('complex64', (False,)*3)
ztensor3 = TensorType('complex128', (False,)*3)
ftensor3 = TensorType('float32', (False,)*3)
dtensor3 = TensorType('float64', (False,)*3)
xtensor3 = TensorType('floatX', (False,)*3)
btensor3 = TensorType('int8', (False,)*3)
wtensor3 = TensorType('int16', (False,)*3)
itensor3 = TensorType('int32', (False,)*3)
......@@ -697,6 +705,7 @@ ctensor4 = TensorType('complex64', (False,)*4)
ztensor4 = TensorType('complex128', (False,)*4)
ftensor4 = TensorType('float32', (False,)*4)
dtensor4 = TensorType('float64', (False,)*4)
xtensor4 = TensorType('floatX', (False,)*4)
btensor4 = TensorType('int8', (False,)*4)
wtensor4 = TensorType('int16', (False,)*4)
itensor4 = TensorType('int32', (False,)*4)
......@@ -1093,6 +1102,9 @@ _convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), 'f
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), 'float64')
"""Cast to double-precision floating point"""
_convert_to_floatX = _conversion(elemwise.Elemwise(scal.convert_to_floatX), 'floatX')
"""Cast to floatX floating point"""
_convert_to_complex64 = _conversion(elemwise.Elemwise(scal.convert_to_complex64), 'complex64')
"""Cast to single-precision complex"""
......@@ -1110,6 +1122,7 @@ _cast_mapping = {
'uint64': _convert_to_uint64,
'float32': _convert_to_float32,
'float64': _convert_to_float64,
'floatX': _convert_to_floatX,
'complex64': _convert_to_complex64,
'complex128': _convert_to_complex128}
@constructor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论