提交 9a8eda51 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Update casting under numpy+floatX since that is tested.

上级 ff085441
......@@ -56,18 +56,29 @@ def upcast(dtype, *dtypes):
# modified within `make_array`.
keep_float32 = [(config.cast_policy == 'numpy+floatX' and
config.floatX == 'float32')]
keep_float16 = [(config.cast_policy == 'numpy+floatX' and
config.floatX == 'float16')]
def make_array(dt):
if dt == 'float64':
# There is an explicit float64 dtype: we cannot keep float32.
keep_float32[0] = False
keep_float16[0] = False
if dt == 'float32':
keep_float16[0] = False
return numpy.zeros((), dtype=dt)
z = make_array(dtype)
for dt in dtypes:
z = z + make_array(dt=dt)
rval = str(z.dtype)
if rval == 'float64' and keep_float32[0]:
return 'float32'
if rval == 'float64':
if keep_float16[0]:
return 'float16'
if keep_float32[0]:
return 'float32'
elif rval == 'float32':
if keep_float16[0]:
return 'float16'
else:
return rval
......
......@@ -252,10 +252,10 @@ class NumpyAutocaster(object):
return numpy.asarray(x)
elif config.cast_policy == 'numpy+floatX':
rval = numpy.asarray(x)
if ((rval.dtype == 'float64' and # numpy wants float64
config.floatX == 'float32' and # but we prefer float32
not hasattr(x, 'dtype'))): # and `x` was not typed
rval = theano._asarray(rval, dtype='float32')
if ((rval.dtype.startswith('float') and
rval.dtype != config.floatX and
not hasattr(x, 'dtype'))):
rval = theano._asarray(rval, dtype=config.floatX)
return rval
# The following is the original code, corresponding to the 'custom'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论