提交 b831de6d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change the default of "allow_downcast" and "allow_input_downcast" to None.

None is almost like False, except it is OK do downcast a Python float to a scalar floatX array. Other specialized case might be added later.
上级 aec0c068
...@@ -433,7 +433,7 @@ Final version ...@@ -433,7 +433,7 @@ Final version
class Double(gof.Type): class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False): def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float): if strict and not isinstance(x, float):
raise TypeError('Expected a float!') raise TypeError('Expected a float!')
return float(x) return float(x)
......
...@@ -113,7 +113,7 @@ must define ``filter`` and shall override ``values_eq_approx``. ...@@ -113,7 +113,7 @@ must define ``filter`` and shall override ``values_eq_approx``.
# Note that we shadow Python's function ``filter`` with this # Note that we shadow Python's function ``filter`` with this
# definition. # definition.
def filter(x, strict=False, allow_downcast = False): def filter(x, strict=False, allow_downcast=None):
if strict: if strict:
if isinstance(x, float): if isinstance(x, float):
return x return x
...@@ -278,7 +278,7 @@ Final version ...@@ -278,7 +278,7 @@ Final version
class Double(gof.Type): class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False): def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float): if strict and not isinstance(x, float):
raise TypeError('Expected a float!') raise TypeError('Expected a float!')
return float(x) return float(x)
......
...@@ -12,7 +12,7 @@ from numpy import any #for to work in python 2.4 ...@@ -12,7 +12,7 @@ from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=False): rebuild_strict=True, allow_input_downcast=None):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -54,12 +54,13 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -54,12 +54,13 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
inputs to outputs). If one of the new types does not make sense for one of the Ops in the inputs to outputs). If one of the new types does not make sense for one of the Ops in the
graph, an Exception will be raised. graph, an Exception will be raised.
:type allow_input_downcast: Boolean :type allow_input_downcast: Boolean or None
:param allow_input_downcast: True means that the values passed as :param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision. the dtype of the corresponding Variable, which may lose precision.
False means that it will only be casted to a more general, or False means that it will only be casted to a more general, or
precise, type. precise, type. None (default) is almost like False, but allows
downcasting of Python float scalars to floatX.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
......
...@@ -360,10 +360,8 @@ class Function(object): ...@@ -360,10 +360,8 @@ class Function(object):
for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)): for i, ((input, indices, sinputs), (required, refeed, value)) in enumerate(zip(self.indices, defaults)):
if indices is None: # this is true iff input is not a SymbolicInputKit if indices is None: # this is true iff input is not a SymbolicInputKit
c = containers[0] #containers is being used as a stack. Here we pop off the next one. c = containers[0] #containers is being used as a stack. Here we pop off the next one.
if input.strict: c.strict = getattr(input, 'strict', False)
c.strict = True c.allow_downcast = getattr(input, 'allow_downcast', None)
if input.allow_downcast:
c.allow_downcast = True
if value is not None: if value is not None:
# Always initialize the storage. # Always initialize the storage.
......
...@@ -35,11 +35,12 @@ class SymbolicInput(object): ...@@ -35,11 +35,12 @@ class SymbolicInput(object):
True: means that the value you pass for this input must have exactly the right type True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be casted automatically to the proper type False: the value you pass for this input may be casted automatically to the proper type
allow_downcast: Bool (default: False) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be casted to a more general, or precise, type. False: the value will only be casted to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -50,7 +51,7 @@ class SymbolicInput(object): ...@@ -50,7 +51,7 @@ class SymbolicInput(object):
""" """
def __init__(self, variable, name=None, update=None, mutable=None, def __init__(self, variable, name=None, update=None, mutable=None,
strict=False, allow_downcast=False, autoname=True, strict=False, allow_downcast=None, autoname=True,
implicit=False): implicit=False):
assert implicit is not None # Safety check. assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
...@@ -167,11 +168,12 @@ class In(SymbolicInput): ...@@ -167,11 +168,12 @@ class In(SymbolicInput):
True: means that the value you pass for this input must have exactly the right type True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type False: the value you pass for this input may be cast automatically to the proper type
allow_downcast: Bool (default: False) allow_downcast: Bool or None (default: None)
Only applies when `strict` is False. Only applies when `strict` is False.
True: the value you pass for this input can be silently True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision. downcasted to fit the right type, which may lose precision.
False: the value will only be casted to a more general, or precise, type. False: the value will only be casted to a more general, or precise, type.
None: Almost like False, but allows downcast of Python floats to floatX.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -190,7 +192,7 @@ class In(SymbolicInput): ...@@ -190,7 +192,7 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt, # Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=False, autoname=True, mutable=None, strict=False, allow_downcast=None, autoname=True,
implicit=None, borrow=None): implicit=None, borrow=None):
# mutable implies the output can be both aliased to the input and that the input can be # mutable implies the output can be both aliased to the input and that the input can be
......
...@@ -8,7 +8,7 @@ import numpy # for backport to 2.4, to get any(). ...@@ -8,7 +8,7 @@ import numpy # for backport to 2.4, to get any().
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=False, implicit=None): strict=False, allow_downcast=None, implicit=None):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a compiled-function parameter
...@@ -24,7 +24,9 @@ class Param(object): ...@@ -24,7 +24,9 @@ class Param(object):
required by `variable`. required by `variable`.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allows assigned value to lose precision when casted during assignment. True -> allow assigned value to lose precision when casted during assignment.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
:param implicit: see help(theano.io.In) :param implicit: see help(theano.io.In)
...@@ -39,7 +41,7 @@ class Param(object): ...@@ -39,7 +41,7 @@ class Param(object):
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=False): rebuild_strict=True, allow_input_downcast=None):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -77,7 +79,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -77,7 +79,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
inputs when calling the function can be silently downcasted to fit inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision. the dtype of the corresponding Variable, which may lose precision.
False means that it will only be casted to a more general, or False means that it will only be casted to a more general, or
precise, type. precise, type. None (default) is almost like False, but allows
downcasting of Python float scalars to floatX.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
...@@ -267,7 +270,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -267,7 +270,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
def _pfunc_param_to_in(param, strict=False, allow_downcast=False): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
if isinstance(param, Constant): if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param) raise TypeError('Constants not allowed in param list', param)
#if isinstance(param, Value): #if isinstance(param, Value):
......
...@@ -42,7 +42,7 @@ class SharedVariable(Variable): ...@@ -42,7 +42,7 @@ class SharedVariable(Variable):
# this Variable, unless another update value has been passed to "function", # this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it. # or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, allow_downcast=False, container=None): def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
""" """
:param name: The name for this variable (see `Variable`). :param name: The name for this variable (see `Variable`).
...@@ -54,7 +54,9 @@ class SharedVariable(Variable): ...@@ -54,7 +54,9 @@ class SharedVariable(Variable):
have the correct type. have the correct type.
:param allow_downcast: Only applies if `strict` is False. :param allow_downcast: Only applies if `strict` is False.
True -> allows assigned value to lose precision when casted during assignment. True -> allow assigned value to lose precision when casted during assignment.
False -> never allow precision loss.
None -> only allow downcasting of a Python float to a scalar floatX.
:param container: The container to use for this variable. Illegal to pass this as well :param container: The container to use for this variable. Illegal to pass this as well
as a value. as a value.
...@@ -160,7 +162,7 @@ def shared_constructor(ctor): ...@@ -160,7 +162,7 @@ def shared_constructor(ctor):
shared.constructors.append(ctor) shared.constructors.append(ctor)
return ctor return ctor
def shared(value, name=None, strict=False, allow_downcast=False, **kwargs): def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or reference of `value`. """Return a SharedVariable Variable, initialized with a copy or reference of `value`.
This function iterates over constructor functions (see `shared_constructor`) to find a This function iterates over constructor functions (see `shared_constructor`) to find a
...@@ -186,7 +188,7 @@ def shared(value, name=None, strict=False, allow_downcast=False, **kwargs): ...@@ -186,7 +188,7 @@ def shared(value, name=None, strict=False, allow_downcast=False, **kwargs):
shared.constructors = [] shared.constructors = []
@shared_constructor @shared_constructor
def generic_constructor(value, name=None, strict=False, allow_downcast=False): def generic_constructor(value, name=None, strict=False, allow_downcast=None):
"""SharedVariable Constructor""" """SharedVariable Constructor"""
return SharedVariable(type=generic, value=value, name=name, strict=strict, return SharedVariable(type=generic, value=value, name=name, strict=strict,
allow_downcast=allow_downcast) allow_downcast=allow_downcast)
......
...@@ -124,7 +124,7 @@ class Container(object): ...@@ -124,7 +124,7 @@ class Container(object):
It is used in linkers, especially for the inputs and outputs of a Function. It is used in linkers, especially for the inputs and outputs of a Function.
""" """
def __init__(self, r, storage, readonly=False, strict=False, def __init__(self, r, storage, readonly=False, strict=False,
allow_downcast=False, name=None): allow_downcast=None, name=None):
"""WRITEME """WRITEME
:Parameters: :Parameters:
...@@ -132,7 +132,9 @@ class Container(object): ...@@ -132,7 +132,9 @@ class Container(object):
`storage`: a list of length 1, whose element is the value for `r` `storage`: a list of length 1, whose element is the value for `r`
`readonly`: True indicates that this should not be setable by Function[r] = val `readonly`: True indicates that this should not be setable by Function[r] = val
`strict`: if True, we don't allow type casting. `strict`: if True, we don't allow type casting.
`allow_downcast`: if True (and `strict` is False), allow type upcasting, but not downcasting. `allow_downcast`: if True (and `strict` is False), allow upcasting
of type, but not downcasting. If False, prevent it. If None
(default), allows only downcasting of float to floatX scalar.
`name`: A string (for pretty-printing?) `name`: A string (for pretty-printing?)
""" """
...@@ -171,8 +173,8 @@ class Container(object): ...@@ -171,8 +173,8 @@ class Container(object):
kwargs = {} kwargs = {}
if self.strict: if self.strict:
kwargs['strict'] = True kwargs['strict'] = True
if self.allow_downcast: if self.allow_downcast is not None:
kwargs['allow_downcast'] = True kwargs['allow_downcast'] = self.allow_downcast
self.storage[0] = self.type.filter(value, **kwargs) self.storage[0] = self.type.filter(value, **kwargs)
except Exception, e: except Exception, e:
......
...@@ -204,7 +204,7 @@ class PureType(object): ...@@ -204,7 +204,7 @@ class PureType(object):
Variable = graph.Variable #the type that will be created by call to make_variable. Variable = graph.Variable #the type that will be created by call to make_variable.
Constant = graph.Constant #the type that will be created by call to make_constant Constant = graph.Constant #the type that will be created by call to make_constant
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=None):
"""Required: Return data or an appropriately wrapped/converted data. """Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if the data is not of an Subclass implementation should raise a TypeError exception if the data is not of an
...@@ -214,7 +214,8 @@ class PureType(object): ...@@ -214,7 +214,8 @@ class PureType(object):
data passed as an argument. If it is False, and allow_downcast data passed as an argument. If it is False, and allow_downcast
is True, filter may cast it to an appropriate type. If is True, filter may cast it to an appropriate type. If
allow_downcast is False, filter may only upcast it, not lose allow_downcast is False, filter may only upcast it, not lose
precision. precision. If allow_downcast is None, only Python float can be
downcasted, and only to a floatX scalar.
:Exceptions: :Exceptions:
- `MethodNotDefined`: subclass doesn't implement this function. - `MethodNotDefined`: subclass doesn't implement this function.
...@@ -353,7 +354,7 @@ class Generic(SingletonType): ...@@ -353,7 +354,7 @@ class Generic(SingletonType):
WRITEME WRITEME
""" """
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=None):
return data return data
def is_valid_value(self, a): def is_valid_value(self, a):
......
...@@ -52,7 +52,7 @@ class CudaNdarrayType(Type): ...@@ -52,7 +52,7 @@ class CudaNdarrayType(Type):
self.name = name self.name = name
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=None):
if strict or allow_downcast or isinstance(data, cuda.CudaNdarray): if strict or allow_downcast or isinstance(data, cuda.CudaNdarray):
return cuda.filter(data, self.broadcastable, strict, None) return cuda.filter(data, self.broadcastable, strict, None)
else: # (not strict) and (not allow_downcast) else: # (not strict) and (not allow_downcast)
...@@ -70,7 +70,13 @@ class CudaNdarrayType(Type): ...@@ -70,7 +70,13 @@ class CudaNdarrayType(Type):
data) data)
else: else:
converted_data = theano._asarray(data, self.dtype) converted_data = theano._asarray(data, self.dtype)
if numpy.all(data == converted_data):
if (allow_downcast is None and
type(data) is float and
self.dtype==theano.config.floatX):
return cuda.filter(converted_data, self.broadcastable,
strict, None)
elif numpy.all(data == converted_data):
return cuda.filter(converted_data, self.broadcastable, return cuda.filter(converted_data, self.broadcastable,
strict, None) strict, None)
else: else:
......
...@@ -70,7 +70,7 @@ class Scalar(Type): ...@@ -70,7 +70,7 @@ class Scalar(Type):
self.dtype = dtype self.dtype = dtype
self.dtype_specs() # error checking self.dtype_specs() # error checking
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=None):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data, raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data,
...@@ -78,10 +78,14 @@ class Scalar(Type): ...@@ -78,10 +78,14 @@ class Scalar(Type):
data) data)
try: try:
converted_data = py_type(data) converted_data = py_type(data)
if allow_downcast or data == converted_data: if (allow_downcast or
(allow_downcast is None and
type(data) is float and
self.dtype==theano.config.floatX) or
data == converted_data):
return py_type(data) return py_type(data)
else: else:
raise TypeError('Value cannot accurately be converted to dtype (%s) and allow_downcast is False' % self.dtype) raise TypeError('Value cannot accurately be converted to dtype (%s) and allow_downcast is not True' % self.dtype)
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e) raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
......
...@@ -168,7 +168,7 @@ class SparseType(gof.Type): ...@@ -168,7 +168,7 @@ class SparseType(gof.Type):
else: else:
raise NotImplementedError('unsupported format "%s" not in list' % format, self.format_cls.keys()) raise NotImplementedError('unsupported format "%s" not in list' % format, self.format_cls.keys())
def filter(self, value, strict=False, allow_downcast=False): def filter(self, value, strict=False, allow_downcast=None):
if isinstance(value, self.format_cls[self.format])\ if isinstance(value, self.format_cls[self.format])\
and value.dtype == self.dtype: and value.dtype == self.dtype:
return value return value
......
...@@ -8,7 +8,7 @@ class SparseTensorSharedVariable(SharedVariable, _sparse_py_operators): ...@@ -8,7 +8,7 @@ class SparseTensorSharedVariable(SharedVariable, _sparse_py_operators):
pass pass
@shared_constructor @shared_constructor
def sparse_constructor(value, name=None, strict=False, allow_downcast=False, def sparse_constructor(value, name=None, strict=False, allow_downcast=None,
borrow=False, format = None): borrow=False, format = None):
"""SharedVariable Constructor for SparseType """SharedVariable Constructor for SparseType
......
...@@ -400,7 +400,7 @@ class TensorType(Type): ...@@ -400,7 +400,7 @@ class TensorType(Type):
self.name = name self.name = name
self.numpy_dtype = numpy.dtype(self.dtype) self.numpy_dtype = numpy.dtype(self.dtype)
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a `TensorVariable`. """Convert `data` to something which can be associated to a `TensorVariable`.
This function is not meant to be called in user code. It is for This function is not meant to be called in user code. It is for
...@@ -443,6 +443,12 @@ class TensorType(Type): ...@@ -443,6 +443,12 @@ class TensorType(Type):
'"function".' '"function".'
% (self, data.dtype, self.dtype)) % (self, data.dtype, self.dtype))
raise TypeError(err_msg, data) raise TypeError(err_msg, data)
elif (allow_downcast is None and
type(data) is float and
self.dtype == theano.config.floatX):
# Special case where we allow downcasting of Python float
# literals to floatX, even when floatX=='float32'
data = theano._asarray(data, self.dtype)
else: else:
# data has to be converted. # data has to be converted.
# Check that this conversion is lossless # Check that this conversion is lossless
......
...@@ -22,7 +22,7 @@ class RandomStateType(gof.Type): ...@@ -22,7 +22,7 @@ class RandomStateType(gof.Type):
def __str__(self): def __str__(self):
return 'RandomStateType' return 'RandomStateType'
def filter(self, data, strict=False, allow_downcast=False): def filter(self, data, strict=False, allow_downcast=None):
if self.is_valid_value(data): if self.is_valid_value(data):
return data return data
else: else:
......
...@@ -12,7 +12,7 @@ class RandomStateSharedVariable(SharedVariable): ...@@ -12,7 +12,7 @@ class RandomStateSharedVariable(SharedVariable):
pass pass
@shared_constructor @shared_constructor
def randomstate_constructor(value, name=None, strict=False, allow_downcast=False, borrow=False): def randomstate_constructor(value, name=None, strict=False, allow_downcast=None, borrow=False):
"""SharedVariable Constructor for RandomState""" """SharedVariable Constructor for RandomState"""
if not isinstance(value, numpy.random.RandomState): if not isinstance(value, numpy.random.RandomState):
raise TypeError raise TypeError
......
...@@ -10,7 +10,7 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -10,7 +10,7 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
@shared_constructor @shared_constructor
def tensor_constructor(value, name=None, strict=False, allow_downcast=False, borrow=False, broadcastable=None): def tensor_constructor(value, name=None, strict=False, allow_downcast=None, borrow=False, broadcastable=None):
"""SharedVariable Constructor for TensorType """SharedVariable Constructor for TensorType
:note: Regarding the inference of the broadcastable pattern... :note: Regarding the inference of the broadcastable pattern...
...@@ -41,7 +41,7 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -41,7 +41,7 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
@shared_constructor @shared_constructor
def scalar_constructor(value, name=None, strict=False, allow_downcast=False): def scalar_constructor(value, name=None, strict=False, allow_downcast=None):
"""SharedVariable constructor for scalar values. Default: int64 or float64. """SharedVariable constructor for scalar values. Default: int64 or float64.
:note: We implement this using 0-d tensors for now. :note: We implement this using 0-d tensors for now.
......
...@@ -22,7 +22,7 @@ class T_extending(unittest.TestCase): ...@@ -22,7 +22,7 @@ class T_extending(unittest.TestCase):
# Note that we shadow Python's function ``filter`` with this # Note that we shadow Python's function ``filter`` with this
# definition. # definition.
def filter(x, strict=False, allow_downcast = False): def filter(x, strict=False, allow_downcast=None):
if strict: if strict:
if isinstance(x, float): if isinstance(x, float):
return x return x
...@@ -66,7 +66,7 @@ class T_extending(unittest.TestCase): ...@@ -66,7 +66,7 @@ class T_extending(unittest.TestCase):
class Double(gof.Type): class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False): def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float): if strict and not isinstance(x, float):
raise TypeError('Expected a float!') raise TypeError('Expected a float!')
return float(x) return float(x)
...@@ -168,7 +168,7 @@ class T_extending(unittest.TestCase): ...@@ -168,7 +168,7 @@ class T_extending(unittest.TestCase):
class Double(gof.Type): class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False): def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float): if strict and not isinstance(x, float):
raise TypeError('Expected a float!') raise TypeError('Expected a float!')
return float(x) return float(x)
...@@ -274,7 +274,7 @@ class T_extending(unittest.TestCase): ...@@ -274,7 +274,7 @@ class T_extending(unittest.TestCase):
from theano import gof from theano import gof
class Double(gof.Type): class Double(gof.Type):
def filter(self, x, strict=False, allow_downcast = False): def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float): if strict and not isinstance(x, float):
raise TypeError('Expected a float!') raise TypeError('Expected a float!')
return float(x) return float(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论