提交 b20cca75 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix a bunch of casting policy problems.

上级 0fea0093
...@@ -79,8 +79,7 @@ def upcast(dtype, *dtypes): ...@@ -79,8 +79,7 @@ def upcast(dtype, *dtypes):
elif rval == 'float32': elif rval == 'float32':
if keep_float16[0]: if keep_float16[0]:
return 'float16' return 'float16'
else: return rval
return rval
def get_scalar_type(dtype): def get_scalar_type(dtype):
......
...@@ -252,9 +252,9 @@ class NumpyAutocaster(object): ...@@ -252,9 +252,9 @@ class NumpyAutocaster(object):
return numpy.asarray(x) return numpy.asarray(x)
elif config.cast_policy == 'numpy+floatX': elif config.cast_policy == 'numpy+floatX':
rval = numpy.asarray(x) rval = numpy.asarray(x)
if ((rval.dtype.startswith('float') and if ((not hasattr(x, 'dtype') and
rval.dtype != config.floatX and rval.dtype in ('float64', 'float32') and
not hasattr(x, 'dtype'))): rval.dtype != config.floatX)):
rval = theano._asarray(rval, dtype=config.floatX) rval = theano._asarray(rval, dtype=config.floatX)
return rval return rval
...@@ -277,7 +277,8 @@ class NumpyAutocaster(object): ...@@ -277,7 +277,8 @@ class NumpyAutocaster(object):
# unsafe downcast of float64 variables when config.floatX == 'float32' # unsafe downcast of float64 variables when config.floatX == 'float32'
# recall: float is numpy.float # recall: float is numpy.float
if ((isinstance(x, float) and if ((isinstance(x, float) and
config.floatX in self.dtypes)): config.floatX in self.dtypes and
config.floatX != 'float64')):
return theano._asarray(x, dtype=config.floatX) return theano._asarray(x, dtype=config.floatX)
for dtype in self.dtypes: for dtype in self.dtypes:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论