提交 27654022 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add stuff so that scalar tests pass with the new bool type.

上级 146ef971
......@@ -477,7 +477,7 @@ def grad(cost, wrt, consider_constant=None,
# function, sure, but nonetheless one we can and should support.
# So before we try to cast it make sure it even has a dtype
if (hasattr(g_cost.type, 'dtype') and
cost.type.dtype not in tensor.discrete_dtypes):
cost.type.dtype in tensor.continuous_dtypes):
# Here we enforce the constraint that floating point variables
# have the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype)
......@@ -485,7 +485,7 @@ def grad(cost, wrt, consider_constant=None,
# This is to be enforced by the Op.grad method for the
# Op that outputs cost.
if hasattr(g_cost.type, 'dtype'):
assert g_cost.type.dtype not in tensor.discrete_dtypes
assert g_cost.type.dtype in tensor.continuous_dtypes
grad_dict[cost] = g_cost
......@@ -1334,12 +1334,11 @@ def _float_ones_like(x):
""" Like ones_like, but forces the object to have a
floating point dtype """
rval = tensor.ones_like(x)
dtype = x.type.dtype
if 'float' not in dtype:
dtype = theano.config.floatX
if rval.type.dtype.find('float') != -1:
return rval
return rval.astype(theano.config.floatX)
return tensor.ones_like(x, dtype=dtype)
class numeric_grad(object):
......
......@@ -34,6 +34,7 @@ from theano.gradient import grad_undefined
from theano.printing import pprint
import collections
builtin_bool = bool
builtin_complex = complex
builtin_int = int
builtin_float = float
......@@ -161,7 +162,7 @@ class Scalar(Type):
TODO: refactor to be named ScalarType for consistency with TensorType.
"""
__props__ = ('dtype',)
ndim = 0
def __init__(self, dtype):
......@@ -200,6 +201,8 @@ class Scalar(Type):
def values_eq_approx(self, a, b, tolerance=1e-4):
# The addition have risk of overflow especially with [u]int8
if self.dtype == 'bool':
return a == b
diff = a - b
if diff == 0:
return True
......@@ -227,12 +230,6 @@ class Scalar(Type):
else:
return []
def __eq__(self, other):
return type(self) == type(other) and other.dtype == self.dtype
def __hash__(self):
return hash('theano.scalar.Scalar') ^ hash(self.dtype)
def dtype_specs(self):
try:
# To help debug dtype/typenum problem, here is code to get
......@@ -244,7 +241,8 @@ class Scalar(Type):
# now, as Theano always expect the exact typenum that
# correspond to our supported dtype.
"""
for dtype in ['int8', 'uint8', 'short', 'ushort', 'intc', 'uintc',
for dtype in ['bool', 'int8', 'uint8', 'short', 'ushort', 'intc',
'uintc',
'longlong', 'ulonglong', 'single', 'double',
'longdouble', 'csingle', 'cdouble', 'clongdouble',
'float32', 'float64', 'int8', 'int16', 'int32',
......@@ -260,6 +258,7 @@ class Scalar(Type):
'complex128': (numpy.complex128, 'theano_complex128',
'Complex128'),
'complex64': (numpy.complex64, 'theano_complex64', 'Complex64'),
'bool': (numpy.bool_, 'npy_bool', 'Bool'),
'uint8': (numpy.uint8, 'npy_uint8', 'UInt8'),
'int8': (numpy.int8, 'npy_int8', 'Int8'),
'uint16': (numpy.uint16, 'npy_uint16', 'UInt16'),
......@@ -288,12 +287,13 @@ class Scalar(Type):
def c_literal(self, data):
if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.")
if self.dtype == 'bool':
return '1' if b else '0'
return str(data)
def c_declare(self, name, sub, check_input=True):
if(check_input):
pre = """
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1])
else:
......@@ -309,6 +309,7 @@ class Scalar(Type):
def c_extract(self, name, sub, check_input=True):
if self.dtype == 'float16':
# This doesn't work at the numpy level
raise NotImplementedError('float16')
specs = self.dtype_specs()
if(check_input):
......@@ -517,6 +518,7 @@ theano.compile.register_view_op_c_code(
1)
bool = get_scalar_type('bool')
int8 = get_scalar_type('int8')
int16 = get_scalar_type('int16')
int32 = get_scalar_type('int32')
......@@ -538,7 +540,7 @@ complex_types = complex64, complex128
discrete_types = int_types + uint_types
continuous_types = float_types + complex_types
all_types = discrete_types + continuous_types
all_types = (bool,) + discrete_types + continuous_types
class _scalar_py_operators:
......@@ -681,38 +683,35 @@ complexs64 = _multi(complex64)
complexs128 = _multi(complex128)
# Using a class instead of a function makes it possible to deep-copy it in
# Python 2.4.
# Note that currently only a few functions use this mechanism, because it is
# enough to make the test-suite pass with Python 2.4. However, it may prove
# necessary to use this same mechanism in other places as well in the future.
class upcast_out(object):
def __new__(self, *types):
dtype = Scalar.upcast(*types)
return get_scalar_type(dtype),
# Using a class instead of a function makes it possible to deep-copy it.
# Note that currently only a few functions use this mechanism, because
# it is enough to make the test-suite pass. However, it may prove
# necessary to use this same mechanism in other places as well in the
# future.
def upcast_out(*types):
dtype = Scalar.upcast(*types)
return get_scalar_type(dtype),
class upgrade_to_float(object):
def __new__(self, *types):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
conv = {int8: float32,
int16: float32,
int32: float64,
int64: float64,
uint8: float32,
uint16: float32,
uint32: float64,
uint64: float64}
return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])),
def upgrade_to_float(*types):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
class same_out(object):
def __new__(self, type):
return type,
"""
conv = {int8: float32,
int16: float32,
int32: float64,
int64: float64,
uint8: float32,
uint16: float32,
uint32: float64,
uint64: float64}
return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])),
def same_out(type):
return type,
def upcast_out_no_complex(*types):
......@@ -728,6 +727,8 @@ def same_out_float_only(type):
class transfer_type(gof.utils.object2):
__props__ = ('transfer',)
def __init__(self, *transfer):
assert all(type(x) in [int, str] or x is None for x in transfer)
self.transfer = transfer
......@@ -748,26 +749,16 @@ class transfer_type(gof.utils.object2):
return retval
# return [upcast if i is None else types[i] for i in self.transfer]
def __eq__(self, other):
return type(self) == type(other) and self.transfer == other.transfer
def __hash__(self):
return hash(self.transfer)
class specific_out(gof.utils.object2):
__props__ = ('spec',)
def __init__(self, *spec):
self.spec = spec
def __call__(self, *types):
return self.spec
def __eq__(self, other):
return type(self) == type(other) and self.spec == other.spec
def __hash__(self):
return hash(self.spec)
def int_out(*types):
return int64,
......@@ -1007,12 +998,12 @@ class BinaryScalarOp(ScalarOp):
class LogicalComparison(BinaryScalarOp):
def output_types(self, *input_dtypes):
return [int8]
return [bool]
def grad(self, inputs, output_gradients):
x, y = inputs
out = self(x, y)
assert str(out.type.dtype).find('int') != -1
assert out.type == bool
return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)]
......@@ -1023,12 +1014,12 @@ class FixedLogicalComparison(UnaryScalarOp):
"""
def output_types(self, *input_dtypes):
return [int8]
return [bool]
def grad(self, inputs, output_gradients):
x, = inputs
out = self(x)
assert str(out.type.dtype).find('int') != -1
assert out.type == bool
return [x.zeros_like().astype(theano.config.floatX)]
......@@ -1202,21 +1193,10 @@ class InRange(LogicalComparison):
def c_code(self, node, name, inputs, outputs, sub):
(x, low, hi) = inputs
(z,) = outputs
if self.openlow:
cmp1 = '>'
else:
cmp1 = '>='
# backport
# cmp1 = '>' if self.openlow else '>='
cmp1 = '>' if self.openlow else '>='
cmp2 = '<' if self.openhi else '<='
if self.openhi:
cmp2 = '<'
else:
cmp2 = '<='
# backport
# cmp2 = '<' if self.openhi else '<='
return ("%(z)s = %(x)s %(cmp1)s %(low)s &&"
" %(x)s %(cmp2)s %(hi)s;" % locals())
......@@ -1247,13 +1227,8 @@ class Switch(ScalarOp):
nfunc_spec = ('where', 3, 1)
def impl(self, cond, ift, iff):
if cond:
return ift
else:
return iff
return ift if cond else iff
# backport
# return ift if cond else iff
def c_code(self, node, name, inputs, outputs, sub):
(cond, ift, iff) = inputs
(z,) = outputs
......@@ -1290,9 +1265,9 @@ switch = Switch()
class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types):
for i in input_types[0]:
if i not in (int8, int16, int32, int64):
raise TypeError('input to a BitOp must have type int8,'
' int16, int32 or int64... not %s' % i)
if i not in ((bool,) + discrete_types):
raise TypeError('input to a BitOp must have type (u)int8, '
'(u)int16, (u)int32 or (u)int64 or bool not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
......@@ -1302,10 +1277,13 @@ class UnaryBitOp(UnaryScalarOp):
class BinaryBitOp(BinaryScalarOp):
def output_types(self, *input_types):
t0, t1 = input_types[0]
if t0 == bool and t1 == bool:
return [bool]
for i in input_types[0]:
if i not in (int8, int16, int32, int64):
raise TypeError('input to a BitOp must have type int8,'
' int16, int32 or int64... not %s' % i)
if i not in discrete_types:
raise TypeError('input to a BitOp must have type (u)int8, '
'(u)int16, (u)int32 or (u)int64 or '
'be all bools not %s' % i)
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
......@@ -1371,6 +1349,8 @@ class Invert(UnaryBitOp):
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
if node.outputs[0].type == bool:
return "%(z)s = (!%(x)s);" % locals()
return "%(z)s = (~%(x)s);" % locals()
invert = Invert()
......@@ -2079,6 +2059,7 @@ class Cast(UnaryScalarOp):
else:
return s
convert_to_bool = Cast(bool, name='convert_to_bool')
convert_to_int8 = Cast(int8, name='convert_to_int8')
convert_to_int16 = Cast(int16, name='convert_to_int16')
convert_to_int32 = Cast(int32, name='convert_to_int32')
......@@ -2094,6 +2075,7 @@ convert_to_complex64 = Cast(complex64, name='convert_to_complex64')
convert_to_complex128 = Cast(complex128, name='convert_to_complex128')
_cast_mapping = {
'bool': convert_to_bool,
'int8': convert_to_int8,
'int16': convert_to_int16,
'int32': convert_to_int32,
......
......@@ -1246,6 +1246,10 @@ def _conversion(real_value, name):
# what types you are casting to what. That logic is implemented by the
# `cast()` function below.
_convert_to_bool = _conversion(
elemwise.Elemwise(scal.convert_to_bool), 'bool')
"""Cast to boolean"""
_convert_to_int8 = _conversion(
elemwise.Elemwise(scal.convert_to_int8), 'int8')
"""Cast to 8-bit integer"""
......@@ -1299,6 +1303,7 @@ _convert_to_complex128 = _conversion(
"""Cast to double-precision complex"""
_cast_mapping = {
'bool': _convert_to_bool,
'int8': _convert_to_int8,
'int16': _convert_to_int16,
'int32': _convert_to_int32,
......
......@@ -255,6 +255,7 @@ class TensorType(Type):
'float16': (float, 'npy_float16', 'NPY_FLOAT16'),
'float32': (float, 'npy_float32', 'NPY_FLOAT32'),
'float64': (float, 'npy_float64', 'NPY_FLOAT64'),
'bool': (bool, 'npy_bool', 'NPY_BOOL'),
'uint8': (int, 'npy_uint8', 'NPY_UINT8'),
'int8': (int, 'npy_int8', 'NPY_INT8'),
'uint16': (int, 'npy_uint16', 'NPY_UINT16'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论