提交 6d93c2fc authored 作者: James Bergstra's avatar James Bergstra

removed floatX module

上级 7b50faa0
"""Provide xscalar, xvector, xmatrix, etc. pseudo-types
"""
if 0:
from .configparser import config
from theano.scalar import float64, float32
from theano.tensor import (fscalar, fvector, fmatrix, frow, fcol, ftensor3, ftensor4, dscalar,
dvector, dmatrix, drow, dcol, dtensor3, dtensor4)
#
# !!! set_floatX adds symbols directly to the module's symbol table !!!
#
def set_floatX(dtype = config.floatX):
""" add the xmatrix, xvector, xscalar etc. aliases to theano.tensor
"""
config.floatX = dtype
if dtype == 'float32': prefix = 'f'
elif dtype == 'float64' : prefix = 'd'
else: raise Exception("Bad param in set_floatX(%s). Only float32 and float64 are supported"%config.floatX)
#tensor.scalar stuff
globals()['floatX'] = globals()[dtype]
# convert_to_floatX = Cast(floatX, name='convert_to_floatX')
#tensor.tensor stuff
for symbol in ('scalar', 'vector', 'matrix', 'row', 'col','tensor3','tensor4'):
globals()['x'+symbol] = globals()[prefix+symbol]
#_convert_to_floatX = _conversion(elemwise.Elemwise(scal.convert_to_floatX), 'floatX')
from theano.tensor import *
import theano.config as config
from theano import function
#from theano.floatx import set_floatX, xscalar, xmatrix, xrow, xcol, xvector, xtensor3, xtensor4
import theano.floatX as FX
def test_floatX():
def test():
floatx=config.floatX
#TODO test other fct then ?vector
#float64 cast to float64 should not generate an op
x = dvector('x')
f = function([x],[cast(x,'float64')])
# print f.maker.env.toposort()
assert len(f.maker.env.toposort())==0
#float32 cast to float32 should not generate an op
x = fvector('x')
f = function([x],[cast(x,'float32')])
# print f.maker.env.toposort()
assert len(f.maker.env.toposort())==0
#floatX cast to float64
x = FX.xvector('x')
f = function([x],[cast(x,'float64')])
# print f.maker.env.toposort()
if floatx=='float64':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#floatX cast to float32
x = FX.xvector('x')
f = function([x],[cast(x,'float32')])
# print f.maker.env.toposort()
if floatx=='float32':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#float64 cast to floatX
x = dvector('x')
f = function([x],[cast(x,'floatX')])
# print f.maker.env.toposort()
if floatx=='float64':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#float32 cast to floatX
x = fvector('x')
f = function([x],[cast(x,'floatX')])
# print f.maker.env.toposort()
if floatx=='float32':
assert len(f.maker.env.toposort()) == 0
else:
assert len(f.maker.env.toposort()) == 1
#floatX cast to floatX
x = FX.xvector('x')
f = function([x],[cast(x,'floatX')])
# print f.maker.env.toposort()
assert len(f.maker.env.toposort()) == 0
orig_floatx = config.floatX
try:
print 'float32'
FX.set_floatX('float32')
test()
print 'float64'
FX.set_floatX('float64')
test()
finally:
pass
FX.set_floatX(orig_floatx)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论