提交 27654022 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add stuff so that scalar tests pass with the new bool type.

上级 146ef971
......@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None,
# function, sure, but nonetheless one we can and should support.
# So before we try to cast it make sure it even has a dtype
if (hasattr(g_cost.type, 'dtype') and
cost.type.dtype not in tensor.discrete_dtypes):
cost.type.dtype in tensor.continuous_dtypes):
# Here we enforce the constraint that floating point variables
# have the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype)
......@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None,
# This is to be enforced by the Op.grad method for the
# Op that outputs cost.
if hasattr(g_cost.type, 'dtype'):
assert g_cost.type.dtype not in tensor.discrete_dtypes
assert g_cost.type.dtype in tensor.continuous_dtypes
grad_dict[cost] = g_cost
......@@ -1334,12 +1334,11 @@ def _float_ones_like(x):
""" Like ones_like, but forces the object to have a
floating point dtype """
rval = tensor.ones_like(x)
dtype = x.type.dtype
if 'float' not in dtype:
dtype = theano.config.floatX
if rval.type.dtype.find('float') != -1:
return rval
return rval.astype(theano.config.floatX)
return tensor.ones_like(x, dtype=dtype)
class numeric_grad(object):
......
差异被折叠。
......@@ -1246,6 +1246,10 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the
# `cast()` function below.
_convert_to_bool = _conversion(
elemwise.Elemwise(scal.convert_to_bool), 'bool')
"""Cast to boolean"""
_convert_to_int8 = _conversion(
elemwise.Elemwise(scal.convert_to_int8), 'int8')
"""Cast to 8-bit integer"""
......@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion(
"""Cast to double-precision complex"""
_cast_mapping = {
'bool': _convert_to_bool,
'int8': _convert_to_int8,
'int16': _convert_to_int16,
'int32': _convert_to_int32,
......
......@@ -255,6 +255,7 @@ class TensorType(Type):
'float16': (float, 'npy_float16', 'NPY_FLOAT16'),
'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'bool': (bool, 'npy_bool', 'NPY_BOOL'),
'uint8': (int, 'npy_uint8', 'NPY_UINT8'),
'int8': (int, 'npy_int8', 'NPY_INT8'),
'uint16': (int, 'npy_uint16', 'NPY_UINT16'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论