提交 27654022 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add stuff so that scalar tests pass with the new bool type.

上级 146ef971
...@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None,
# function, sure, but nonetheless one we can and should support. # function, sure, but nonetheless one we can and should support.
# So before we try to cast it make sure it even has a dtype # So before we try to cast it make sure it even has a dtype
if (hasattr(g_cost.type, 'dtype') and if (hasattr(g_cost.type, 'dtype') and
cost.type.dtype not in tensor.discrete_dtypes): cost.type.dtype in tensor.continuous_dtypes):
# Here we enforce the constraint that floating point variables # Here we enforce the constraint that floating point variables
# have the same dtype as their gradient. # have the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
...@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None,
# This is to be enforced by the Op.grad method for the # This is to be enforced by the Op.grad method for the
# Op that outputs cost. # Op that outputs cost.
if hasattr(g_cost.type, 'dtype'): if hasattr(g_cost.type, 'dtype'):
assert g_cost.type.dtype not in tensor.discrete_dtypes assert g_cost.type.dtype in tensor.continuous_dtypes
grad_dict[cost] = g_cost grad_dict[cost] = g_cost
...@@ -1334,12 +1334,11 @@ def _float_ones_like(x): ...@@ -1334,12 +1334,11 @@ def _float_ones_like(x):
""" Like ones_like, but forces the object to have a """ Like ones_like, but forces the object to have a
floating point dtype """ floating point dtype """
rval = tensor.ones_like(x) dtype = x.type.dtype
if 'float' not in dtype:
dtype = theano.config.floatX
if rval.type.dtype.find('float') != -1: return tensor.ones_like(x, dtype=dtype)
return rval
return rval.astype(theano.config.floatX)
class numeric_grad(object): class numeric_grad(object):
......
...@@ -34,6 +34,7 @@ from theano.gradient import grad_undefined ...@@ -34,6 +34,7 @@ from theano.gradient import grad_undefined
from theano.printing import pprint from theano.printing import pprint
import collections import collections
builtin_bool = bool
builtin_complex = complex builtin_complex = complex
builtin_int = int builtin_int = int
builtin_float = float builtin_float = float
...@@ -161,7 +162,7 @@ class Scalar(Type): ...@@ -161,7 +162,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType. TODO: refactor to be named ScalarType for consistency with TensorType.
""" """
__props__ = ('dtype',)
ndim = 0 ndim = 0
def __init__(self, dtype): def __init__(self, dtype):
...@@ -200,6 +201,8 @@ class Scalar(Type): ...@@ -200,6 +201,8 @@ class Scalar(Type):
def values_eq_approx(self, a, b, tolerance=1e-4): def values_eq_approx(self, a, b, tolerance=1e-4):
# The addition have risk of overflow especially with [u]int8 # The addition have risk of overflow especially with [u]int8
if self.dtype == 'bool':
return a == b
diff = a - b diff = a - b
if diff == 0: if diff == 0:
return True return True
...@@ -227,12 +230,6 @@ class Scalar(Type): ...@@ -227,12 +230,6 @@ class Scalar(Type):
else: else:
return [] return []
def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype
def __hash__(self):
return hash('theano.scalar.Scalar') ^ hash(self.dtype)
def dtype_specs(self): def dtype_specs(self):
try: try:
# To help debug dtype/typenum problem, here is code to get # To help debug dtype/typenum problem, here is code to get
...@@ -244,7 +241,8 @@ class Scalar(Type): ...@@ -244,7 +241,8 @@ class Scalar(Type):
# now, as Theano always expect the exact typenum that # now, as Theano always expect the exact typenum that
# correspond to our supported dtype. # correspond to our supported dtype.
""" """
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc', for dtype in ['bool', 'int8', 'uint8', 'short', 'ushort', 'intc',
'uintc',
'longlong', 'ulonglong', 'single', 'double', 'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble', 'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32', 'float32', 'float64', 'int8', 'int16', 'int32',
...@@ -260,6 +258,7 @@ class Scalar(Type): ...@@ -260,6 +258,7 @@ class Scalar(Type):
'complex128': (numpy.complex128, 'theano_complex128', 'complex128': (numpy.complex128, 'theano_complex128',
'Complex128'), 'Complex128'),
'complex64': (numpy.complex64, 'theano_complex64', 'Complex64'), 'complex64': (numpy.complex64, 'theano_complex64', 'Complex64'),
'bool': (numpy.bool_, 'npy_bool', 'Bool'),
'uint8': (numpy.uint8, 'npy_uint8', 'UInt8'), 'uint8': (numpy.uint8, 'npy_uint8', 'UInt8'),
'int8': (numpy.int8, 'npy_int8', 'Int8'), 'int8': (numpy.int8, 'npy_int8', 'Int8'),
'uint16': (numpy.uint16, 'npy_uint16', 'UInt16'), 'uint16': (numpy.uint16, 'npy_uint16', 'UInt16'),
...@@ -288,12 +287,13 @@ class Scalar(Type): ...@@ -288,12 +287,13 @@ class Scalar(Type):
def c_literal(self, data): def c_literal(self, data):
if 'complex' in self.dtype: if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.") raise NotImplementedError("No literal for complex values.")
if self.dtype == 'bool':
return '1' if b else '0'
return str(data) return str(data)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
if(check_input): if(check_input):
pre = """ pre = """
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
typedef %(dtype)s dtype_%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1]) """ % dict(name=name, dtype=self.dtype_specs()[1])
else: else:
...@@ -309,6 +309,7 @@ class Scalar(Type): ...@@ -309,6 +309,7 @@ class Scalar(Type):
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
if self.dtype == 'float16': if self.dtype == 'float16':
# This doesn't work at the numpy level
raise NotImplementedError('float16') raise NotImplementedError('float16')
specs = self.dtype_specs() specs = self.dtype_specs()
if(check_input): if(check_input):
...@@ -517,6 +518,7 @@ theano.compile.register_view_op_c_code( ...@@ -517,6 +518,7 @@ theano.compile.register_view_op_c_code(
1) 1)
bool = get_scalar_type('bool')
int8 = get_scalar_type('int8') int8 = get_scalar_type('int8')
int16 = get_scalar_type('int16') int16 = get_scalar_type('int16')
int32 = get_scalar_type('int32') int32 = get_scalar_type('int32')
...@@ -538,7 +540,7 @@ complex_types = complex64, complex128 ...@@ -538,7 +540,7 @@ complex_types = complex64, complex128
discrete_types = int_types + uint_types discrete_types = int_types + uint_types
continuous_types = float_types + complex_types continuous_types = float_types + complex_types
all_types = discrete_types + continuous_types all_types = (bool,) + discrete_types + continuous_types
class _scalar_py_operators: class _scalar_py_operators:
...@@ -681,38 +683,35 @@ complexs64 = _multi(complex64) ...@@ -681,38 +683,35 @@ complexs64 = _multi(complex64)
complexs128 = _multi(complex128) complexs128 = _multi(complex128)
# Using a class instead of a function makes it possible to deep-copy it in # Using a class instead of a function makes it possible to deep-copy it.
# Python 2.4. # Note that currently only a few functions use this mechanism, because
# Note that currently only a few functions use this mechanism, because it is # it is enough to make the test-suite pass. However, it may prove
# enough to make the test-suite pass with Python 2.4. However, it may prove # necessary to use this same mechanism in other places as well in the
# necessary to use this same mechanism in other places as well in the future. # future.
class upcast_out(object): def upcast_out(*types):
def __new__(self, *types): dtype = Scalar.upcast(*types)
dtype = Scalar.upcast(*types) return get_scalar_type(dtype),
return get_scalar_type(dtype),
class upgrade_to_float(object):
def __new__(self, *types):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
conv = {int8: float32,
int16: float32,
int32: float64,
int64: float64,
uint8: float32,
uint16: float32,
uint32: float64,
uint64: float64}
return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])),
def upgrade_to_float(*types):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
class same_out(object): """
def __new__(self, type): conv = {int8: float32,
return type, int16: float32,
int32: float64,
int64: float64,
uint8: float32,
uint16: float32,
uint32: float64,
uint64: float64}
return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])),
def same_out(type):
return type,
def upcast_out_no_complex(*types): def upcast_out_no_complex(*types):
...@@ -728,6 +727,8 @@ def same_out_float_only(type): ...@@ -728,6 +727,8 @@ def same_out_float_only(type):
class transfer_type(gof.utils.object2): class transfer_type(gof.utils.object2):
__props__ = ('transfer',)
def __init__(self, *transfer): def __init__(self, *transfer):
assert all(type(x) in [int, str] or x is None for x in transfer) assert all(type(x) in [int, str] or x is None for x in transfer)
self.transfer = transfer self.transfer = transfer
...@@ -748,26 +749,16 @@ class transfer_type(gof.utils.object2): ...@@ -748,26 +749,16 @@ class transfer_type(gof.utils.object2):
return retval return retval
# return [upcast if i is None else types[i] for i in self.transfer] # return [upcast if i is None else types[i] for i in self.transfer]
def __eq__(self, other):
return type(self) == type(other) and self.transfer == other.transfer
def __hash__(self):
return hash(self.transfer)
class specific_out(gof.utils.object2): class specific_out(gof.utils.object2):
__props__ = ('spec',)
def __init__(self, *spec): def __init__(self, *spec):
self.spec = spec self.spec = spec
def __call__(self, *types): def __call__(self, *types):
return self.spec return self.spec
def __eq__(self, other):
return type(self) == type(other) and self.spec == other.spec
def __hash__(self):
return hash(self.spec)
def int_out(*types): def int_out(*types):
return int64, return int64,
...@@ -1007,12 +998,12 @@ class BinaryScalarOp(ScalarOp): ...@@ -1007,12 +998,12 @@ class BinaryScalarOp(ScalarOp):
class LogicalComparison(BinaryScalarOp): class LogicalComparison(BinaryScalarOp):
def output_types(self, *input_dtypes): def output_types(self, *input_dtypes):
return [int8] return [bool]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
x, y = inputs x, y = inputs
out = self(x, y) out = self(x, y)
assert str(out.type.dtype).find('int') != -1 assert out.type == bool
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
...@@ -1023,12 +1014,12 @@ class FixedLogicalComparison(UnaryScalarOp): ...@@ -1023,12 +1014,12 @@ class FixedLogicalComparison(UnaryScalarOp):
""" """
def output_types(self, *input_dtypes): def output_types(self, *input_dtypes):
return [int8] return [bool]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
x, = inputs x, = inputs
out = self(x) out = self(x)
assert str(out.type.dtype).find('int') != -1 assert out.type == bool
return [x.zeros_like().astype(theano.config.floatX)] return [x.zeros_like().astype(theano.config.floatX)]
...@@ -1202,21 +1193,10 @@ class InRange(LogicalComparison): ...@@ -1202,21 +1193,10 @@ class InRange(LogicalComparison):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x, low, hi) = inputs (x, low, hi) = inputs
(z,) = outputs (z,) = outputs
if self.openlow:
cmp1 = '>'
else:
cmp1 = '>='
# backport cmp1 = '>' if self.openlow else '>='
# cmp1 = '>' if self.openlow else '>=' cmp2 = '<' if self.openhi else '<='
if self.openhi:
cmp2 = '<'
else:
cmp2 = '<='
# backport
# cmp2 = '<' if self.openhi else '<='
return ("%(z)s = %(x)s %(cmp1)s %(low)s &&" return ("%(z)s = %(x)s %(cmp1)s %(low)s &&"
" %(x)s %(cmp2)s %(hi)s;" % locals()) " %(x)s %(cmp2)s %(hi)s;" % locals())
...@@ -1247,13 +1227,8 @@ class Switch(ScalarOp): ...@@ -1247,13 +1227,8 @@ class Switch(ScalarOp):
nfunc_spec = ('where', 3, 1) nfunc_spec = ('where', 3, 1)
def impl(self, cond, ift, iff): def impl(self, cond, ift, iff):
if cond: return ift if cond else iff
return ift
else:
return iff
# backport
# return ift if cond else iff
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(cond, ift, iff) = inputs (cond, ift, iff) = inputs
(z,) = outputs (z,) = outputs
...@@ -1290,9 +1265,9 @@ switch = Switch() ...@@ -1290,9 +1265,9 @@ switch = Switch()
class UnaryBitOp(UnaryScalarOp): class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types): def output_types(self, *input_types):
for i in input_types[0]: for i in input_types[0]:
if i not in (int8, int16, int32, int64): if i not in ((bool,) + discrete_types):
raise TypeError('input to a BitOp must have type int8,' raise TypeError('input to a BitOp must have type (u)int8, '
' int16, int32 or int64... not %s' % i) '(u)int16, (u)int32 or (u)int64 or bool not %s' % i)
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
...@@ -1302,10 +1277,13 @@ class UnaryBitOp(UnaryScalarOp): ...@@ -1302,10 +1277,13 @@ class UnaryBitOp(UnaryScalarOp):
class BinaryBitOp(BinaryScalarOp): class BinaryBitOp(BinaryScalarOp):
def output_types(self, *input_types): def output_types(self, *input_types):
t0, t1 = input_types[0] t0, t1 = input_types[0]
if t0 == bool and t1 == bool:
return [bool]
for i in input_types[0]: for i in input_types[0]:
if i not in (int8, int16, int32, int64): if i not in discrete_types:
raise TypeError('input to a BitOp must have type int8,' raise TypeError('input to a BitOp must have type (u)int8, '
' int16, int32 or int64... not %s' % i) '(u)int16, (u)int32 or (u)int64 or '
'be all bools not %s' % i)
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
...@@ -1371,6 +1349,8 @@ class Invert(UnaryBitOp): ...@@ -1371,6 +1349,8 @@ class Invert(UnaryBitOp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
if node.outputs[0].type == bool:
return "%(z)s = (!%(x)s);" % locals()
return "%(z)s = (~%(x)s);" % locals() return "%(z)s = (~%(x)s);" % locals()
invert = Invert() invert = Invert()
...@@ -2079,6 +2059,7 @@ class Cast(UnaryScalarOp): ...@@ -2079,6 +2059,7 @@ class Cast(UnaryScalarOp):
else: else:
return s return s
convert_to_bool = Cast(bool, name='convert_to_bool')
convert_to_int8 = Cast(int8, name='convert_to_int8') convert_to_int8 = Cast(int8, name='convert_to_int8')
convert_to_int16 = Cast(int16, name='convert_to_int16') convert_to_int16 = Cast(int16, name='convert_to_int16')
convert_to_int32 = Cast(int32, name='convert_to_int32') convert_to_int32 = Cast(int32, name='convert_to_int32')
...@@ -2094,6 +2075,7 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64') ...@@ -2094,6 +2075,7 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128 = Cast(complex128, name='convert_to_complex128') convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
_cast_mapping = { _cast_mapping = {
'bool': convert_to_bool,
'int8': convert_to_int8, 'int8': convert_to_int8,
'int16': convert_to_int16, 'int16': convert_to_int16,
'int32': convert_to_int32, 'int32': convert_to_int32,
......
...@@ -1246,6 +1246,10 @@ def _conversion(real_value, name): ...@@ -1246,6 +1246,10 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the # what types you are casting to what. That logic is implemented by the
# `cast()` function below. # `cast()` function below.
_convert_to_bool = _conversion(
elemwise.Elemwise(scal.convert_to_bool), 'bool')
"""Cast to boolean"""
_convert_to_int8 = _conversion( _convert_to_int8 = _conversion(
elemwise.Elemwise(scal.convert_to_int8), 'int8') elemwise.Elemwise(scal.convert_to_int8), 'int8')
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
...@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion( ...@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion(
"""Cast to double-precision complex""" """Cast to double-precision complex"""
_cast_mapping = { _cast_mapping = {
'bool': _convert_to_bool,
'int8': _convert_to_int8, 'int8': _convert_to_int8,
'int16': _convert_to_int16, 'int16': _convert_to_int16,
'int32': _convert_to_int32, 'int32': _convert_to_int32,
......
...@@ -255,6 +255,7 @@ class TensorType(Type): ...@@ -255,6 +255,7 @@ class TensorType(Type):
'float16': (float, 'npy_float16', 'NPY_FLOAT16'), 'float16': (float, 'npy_float16', 'NPY_FLOAT16'),
'float32': (float, 'npy_float32', 'NPY_FLOAT32'), 'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'float64': (float, 'npy_float64', 'NPY_FLOAT64'), 'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'bool': (bool, 'npy_bool', 'NPY_BOOL'),
'uint8': (int, 'npy_uint8', 'NPY_UINT8'), 'uint8': (int, 'npy_uint8', 'NPY_UINT8'),
'int8': (int, 'npy_int8', 'NPY_INT8'), 'int8': (int, 'npy_int8', 'NPY_INT8'),
'uint16': (int, 'npy_uint16', 'NPY_UINT16'), 'uint16': (int, 'npy_uint16', 'NPY_UINT16'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论