提交 9a8eda51 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Update casting under numpy+floatX since that is tested.

上级 ff085441
...@@ -56,18 +56,29 @@ def upcast(dtype, *dtypes): ...@@ -56,18 +56,29 @@ def upcast(dtype, *dtypes):
# modified within `make_array`. # modified within `make_array`.
keep_float32 = [(config.cast_policy == 'numpy+floatX' and keep_float32 = [(config.cast_policy == 'numpy+floatX' and
config.floatX == 'float32')] config.floatX == 'float32')]
keep_float16 = [(config.cast_policy == 'numpy+floatX' and
config.floatX == 'float16')]
def make_array(dt): def make_array(dt):
if dt == 'float64': if dt == 'float64':
# There is an explicit float64 dtype: we cannot keep float32. # There is an explicit float64 dtype: we cannot keep float32.
keep_float32[0] = False keep_float32[0] = False
keep_float16[0] = False
if dt == 'float32':
keep_float16[0] = False
return numpy.zeros((), dtype=dt) return numpy.zeros((), dtype=dt)
z = make_array(dtype) z = make_array(dtype)
for dt in dtypes: for dt in dtypes:
z = z + make_array(dt=dt) z = z + make_array(dt=dt)
rval = str(z.dtype) rval = str(z.dtype)
if rval == 'float64' and keep_float32[0]: if rval == 'float64':
if keep_float16[0]:
return 'float16'
if keep_float32[0]:
return 'float32' return 'float32'
elif rval == 'float32':
if keep_float16[0]:
return 'float16'
else: else:
return rval return rval
......
...@@ -252,10 +252,10 @@ class NumpyAutocaster(object): ...@@ -252,10 +252,10 @@ class NumpyAutocaster(object):
return numpy.asarray(x) return numpy.asarray(x)
elif config.cast_policy == 'numpy+floatX': elif config.cast_policy == 'numpy+floatX':
rval = numpy.asarray(x) rval = numpy.asarray(x)
if ((rval.dtype == 'float64' and # numpy wants float64 if ((rval.dtype.startswith('float') and
config.floatX == 'float32' and # but we prefer float32 rval.dtype != config.floatX and
not hasattr(x, 'dtype'))): # and `x` was not typed not hasattr(x, 'dtype'))):
rval = theano._asarray(rval, dtype='float32') rval = theano._asarray(rval, dtype=config.floatX)
return rval return rval
# The following is the original code, corresponding to the 'custom' # The following is the original code, corresponding to the 'custom'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论