提交 fd7655aa authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make float16 a recognized type (no casting as it would break on CPU).

上级 abd0a0fc
...@@ -42,7 +42,7 @@ AddConfigVar('cast_policy', ...@@ -42,7 +42,7 @@ AddConfigVar('cast_policy',
EnumStr('custom', 'numpy+floatX', EnumStr('custom', 'numpy+floatX',
# The 'numpy' policy was originally planned to provide a # The 'numpy' policy was originally planned to provide a
# smooth transition from numpy. It was meant to behave the # smooth transition from numpy. It was meant to behave the
# same asnumpy+floatX, but keeping float64 when numpy # same as numpy+floatX, but keeping float64 when numpy
# would. However the current implementation of some cast # would. However the current implementation of some cast
# mechanisms makes it a bit more complex to add than what # mechanisms makes it a bit more complex to add than what
# was expected, so it is currently not available. # was expected, so it is currently not available.
......
...@@ -232,6 +232,7 @@ class Scalar(Type): ...@@ -232,6 +232,7 @@ class Scalar(Type):
print(dtype, np.zeros(1, dtype=dtype).dtype.num) print(dtype, np.zeros(1, dtype=dtype).dtype.num)
""" """
return { # dtype: (py_type, c_type, cls_name) return { # dtype: (py_type, c_type, cls_name)
'float16': (numpy.float16, 'npy_float16', 'Float16'),
'float32': (numpy.float32, 'npy_float32', 'Float32'), 'float32': (numpy.float32, 'npy_float32', 'Float32'),
'float64': (numpy.float64, 'npy_float64', 'Float64'), 'float64': (numpy.float64, 'npy_float64', 'Float64'),
'complex128': (numpy.complex128, 'theano_complex128', 'complex128': (numpy.complex128, 'theano_complex128',
...@@ -501,6 +502,7 @@ uint8 = get_scalar_type('uint8') ...@@ -501,6 +502,7 @@ uint8 = get_scalar_type('uint8')
uint16 = get_scalar_type('uint16') uint16 = get_scalar_type('uint16')
uint32 = get_scalar_type('uint32') uint32 = get_scalar_type('uint32')
uint64 = get_scalar_type('uint64') uint64 = get_scalar_type('uint64')
float16 = get_scalar_type('float16')
float32 = get_scalar_type('float32') float32 = get_scalar_type('float32')
float64 = get_scalar_type('float64') float64 = get_scalar_type('float64')
complex64 = get_scalar_type('complex64') complex64 = get_scalar_type('complex64')
...@@ -508,7 +510,7 @@ complex128 = get_scalar_type('complex128') ...@@ -508,7 +510,7 @@ complex128 = get_scalar_type('complex128')
int_types = int8, int16, int32, int64 int_types = int8, int16, int32, int64
uint_types = uint8, uint16, uint32, uint64 uint_types = uint8, uint16, uint32, uint64
float_types = float32, float64 float_types = float16, float32, float64
complex_types = complex64, complex128 complex_types = complex64, complex128
discrete_types = int_types + uint_types discrete_types = int_types + uint_types
......
...@@ -234,6 +234,7 @@ class TensorType(Type): ...@@ -234,6 +234,7 @@ class TensorType(Type):
# complex64, etc. # complex64, etc.
try: try:
return { return {
'float16': (float, 'npy_float16', 'NPY_FLOAT16'),
'float32': (float, 'npy_float32', 'NPY_FLOAT32'), 'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'float64': (float, 'npy_float64', 'NPY_FLOAT64'), 'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'uint8': (int, 'npy_uint8', 'NPY_UINT8'), 'uint8': (int, 'npy_uint8', 'NPY_UINT8'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论