提交 27654022 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add stuff so that scalar tests pass with the new bool type.

上级 146ef971
...@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None,
# function, sure, but nonetheless one we can and should support. # function, sure, but nonetheless one we can and should support.
# So before we try to cast it make sure it even has a dtype # So before we try to cast it make sure it even has a dtype
if (hasattr(g_cost.type, 'dtype') and if (hasattr(g_cost.type, 'dtype') and
cost.type.dtype not in tensor.discrete_dtypes): cost.type.dtype in tensor.continuous_dtypes):
# Here we enforce the constraint that floating point variables # Here we enforce the constraint that floating point variables
# have the same dtype as their gradient. # have the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
...@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None,
# This is to be enforced by the Op.grad method for the # This is to be enforced by the Op.grad method for the
# Op that outputs cost. # Op that outputs cost.
if hasattr(g_cost.type, 'dtype'): if hasattr(g_cost.type, 'dtype'):
assert g_cost.type.dtype not in tensor.discrete_dtypes assert g_cost.type.dtype in tensor.continuous_dtypes
grad_dict[cost] = g_cost grad_dict[cost] = g_cost
...@@ -1334,12 +1334,11 @@ def _float_ones_like(x): ...@@ -1334,12 +1334,11 @@ def _float_ones_like(x):
""" Like ones_like, but forces the object to have a """ Like ones_like, but forces the object to have a
floating point dtype """ floating point dtype """
rval = tensor.ones_like(x) dtype = x.type.dtype
if 'float' not in dtype:
dtype = theano.config.floatX
if rval.type.dtype.find('float') != -1: return tensor.ones_like(x, dtype=dtype)
return rval
return rval.astype(theano.config.floatX)
class numeric_grad(object): class numeric_grad(object):
......
差异被折叠。
...@@ -1246,6 +1246,10 @@ def _conversion(real_value, name): ...@@ -1246,6 +1246,10 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the # what types you are casting to what. That logic is implemented by the
# `cast()` function below. # `cast()` function below.
_convert_to_bool = _conversion(
elemwise.Elemwise(scal.convert_to_bool), 'bool')
"""Cast to boolean"""
_convert_to_int8 = _conversion( _convert_to_int8 = _conversion(
elemwise.Elemwise(scal.convert_to_int8), 'int8') elemwise.Elemwise(scal.convert_to_int8), 'int8')
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
...@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion( ...@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion(
"""Cast to double-precision complex""" """Cast to double-precision complex"""
_cast_mapping = { _cast_mapping = {
'bool': _convert_to_bool,
'int8': _convert_to_int8, 'int8': _convert_to_int8,
'int16': _convert_to_int16, 'int16': _convert_to_int16,
'int32': _convert_to_int32, 'int32': _convert_to_int32,
......
...@@ -255,6 +255,7 @@ class TensorType(Type): ...@@ -255,6 +255,7 @@ class TensorType(Type):
'float16': (float, 'npy_float16', 'NPY_FLOAT16'), 'float16': (float, 'npy_float16', 'NPY_FLOAT16'),
'float32': (float, 'npy_float32', 'NPY_FLOAT32'), 'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'float64': (float, 'npy_float64', 'NPY_FLOAT64'), 'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'bool': (bool, 'npy_bool', 'NPY_BOOL'),
'uint8': (int, 'npy_uint8', 'NPY_UINT8'), 'uint8': (int, 'npy_uint8', 'NPY_UINT8'),
'int8': (int, 'npy_int8', 'NPY_INT8'), 'int8': (int, 'npy_int8', 'NPY_INT8'),
'uint16': (int, 'npy_uint16', 'NPY_UINT16'), 'uint16': (int, 'npy_uint16', 'NPY_UINT16'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论