提交 b831de6d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change the default of "allow_downcast" and "allow_input_downcast" to None.

None is almost like False, except it is OK do downcast a Python float to a scalar floatX array. Other specialized case might be added later.
上级 aec0c068
......@@ -433,7 +433,7 @@ Final version
class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False):
def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
raise TypeError('Expected a float!')
return float(x)
......
......@@ -113,7 +113,7 @@ must define ``filter`` and shall override ``values_eq_approx``.
# Note that we shadow Python's function ``filter`` with this
# definition.
def filter(x, strict=False, allow_downcast = False):
def filter(x, strict=False, allow_downcast=None):
if strict:
if isinstance(x, float):
return x
......@@ -278,7 +278,7 @@ Final version
class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False):
def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
raise TypeError('Expected a float!')
return float(x)
......
......@@ -12,7 +12,7 @@ from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=False):
rebuild_strict=True, allow_input_downcast=None):
"""
Return a callable object that will calculate `outputs` from `inputs`.
......@@ -54,12 +54,13 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
inputs to outputs). If one of the new types does not make sense for one of the Ops in the
graph, an Exception will be raised.
:type allow_input_downcast: Boolean
:type allow_input_downcast: Boolean or None
:param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision.
False means that it will only be casted to a more general, or
precise, type.
precise, type. None (default) is almost like False, but allows
downcasting of Python float scalars to floatX.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
......
......@@ -360,10 +360,8 @@ class Function(object):
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit
c = containers[0] #containers is being used as a stack. Here we pop off the next one.
if input.strict:
c.strict = True
if input.allow_downcast:
c.allow_downcast = True
c.strict = getattr(input, 'strict', False)
c.allow_downcast = getattr(input, 'allow_downcast', None)
if value is not None:
# Always initialize the storage.
......
......@@ -35,11 +35,12 @@ class SymbolicInput(object):
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be casted automatically to the proper type
allow_downcast: Bool (default: False)
allow_downcast: Bool or None (default: None)
Only applies when `strict` is False.
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
False: the value will only be casted to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
autoname: Bool (default: True)
See the name option.
......@@ -50,7 +51,7 @@ class SymbolicInput(object):
"""
def __init__(self, variable, name=None, update=None, mutable=None,
strict=False, allow_downcast=False, autoname=True,
strict=False, allow_downcast=None, autoname=True,
implicit=False):
assert implicit is not None # Safety check.
self.variable = variable
......@@ -167,11 +168,12 @@ class In(SymbolicInput):
True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type
allow_downcast: Bool (default: False)
allow_downcast: Bool or None (default: None)
Only applies when `strict` is False.
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
False: the value will only be casted to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
autoname: Bool (default: True)
See the name option.
......@@ -190,7 +192,7 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=False, autoname=True,
mutable=None, strict=False, allow_downcast=None, autoname=True,
implicit=None, borrow=None):
# mutable implies the output can be both aliased to the input and that the input can be
......
......@@ -8,7 +8,7 @@ import numpy # for backport to 2.4, to get any().
class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=False, implicit=None):
strict=False, allow_downcast=None, implicit=None):
"""
:param variable: A variable in an expression graph to use as a compiled-function parameter
......@@ -24,7 +24,9 @@ class Param(object):
required by `variable`.
:param allow_downcast: Only applies if `strict` is False.
True -> allows assigned value to lose precision when casted during assignment.
True -> allow assigned value to lose precision when casted during assignment.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
:param implicit: see help(theano.io.In)
......@@ -39,7 +41,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=False):
rebuild_strict=True, allow_input_downcast=None):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
......@@ -77,7 +79,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision.
False means that it will only be casted to a more general, or
precise, type.
precise, type. None (default) is almost like False, but allows
downcasting of Python float scalars to floatX.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
......@@ -267,7 +270,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
accept_inplace=accept_inplace, name=name)
def _pfunc_param_to_in(param, strict=False, allow_downcast=False):
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param)
#if isinstance(param, Value):
......
......@@ -42,7 +42,7 @@ class SharedVariable(Variable):
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, allow_downcast=False, container=None):
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
"""
:param name: The name for this variable (see `Variable`).
......@@ -54,7 +54,9 @@ class SharedVariable(Variable):
have the correct type.
:param allow_downcast: Only applies if `strict` is False.
True -> allows assigned value to lose precision when casted during assignment.
True -> allow assigned value to lose precision when casted during assignment.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
:param container: The container to use for this variable. Illegal to pass this as well
as a value.
......@@ -160,7 +162,7 @@ def shared_constructor(ctor):
shared.constructors.append(ctor)
return ctor
def shared(value, name=None, strict=False, allow_downcast=False, **kwargs):
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or reference of `value`.
This function iterates over constructor functions (see `shared_constructor`) to find a
......@@ -186,7 +188,7 @@ def shared(value, name=None, strict=False, allow_downcast=False, **kwargs):
shared.constructors = []
@shared_constructor
def generic_constructor(value, name=None, strict=False, allow_downcast=False):
def generic_constructor(value, name=None, strict=False, allow_downcast=None):
"""SharedVariable Constructor"""
return SharedVariable(type=generic, value=value, name=name, strict=strict,
allow_downcast=allow_downcast)
......
......@@ -124,7 +124,7 @@ class Container(object):
It is used in linkers, especially for the inputs and outputs of a Function.
"""
def __init__(self, r, storage, readonly=False, strict=False,
allow_downcast=False, name=None):
allow_downcast=None, name=None):
"""WRITEME
:Parameters:
......@@ -132,7 +132,9 @@ class Container(object):
`storage`: a list of length 1, whose element is the value for `r`
`readonly`: True indicates that this should not be setable by Function[r] = val
`strict`: if True, we don't allow type casting.
`allow_downcast`: if True (and `strict` is False), allow type upcasting, but not downcasting.
`allow_downcast`: if True (and `strict` is False), allow upcasting
of type, but not downcasting. If False, prevent it. If None
(default), allows only downcasting of float to floatX scalar.
`name`: A string (for pretty-printing?)
"""
......@@ -171,8 +173,8 @@ class Container(object):
kwargs = {}
if self.strict:
kwargs['strict'] = True
if self.allow_downcast:
kwargs['allow_downcast'] = True
if self.allow_downcast is not None:
kwargs['allow_downcast'] = self.allow_downcast
self.storage[0] = self.type.filter(value, **kwargs)
except Exception, e:
......
......@@ -204,7 +204,7 @@ class PureType(object):
Variable = graph.Variable #the type that will be created by call to make_variable.
Constant = graph.Constant #the type that will be created by call to make_constant
def filter(self, data, strict=False, allow_downcast=False):
def filter(self, data, strict=False, allow_downcast=None):
"""Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if the data is not of an
......@@ -214,7 +214,8 @@ class PureType(object):
data passed as an argument. If it is False, and allow_downcast
is True, filter may cast it to an appropriate type. If
allow_downcast is False, filter may only upcast it, not lose
precision.
precision. If allow_downcast is None, only Python float can be
downcasted, and only to a floatX scalar.
:Exceptions:
- `MethodNotDefined`: subclass doesn't implement this function.
......@@ -353,7 +354,7 @@ class Generic(SingletonType):
WRITEME
"""
def filter(self, data, strict=False, allow_downcast=False):
def filter(self, data, strict=False, allow_downcast=None):
return data
def is_valid_value(self, a):
......
......@@ -52,7 +52,7 @@ class CudaNdarrayType(Type):
self.name = name
self.dtype_specs() # error checking is done there
def filter(self, data, strict=False, allow_downcast=False):
def filter(self, data, strict=False, allow_downcast=None):
if strict or allow_downcast or isinstance(data, cuda.CudaNdarray):
return cuda.filter(data, self.broadcastable, strict, None)
else: # (not strict) and (not allow_downcast)
......@@ -70,7 +70,13 @@ class CudaNdarrayType(Type):
data)
else:
converted_data = theano._asarray(data, self.dtype)
if numpy.all(data == converted_data):
if (allow_downcast is None and
type(data) is float and
self.dtype==theano.config.floatX):
return cuda.filter(converted_data, self.broadcastable,
strict, None)
elif numpy.all(data == converted_data):
return cuda.filter(converted_data, self.broadcastable,
strict, None)
else:
......
......@@ -70,7 +70,7 @@ class Scalar(Type):
self.dtype = dtype
self.dtype_specs() # error checking
def filter(self, data, strict=False, allow_downcast=False):
def filter(self, data, strict=False, allow_downcast=None):
py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data,
......@@ -78,10 +78,14 @@ class Scalar(Type):
data)
try:
converted_data = py_type(data)
if allow_downcast or data == converted_data:
if (allow_downcast or
(allow_downcast is None and
type(data) is float and
self.dtype==theano.config.floatX) or
data == converted_data):
return py_type(data)
else:
raise TypeError('Value cannot accurately be converted to dtype (%s) and allow_downcast is False' % self.dtype)
raise TypeError('Value cannot accurately be converted to dtype (%s) and allow_downcast is not True' % self.dtype)
except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
......
......@@ -168,7 +168,7 @@ class SparseType(gof.Type):
else:
raise NotImplementedError('unsupported format "%s" not in list' % format, self.format_cls.keys())
def filter(self, value, strict=False, allow_downcast=False):
def filter(self, value, strict=False, allow_downcast=None):
if isinstance(value, self.format_cls[self.format])\
and value.dtype == self.dtype:
return value
......
......@@ -8,7 +8,7 @@ class SparseTensorSharedVariable(SharedVariable, _sparse_py_operators):
pass
@shared_constructor
def sparse_constructor(value, name=None, strict=False, allow_downcast=False,
def sparse_constructor(value, name=None, strict=False, allow_downcast=None,
borrow=False, format = None):
"""SharedVariable Constructor for SparseType
......
......@@ -400,7 +400,7 @@ class TensorType(Type):
self.name = name
self.numpy_dtype = numpy.dtype(self.dtype)
def filter(self, data, strict=False, allow_downcast=False):
def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a `TensorVariable`.
This function is not meant to be called in user code. It is for
......@@ -443,6 +443,12 @@ class TensorType(Type):
'"function".'
% (self, data.dtype, self.dtype))
raise TypeError(err_msg, data)
elif (allow_downcast is None and
type(data) is float and
self.dtype == theano.config.floatX):
# Special case where we allow downcasting of Python float
# literals to floatX, even when floatX=='float32'
data = theano._asarray(data, self.dtype)
else:
# data has to be converted.
# Check that this conversion is lossless
......
......@@ -22,7 +22,7 @@ class RandomStateType(gof.Type):
def __str__(self):
return 'RandomStateType'
def filter(self, data, strict=False, allow_downcast=False):
def filter(self, data, strict=False, allow_downcast=None):
if self.is_valid_value(data):
return data
else:
......
......@@ -12,7 +12,7 @@ class RandomStateSharedVariable(SharedVariable):
pass
@shared_constructor
def randomstate_constructor(value, name=None, strict=False, allow_downcast=False, borrow=False):
def randomstate_constructor(value, name=None, strict=False, allow_downcast=None, borrow=False):
"""SharedVariable Constructor for RandomState"""
if not isinstance(value, numpy.random.RandomState):
raise TypeError
......
......@@ -10,7 +10,7 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass
@shared_constructor
def tensor_constructor(value, name=None, strict=False, allow_downcast=False, borrow=False, broadcastable=None):
def tensor_constructor(value, name=None, strict=False, allow_downcast=None, borrow=False, broadcastable=None):
"""SharedVariable Constructor for TensorType
:note: Regarding the inference of the broadcastable pattern...
......@@ -41,7 +41,7 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass
@shared_constructor
def scalar_constructor(value, name=None, strict=False, allow_downcast=False):
def scalar_constructor(value, name=None, strict=False, allow_downcast=None):
"""SharedVariable constructor for scalar values. Default: int64 or float64.
:note: We implement this using 0-d tensors for now.
......
......@@ -22,7 +22,7 @@ class T_extending(unittest.TestCase):
# Note that we shadow Python's function ``filter`` with this
# definition.
def filter(x, strict=False, allow_downcast = False):
def filter(x, strict=False, allow_downcast=None):
if strict:
if isinstance(x, float):
return x
......@@ -66,7 +66,7 @@ class T_extending(unittest.TestCase):
class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False):
def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
raise TypeError('Expected a float!')
return float(x)
......@@ -168,7 +168,7 @@ class T_extending(unittest.TestCase):
class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False):
def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
raise TypeError('Expected a float!')
return float(x)
......@@ -274,7 +274,7 @@ class T_extending(unittest.TestCase):
from theano import gof
class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False):
def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
raise TypeError('Expected a float!')
return float(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论