提交 ec552e7b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

New arg "allow_input_downcast" to function, and "allow_downcast" to filter.

Default of "allow_downcast" is True, like previous behaviour. However, default of "allow_input_downcast" is False, so no precision is silently lost when passing a parameter to a Theano function.
上级 104d3703
...@@ -12,7 +12,7 @@ from numpy import any #for to work in python 2.4 ...@@ -12,7 +12,7 @@ from numpy import any #for to work in python 2.4
def function(inputs, outputs=None, mode=None, updates=[], givens=[], def function(inputs, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict = True): rebuild_strict=True, allow_input_downcast=False):
""" """
Return a callable object that will calculate `outputs` from `inputs`. Return a callable object that will calculate `outputs` from `inputs`.
...@@ -54,6 +54,13 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -54,6 +54,13 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
inputs to outputs). If one of the new types does not make sense for one of the Ops in the inputs to outputs). If one of the new types does not make sense for one of the Ops in the
graph, an Exception will be raised. graph, an Exception will be raised.
:type allow_input_downcast: Boolean
:param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision.
False means that it will only be casted to a more general, or
precise, type.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from another expression is undefined. Replacements specified with givens are different from
...@@ -93,7 +100,8 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], ...@@ -93,7 +100,8 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[],
givens=givens, givens=givens,
no_default_updates=no_default_updates, no_default_updates=no_default_updates,
accept_inplace=accept_inplace,name=name, accept_inplace=accept_inplace,name=name,
rebuild_strict=rebuild_strict) rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast)
# We need to add the flag check_aliased inputs if we have any mutable or # We need to add the flag check_aliased inputs if we have any mutable or
# borrowed used defined inputs # borrowed used defined inputs
fn._check_for_aliased_inputs = check_for_aliased_inputs fn._check_for_aliased_inputs = check_for_aliased_inputs
......
...@@ -362,6 +362,8 @@ class Function(object): ...@@ -362,6 +362,8 @@ class Function(object):
c = containers[0] #containers is being used as a stack. Here we pop off the next one. c = containers[0] #containers is being used as a stack. Here we pop off the next one.
if input.strict: if input.strict:
c.strict = True c.strict = True
if input.allow_downcast:
c.allow_downcast = True
if value is not None: if value is not None:
# Always initialize the storage. # Always initialize the storage.
...@@ -519,7 +521,8 @@ class Function(object): ...@@ -519,7 +521,8 @@ class Function(object):
s.storage[0] = arg s.storage[0] = arg
else: else:
try: try:
s.storage[0] = s.type.filter(arg, strict=s.strict) s.storage[0] = s.type.filter(arg, strict=s.strict,
allow_downcast=s.allow_downcast)
except Exception, e: except Exception, e:
e.args = tuple(list(e.args)+["Bad input argument at index %d" % arg_index]) e.args = tuple(list(e.args)+["Bad input argument at index %d" % arg_index])
......
...@@ -28,6 +28,12 @@ class SymbolicInput(object): ...@@ -28,6 +28,12 @@ class SymbolicInput(object):
True: means that the value you pass for this input must have exactly the right type True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be casted automatically to the proper type False: the value you pass for this input may be casted automatically to the proper type
allow_downcast: Bool (default: False)
Only applies when `strict` is False.
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
False: the value will only be casted to a more general, or precise, type.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -36,7 +42,8 @@ class SymbolicInput(object): ...@@ -36,7 +42,8 @@ class SymbolicInput(object):
symbolic case. symbolic case.
""" """
def __init__(self, variable, name=None, update=None, mutable=None, strict=False, autoname=True, def __init__(self, variable, name=None, update=None, mutable=None,
strict=False, allow_downcast=False, autoname=True,
implicit=False): implicit=False):
assert implicit is not None # Safety check. assert implicit is not None # Safety check.
self.variable = variable self.variable = variable
...@@ -45,8 +52,6 @@ class SymbolicInput(object): ...@@ -45,8 +52,6 @@ class SymbolicInput(object):
else: else:
self.name = name self.name = name
#backport
#self.name = variable.name if (autoname and name is None) else name
if self.name is not None and not isinstance(self.name, str): if self.name is not None and not isinstance(self.name, str):
raise TypeError("name must be a string! (got: %s)" % self.name) raise TypeError("name must be a string! (got: %s)" % self.name)
self.update = update self.update = update
...@@ -55,9 +60,8 @@ class SymbolicInput(object): ...@@ -55,9 +60,8 @@ class SymbolicInput(object):
else: else:
self.mutable = (update is not None) self.mutable = (update is not None)
#backport
#self.mutable = mutable if (mutable is not None) else (update is not None)
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast
self.implicit = implicit self.implicit = implicit
def __str__(self): def __str__(self):
...@@ -156,6 +160,12 @@ class In(SymbolicInput): ...@@ -156,6 +160,12 @@ class In(SymbolicInput):
True: means that the value you pass for this input must have exactly the right type True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type False: the value you pass for this input may be cast automatically to the proper type
allow_downcast: Bool (default: False)
Only applies when `strict` is False.
True: the value you pass for this input can be silently
downcasted to fit the right type, which may lose precision.
False: the value will only be casted to a more general, or precise, type.
autoname: Bool (default: True) autoname: Bool (default: True)
See the name option. See the name option.
...@@ -173,13 +183,20 @@ class In(SymbolicInput): ...@@ -173,13 +183,20 @@ class In(SymbolicInput):
# Note: the documentation above is duplicated in doc/topics/function.txt, # Note: the documentation above is duplicated in doc/topics/function.txt,
# try to keep it synchronized. # try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None, def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, autoname=True, mutable=None, strict=False, allow_downcast=False, autoname=True,
implicit=None): implicit=None):
if implicit is None: if implicit is None:
implicit = (isinstance(value, gof.Container) or implicit = (isinstance(value, gof.Container) or
isinstance(value, SharedVariable)) isinstance(value, SharedVariable))
super(In, self).__init__(variable, name, update, mutable, strict, super(In, self).__init__(
autoname, implicit = implicit) variable=variable,
name=name,
update=update,
mutable=mutable,
strict=strict,
allow_downcast=allow_downcast,
autoname=autoname,
implicit=implicit)
self.value = value self.value = value
if self.implicit and value is None: if self.implicit and value is None:
raise TypeError('An implicit input must be given a default value') raise TypeError('An implicit input must be given a default value')
......
...@@ -7,8 +7,8 @@ from theano.compile.sharedvalue import SharedVariable, shared ...@@ -7,8 +7,8 @@ from theano.compile.sharedvalue import SharedVariable, shared
import numpy # for backport to 2.4, to get any(). import numpy # for backport to 2.4, to get any().
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, strict=False, def __init__(self, variable, default=None, name=None, mutable=False,
implicit=None): strict=False, allow_downcast=False, implicit=None):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a compiled-function parameter
...@@ -23,6 +23,9 @@ class Param(object): ...@@ -23,6 +23,9 @@ class Param(object):
type required by the parameter `variable`. True -> function arguments must exactly match the type type required by the parameter `variable`. True -> function arguments must exactly match the type
required by `variable`. required by `variable`.
:param allow_downcast: Only applies if `strict` is False.
True -> allows assigned value to lose precision when casted during assignment.
:param implicit: see help(theano.io.In) :param implicit: see help(theano.io.In)
""" """
...@@ -31,10 +34,12 @@ class Param(object): ...@@ -31,10 +34,12 @@ class Param(object):
self.name = name self.name = name
self.mutable = mutable self.mutable = mutable
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast
self.implicit = implicit self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, rebuild_strict = True): no_default_updates=False, accept_inplace=False, name=None,
rebuild_strict=True, allow_input_downcast=False):
"""Function-constructor for graphs with shared variables. """Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances. :type params: list of either Variable or Param instances.
...@@ -67,6 +72,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -67,6 +72,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
:returns: a callable object that will compute the outputs (given the inputs) :returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`. and update the implicit function arguments according to the `updates`.
:type allow_input_downcast: Boolean
:param allow_input_downcast: True means that the values passed as
inputs when calling the function can be silently downcasted to fit
the dtype of the corresponding Variable, which may lose precision.
False means that it will only be casted to a more general, or
precise, type.
:note: Regarding givens: Be careful to make sure that these substitutions are :note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from another expression is undefined. Replacements specified with givens are different from
...@@ -165,7 +177,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -165,7 +177,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
clone_d[v_orig] = clone_v_get_shared_updates(v_repl) clone_d[v_orig] = clone_v_get_shared_updates(v_repl)
# transform params into theano.compile.In objects. # transform params into theano.compile.In objects.
inputs = [_pfunc_param_to_in(p) for p in params] inputs = [_pfunc_param_to_in(p, allow_downcast=allow_input_downcast)
for p in params]
#Switch inputs to cloned variables #Switch inputs to cloned variables
input_variables = [clone_d.setdefault(i.variable, i.variable) for i in inputs] input_variables = [clone_d.setdefault(i.variable, i.variable) for i in inputs]
...@@ -253,14 +266,14 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -253,14 +266,14 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
accept_inplace=accept_inplace, name=name) accept_inplace=accept_inplace, name=name)
def _pfunc_param_to_in(param): def _pfunc_param_to_in(param, strict=False, allow_downcast=False):
if isinstance(param, Constant): if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param) raise TypeError('Constants not allowed in param list', param)
#if isinstance(param, Value): #if isinstance(param, Value):
#return In(variable=param) #return In(variable=param)
#raise NotImplementedError() #raise NotImplementedError()
if isinstance(param, Variable): #N.B. includes Value and SharedVariable if isinstance(param, Variable): #N.B. includes Value and SharedVariable
return In(variable=param) return In(variable=param, strict=strict, allow_downcast=allow_downcast)
elif isinstance(param, Param): elif isinstance(param, Param):
return In( return In(
variable=param.variable, variable=param.variable,
...@@ -268,6 +281,7 @@ def _pfunc_param_to_in(param): ...@@ -268,6 +281,7 @@ def _pfunc_param_to_in(param):
value=param.default, value=param.default,
mutable=param.mutable, mutable=param.mutable,
strict=param.strict, strict=param.strict,
allow_downcast=param.allow_downcast,
implicit = param.implicit) implicit = param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param)) raise TypeError('Unknown parameter type: %s' % type(param))
......
...@@ -42,7 +42,7 @@ class SharedVariable(Variable): ...@@ -42,7 +42,7 @@ class SharedVariable(Variable):
# this Variable, unless another update value has been passed to "function", # this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it. # or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, container=None): def __init__(self, name, type, value, strict, allow_downcast=False, container=None):
""" """
:param name: The name for this variable (see `Variable`). :param name: The name for this variable (see `Variable`).
...@@ -53,6 +53,9 @@ class SharedVariable(Variable): ...@@ -53,6 +53,9 @@ class SharedVariable(Variable):
:param strict: True -> assignments to .value will not be casted or copied, so they must :param strict: True -> assignments to .value will not be casted or copied, so they must
have the correct type. have the correct type.
:param allow_downcast: Only applies if `strict` is False.
True -> allows assigned value to lose precision when casted during assignment.
:param container: The container to use for this variable. Illegal to pass this as well :param container: The container to use for this variable. Illegal to pass this as well
as a value. as a value.
...@@ -69,9 +72,10 @@ class SharedVariable(Variable): ...@@ -69,9 +72,10 @@ class SharedVariable(Variable):
if container is not None: if container is not None:
raise TypeError('Error to specify both value and container') raise TypeError('Error to specify both value and container')
self.container = Container(self, self.container = Container(self,
storage=[type.filter(value, strict=strict)], storage=[type.filter(value, strict=strict, allow_downcast=allow_downcast)],
readonly=False, readonly=False,
strict=strict) strict=strict,
allow_downcast=allow_downcast)
def get_value(self, borrow=False, return_internal_type=False): def get_value(self, borrow=False, return_internal_type=False):
"""Get the non-symbolic value associated with this SharedVariable. """Get the non-symbolic value associated with this SharedVariable.
...@@ -156,7 +160,7 @@ def shared_constructor(ctor): ...@@ -156,7 +160,7 @@ def shared_constructor(ctor):
shared.constructors.append(ctor) shared.constructors.append(ctor)
return ctor return ctor
def shared(value, name=None, strict=False, **kwargs): def shared(value, name=None, strict=False, allow_downcast=False, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or reference of `value`. """Return a SharedVariable Variable, initialized with a copy or reference of `value`.
This function iterates over constructor functions (see `shared_constructor`) to find a This function iterates over constructor functions (see `shared_constructor`) to find a
...@@ -169,7 +173,8 @@ def shared(value, name=None, strict=False, **kwargs): ...@@ -169,7 +173,8 @@ def shared(value, name=None, strict=False, **kwargs):
""" """
for ctor in reversed(shared.constructors): for ctor in reversed(shared.constructors):
try: try:
return ctor(value, name=name, strict=strict, **kwargs) return ctor(value, name=name, strict=strict,
allow_downcast=allow_downcast, **kwargs)
except TypeError: except TypeError:
continue continue
# This may happen when kwargs were supplied # This may happen when kwargs were supplied
...@@ -181,7 +186,8 @@ def shared(value, name=None, strict=False, **kwargs): ...@@ -181,7 +186,8 @@ def shared(value, name=None, strict=False, **kwargs):
shared.constructors = [] shared.constructors = []
@shared_constructor @shared_constructor
def generic_constructor(value, name=None, strict=False): def generic_constructor(value, name=None, strict=False, allow_downcast=False):
"""SharedVariable Constructor""" """SharedVariable Constructor"""
return SharedVariable(type=generic, value=value, name=name, strict=strict) return SharedVariable(type=generic, value=value, name=name, strict=strict,
allow_downcast=allow_downcast)
...@@ -123,7 +123,8 @@ class Container(object): ...@@ -123,7 +123,8 @@ class Container(object):
"""This class joins a variable with its computed value. """This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function. It is used in linkers, especially for the inputs and outputs of a Function.
""" """
def __init__(self, r, storage, readonly = False, strict = False, name = None): def __init__(self, r, storage, readonly=False, strict=False,
allow_downcast=False, name=None):
"""WRITEME """WRITEME
:Parameters: :Parameters:
...@@ -131,6 +132,7 @@ class Container(object): ...@@ -131,6 +132,7 @@ class Container(object):
`storage`: a list of length 1, whose element is the value for `r` `storage`: a list of length 1, whose element is the value for `r`
`readonly`: True indicates that this should not be setable by Function[r] = val `readonly`: True indicates that this should not be setable by Function[r] = val
`strict`: if True, we don't allow type casting. `strict`: if True, we don't allow type casting.
`allow_downcast`: if True (and `strict` is False), allow type upcasting, but not downcasting.
`name`: A string (for pretty-printing?) `name`: A string (for pretty-printing?)
""" """
...@@ -144,13 +146,14 @@ class Container(object): ...@@ -144,13 +146,14 @@ class Container(object):
if name is None: if name is None:
self.name = r.name self.name = r.name
#backport
#self.name = r.name if name is None else name
self.storage = storage self.storage = storage
self.readonly = readonly self.readonly = readonly
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast
def __get__(self): def __get__(self):
return self.storage[0] return self.storage[0]
def __set__(self, value): def __set__(self, value):
if self.readonly: if self.readonly:
raise Exception("Cannot set readonly storage: %s" % self.name) raise Exception("Cannot set readonly storage: %s" % self.name)
...@@ -164,10 +167,14 @@ class Container(object): ...@@ -164,10 +167,14 @@ class Container(object):
#That cause 2 region allocated at the same time! #That cause 2 region allocated at the same time!
#We decrement the memory reference conter now to try to lower the memory usage. #We decrement the memory reference conter now to try to lower the memory usage.
self.storage[0] = None self.storage[0] = None
kwargs = {}
if self.strict: if self.strict:
self.storage[0] = self.type.filter(value, strict = True) kwargs['strict'] = True
else: if self.allow_downcast:
self.storage[0] = self.type.filter(value) kwargs['allow_downcast'] = True
self.storage[0] = self.type.filter(value, **kwargs)
except Exception, e: except Exception, e:
e.args = e.args + (('Container name "%s"' % self.name),) e.args = e.args + (('Container name "%s"' % self.name),)
raise raise
......
...@@ -204,14 +204,17 @@ class PureType(object): ...@@ -204,14 +204,17 @@ class PureType(object):
Variable = graph.Variable #the type that will be created by call to make_variable. Variable = graph.Variable #the type that will be created by call to make_variable.
Constant = graph.Constant #the type that will be created by call to make_constant Constant = graph.Constant #the type that will be created by call to make_constant
def filter(self, data, strict = False): def filter(self, data, strict=False, allow_downcast=False):
"""Required: Return data or an appropriately wrapped/converted data. """Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if the data is not of an Subclass implementation should raise a TypeError exception if the data is not of an
acceptable type. acceptable type.
If strict is True, the data returned must be the same as the data passed as an If strict is True, the data returned must be the same as the
argument. If it is False, filter may cast it to an appropriate type. data passed as an argument. If it is False, and allow_downcast
is True, filter may cast it to an appropriate type. If
allow_downcast is False, filter may only upcast it, not lose
precision.
:Exceptions: :Exceptions:
- `MethodNotDefined`: subclass doesn't implement this function. - `MethodNotDefined`: subclass doesn't implement this function.
...@@ -222,7 +225,7 @@ class PureType(object): ...@@ -222,7 +225,7 @@ class PureType(object):
def is_valid_value(self, a): def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Variable of this Type""" """Required: Return True for any python object `a` that would be a legal value for a Variable of this Type"""
try: try:
self.filter(a, True) self.filter(a, strict=True)
return True return True
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
...@@ -350,7 +353,7 @@ class Generic(SingletonType): ...@@ -350,7 +353,7 @@ class Generic(SingletonType):
WRITEME WRITEME
""" """
def filter(self, data, strict = False): def filter(self, data, strict=False, allow_downcast=False):
return data return data
def is_valid_value(self, a): def is_valid_value(self, a):
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
import sys, os, StringIO import sys, os, StringIO
import numpy import numpy
import theano
from theano import Op, Type, Apply, Variable, Constant from theano import Op, Type, Apply, Variable, Constant
from theano import tensor, config from theano import tensor, config
from theano import scalar as scal
import cuda_ndarray.cuda_ndarray as cuda import cuda_ndarray.cuda_ndarray as cuda
import cuda_ndarray import cuda_ndarray
...@@ -50,8 +52,34 @@ class CudaNdarrayType(Type): ...@@ -50,8 +52,34 @@ class CudaNdarrayType(Type):
self.name = name self.name = name
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
def filter(self, data, strict=False): def filter(self, data, strict=False, allow_downcast=False):
return cuda.filter(data, self.broadcastable, strict, None) if strict or allow_downcast or isinstance(data, cuda.CudaNdarray):
return cuda.filter(data, self.broadcastable, strict, None)
else: # (not strict) and (not allow_downcast)
# Check if data.dtype can be accurately casted to self.dtype
if isinstance(data, numpy.ndarray):
up_dtype = scal.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype:
return cuda.filter(data, self.broadcastable, strict, None)
else:
raise TypeError(
'%s, with dtype %s, cannot store a value of '
'dtype %s without risking loss of precision.'
'If you do not mind, please cast your data to %s.'
% (self, self.dtype, data.dtype, self.dtype),
data)
else:
converted_data = theano._asarray(data, self.dtype)
if numpy.all(data == converted_data):
return cuda.filter(converted_data, self.broadcastable,
strict, None)
else:
raise TypeError(
'%s, with dtype %s, cannot store accurately value %s, '
'it would be represented as %s. If you do not mind, '
'you can cast your data to %s.'
% (self, self.dtype, data, converted_data, self.dtype),
data)
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
......
...@@ -69,15 +69,19 @@ class Scalar(Type): ...@@ -69,15 +69,19 @@ class Scalar(Type):
dtype = config.floatX dtype = config.floatX
self.dtype = dtype self.dtype = dtype
self.dtype_specs() # error checking self.dtype_specs() # error checking
def filter(self, data, strict = False): def filter(self, data, strict=False, allow_downcast=False):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data, raise TypeError("%s expected a %s, got %s of type %s" % (self, py_type, data,
type(data)), type(data)),
data) data)
try: try:
return py_type(data) converted_data = py_type(data)
if allow_downcast or data == converted_data:
return py_type(data)
else:
raise TypeError('Value cannot accurately be converted to dtype (%s) and allow_downcast is False' % self.dtype)
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e) raise TypeError("Could not convert %s (value=%s) to %s" % (type(data), data, self.dtype), e)
......
...@@ -168,16 +168,20 @@ class SparseType(gof.Type): ...@@ -168,16 +168,20 @@ class SparseType(gof.Type):
else: else:
raise NotImplementedError('unsupported format "%s" not in list' % format, self.format_cls.keys()) raise NotImplementedError('unsupported format "%s" not in list' % format, self.format_cls.keys())
def filter(self, value, strict = False): def filter(self, value, strict=False, allow_downcast=False):
if isinstance(value, self.format_cls[self.format])\ if isinstance(value, self.format_cls[self.format])\
and value.dtype == self.dtype: and value.dtype == self.dtype:
return value return value
if strict: if strict:
raise TypeError("%s is not sparse" % value) raise TypeError("%s is not sparse, or not the right dtype (is %s, expected %s)"
% (value, value.dtype, self.dtype))
#The input format could be converted here #The input format could be converted here
sp = self.format_cls[self.format](value) if allow_downcast:
if str(sp.dtype) != self.dtype: sp = self.format_cls[self.format](value, dtype=self.dtype)
raise NotImplementedError("Expected %s dtype but got %s"%(self.dtype,str(sp.dtype))) else:
sp = self.format_cls[self.format](value)
if str(sp.dtype) != self.dtype:
raise NotImplementedError("Expected %s dtype but got %s"%(self.dtype,str(sp.dtype)))
if sp.format != self.format: if sp.format != self.format:
raise NotImplementedError() raise NotImplementedError()
return sp return sp
......
...@@ -8,7 +8,8 @@ class SparseTensorSharedVariable(SharedVariable, _sparse_py_operators): ...@@ -8,7 +8,8 @@ class SparseTensorSharedVariable(SharedVariable, _sparse_py_operators):
pass pass
@shared_constructor @shared_constructor
def sparse_constructor(value, name=None, strict=False, borrow=False, format = None): def sparse_constructor(value, name=None, strict=False, allow_downcast=False,
borrow=False, format = None):
"""SharedVariable Constructor for SparseType """SharedVariable Constructor for SparseType
writeme writeme
...@@ -21,6 +22,7 @@ def sparse_constructor(value, name=None, strict=False, borrow=False, format = No ...@@ -21,6 +22,7 @@ def sparse_constructor(value, name=None, strict=False, borrow=False, format = No
type = SparseType(format =format, dtype = value.dtype) type = SparseType(format =format, dtype = value.dtype)
if not borrow: if not borrow:
value = copy.deepcopy(value) value = copy.deepcopy(value)
return SparseTensorSharedVariable(type = type, value = value, name=name, strict =strict) return SparseTensorSharedVariable(type=type, value=value, name=name,
strict=strict, allow_downcast=allow_downcast)
...@@ -400,7 +400,7 @@ class TensorType(Type): ...@@ -400,7 +400,7 @@ class TensorType(Type):
self.name = name self.name = name
self.numpy_dtype = numpy.dtype(self.dtype) self.numpy_dtype = numpy.dtype(self.dtype)
def filter(self, data, strict = False): def filter(self, data, strict=False, allow_downcast=False):
"""Convert `data` to something which can be associated to a `TensorVariable`. """Convert `data` to something which can be associated to a `TensorVariable`.
This function is not meant to be called in user code. It is for This function is not meant to be called in user code. It is for
...@@ -411,7 +411,7 @@ class TensorType(Type): ...@@ -411,7 +411,7 @@ class TensorType(Type):
elif strict: elif strict:
# this is its own subcase that doesn't fall through to anything # this is its own subcase that doesn't fall through to anything
if not isinstance(data, numpy.ndarray): if not isinstance(data, numpy.ndarray):
raise TypeError("%s expected a ndarray object.", data, type(data)) raise TypeError("%s expected a ndarray object." % self, data, type(data))
if not str(data.dtype) == self.dtype: if not str(data.dtype) == self.dtype:
raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.dtype, data.dtype)) raise TypeError("%s expected a ndarray object with dtype = %s (got %s)." % (self, self.dtype, data.dtype))
if not data.ndim == self.ndim: if not data.ndim == self.ndim:
...@@ -419,8 +419,54 @@ class TensorType(Type): ...@@ -419,8 +419,54 @@ class TensorType(Type):
return data return data
else: else:
data = theano._asarray(data, dtype = self.dtype) #TODO - consider to pad shape with ones if allow_downcast:
# to make it consistent with self.broadcastable... like vector->row type thing # Convert to self.dtype, regardless of the type of data
data = theano._asarray(data, dtype=self.dtype) #TODO - consider to pad shape with ones
# to make it consistent with self.broadcastable... like vector->row type thing
else:
if isinstance(data, numpy.ndarray):
# Check if self.dtype can accurately represent data
# (do not try to convert the data)
up_dtype = scal.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype:
# Bug in the following line when data is a scalar array,
# see http://projects.scipy.org/numpy/ticket/1611
#data = data.astype(self.dtype)
data = theano._asarray(data, dtype=self.dtype)
if up_dtype != self.dtype:
err_msg = (
'%s cannot store a value of dtype %s without '
'risking loss of precision. If you do not mind '
'this loss, you can: '
'1) explicitly cast your data to %s, or '
'2) set "allow_input_downcast=True" when calling '
'"function".'
% (self, data.dtype, self.dtype))
raise TypeError(err_msg, data)
else:
# data has to be converted.
# Check that this conversion is lossless
converted_data = theano._asarray(data, self.dtype)
if numpy.all(data == converted_data):
data = converted_data
else:
# Do not print a too long description of data
# (ndarray truncates it, but it's not sure for data)
str_data = str(data)
if len(str_data) > 80:
str_data = str_data[:75] + '(...)'
err_msg = (
'%s cannot store accurately value %s, '
'it would be represented as %s. '
'If you do not mind this precision loss, you can: '
'1) explicitly convert your data to a numpy array '
'of dtype %s, or '
'2) set "allow_input_downcast=True" when calling '
'"function".'
% (self, data, converted_data, self.dtype))
raise TypeError(err_msg, data)
if self.ndim != data.ndim: if self.ndim != data.ndim:
raise TypeError("Wrong number of dimensions: expected %s, got %s with shape %s." % (self.ndim, data.ndim, data.shape), data) raise TypeError("Wrong number of dimensions: expected %s, got %s with shape %s." % (self.ndim, data.ndim, data.shape), data)
i = 0 i = 0
...@@ -434,7 +480,7 @@ class TensorType(Type): ...@@ -434,7 +480,7 @@ class TensorType(Type):
def value_validity_msg(self, a): def value_validity_msg(self, a):
try: try:
self.filter(a, True) self.filter(a, strict=True)
except Exception, e: except Exception, e:
return str(e) return str(e)
return "value is valid" return "value is valid"
...@@ -4300,11 +4346,15 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -4300,11 +4346,15 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
raise TypeError('rng should be a valid instance of numpy.random.RandomState.', raise TypeError('rng should be a valid instance of numpy.random.RandomState.',
'You may want to use theano.tests.unittest_tools.verify_grad instead of theano.tensor.verify_grad.') 'You may want to use theano.tests.unittest_tools.verify_grad instead of theano.tensor.verify_grad.')
# We allow input downcast in function, because numeric_grad works in the
# most precise dtype used among the inputs, so we may need to cast some.
def function(inputs, output): def function(inputs, output):
if mode is None: if mode is None:
f = compile.function(inputs, output, accept_inplace=True) f = compile.function(inputs, output, accept_inplace=True,
allow_input_downcast=True)
else: else:
f = compile.function(inputs, output, accept_inplace=True, mode=mode) f = compile.function(inputs, output, accept_inplace=True,
allow_input_downcast=True, mode=mode)
return f return f
tensor_pt = [TensorType(as_tensor_variable(p).dtype, as_tensor_variable(p).broadcastable)(name='input %i'%i) for i,p in enumerate(pt)] tensor_pt = [TensorType(as_tensor_variable(p).dtype, as_tensor_variable(p).broadcastable)(name='input %i'%i) for i,p in enumerate(pt)]
......
...@@ -22,7 +22,7 @@ class RandomStateType(gof.Type): ...@@ -22,7 +22,7 @@ class RandomStateType(gof.Type):
def __str__(self): def __str__(self):
return 'RandomStateType' return 'RandomStateType'
def filter(self, data, strict=False): def filter(self, data, strict=False, allow_downcast=False):
if self.is_valid_value(data): if self.is_valid_value(data):
return data return data
else: else:
......
...@@ -12,7 +12,7 @@ class RandomStateSharedVariable(SharedVariable): ...@@ -12,7 +12,7 @@ class RandomStateSharedVariable(SharedVariable):
pass pass
@shared_constructor @shared_constructor
def randomstate_constructor(value, name=None, strict=False, borrow=False): def randomstate_constructor(value, name=None, strict=False, allow_downcast=False, borrow=False):
"""SharedVariable Constructor for RandomState""" """SharedVariable Constructor for RandomState"""
if not isinstance(value, numpy.random.RandomState): if not isinstance(value, numpy.random.RandomState):
raise TypeError raise TypeError
...@@ -20,9 +20,10 @@ def randomstate_constructor(value, name=None, strict=False, borrow=False): ...@@ -20,9 +20,10 @@ def randomstate_constructor(value, name=None, strict=False, borrow=False):
value = copy.deepcopy(value) value = copy.deepcopy(value)
return RandomStateSharedVariable( return RandomStateSharedVariable(
type=raw_random.random_state_type, type=raw_random.random_state_type,
value=value, value=value,
name=name, name=name,
strict=strict) strict=strict,
allow_downcast=allow_downcast)
class RandomStreams(raw_random.RandomStreamsBase): class RandomStreams(raw_random.RandomStreamsBase):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)""" """Module component with similar interface to numpy.random (numpy.random.RandomState)"""
......
...@@ -9,7 +9,7 @@ class TensorSharedVariable(SharedVariable, _tensor_py_operators): ...@@ -9,7 +9,7 @@ class TensorSharedVariable(SharedVariable, _tensor_py_operators):
pass pass
@shared_constructor @shared_constructor
def tensor_constructor(value, name=None, strict=False, borrow=False, broadcastable=None): def tensor_constructor(value, name=None, strict=False, allow_downcast=False, borrow=False, broadcastable=None):
"""SharedVariable Constructor for TensorType """SharedVariable Constructor for TensorType
:note: Regarding the inference of the broadcastable pattern... :note: Regarding the inference of the broadcastable pattern...
...@@ -27,7 +27,11 @@ def tensor_constructor(value, name=None, strict=False, borrow=False, broadcastab ...@@ -27,7 +27,11 @@ def tensor_constructor(value, name=None, strict=False, borrow=False, broadcastab
if broadcastable is None: if broadcastable is None:
broadcastable = (False,)*len(value.shape) broadcastable = (False,)*len(value.shape)
type = TensorType(value.dtype, broadcastable=broadcastable) type = TensorType(value.dtype, broadcastable=broadcastable)
return TensorSharedVariable(type=type, value=numpy.array(value,copy=(not borrow)), name=name, strict=strict) return TensorSharedVariable(type=type,
value=numpy.array(value,copy=(not borrow)),
name=name,
strict=strict,
allow_downcast=allow_downcast)
# TensorSharedVariable brings in the tensor operators, is not ideal, but works as long as we # TensorSharedVariable brings in the tensor operators, is not ideal, but works as long as we
# dont do purely scalar-scalar operations # dont do purely scalar-scalar operations
...@@ -35,7 +39,7 @@ class ScalarSharedVariable(SharedVariable, _tensor_py_operators): ...@@ -35,7 +39,7 @@ class ScalarSharedVariable(SharedVariable, _tensor_py_operators):
pass pass
@shared_constructor @shared_constructor
def scalar_constructor(value, name=None, strict=False): def scalar_constructor(value, name=None, strict=False, allow_downcast=False):
"""SharedVariable constructor for scalar values. Default: int64 or float64. """SharedVariable constructor for scalar values. Default: int64 or float64.
:note: We implement this using 0-d tensors for now. :note: We implement this using 0-d tensors for now.
...@@ -57,7 +61,7 @@ def scalar_constructor(value, name=None, strict=False): ...@@ -57,7 +61,7 @@ def scalar_constructor(value, name=None, strict=False):
# strict is True and the types do not match. # strict is True and the types do not match.
rval = ScalarSharedVariable(type=tensor_type, rval = ScalarSharedVariable(type=tensor_type,
value=numpy.array(value, copy=True), value=numpy.array(value, copy=True),
name=name, strict=strict) name=name, strict=strict, allow_downcast=allow_downcast)
return rval return rval
except: except:
traceback.print_exc() traceback.print_exc()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论