提交 605b91e4 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Copied content of TheanoProposal_shared_val into theano.compile.sandbox

上级 e9f194f5
"""Provide a simple user friendly API """
__docformat__ = 'restructuredtext en'
from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.compile import function, In
from sharedvalue import SharedVariable, shared
class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, strict=False,
implicit=None):
"""
:param variable: A node in an expression graph to set with each function call.
:param default: The default value to use at call-time (can also be a Container where
the function will find a value at call-time.)
:param name: A string to identify this parameter from function kwargs.
:param mutable: True -> function is allowed to modify this argument.
:param strict: False -> function arguments may be copied or casted to match the
type required by the parameter `variable`. True -> function arguments must exactly match the type
required by `variable`.
:param implicit: see help(theano.io.In)
"""
self.variable = variable
self.default = default
self.name = name
self.mutable = mutable
self.strict = strict
self.implicit = implicit
def pfunc(params, outputs=None, mode=None, updates=[]):
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
:param params: function parameters, these are not allowed to be shared
variables
:type outputs: list of Variables or Out instances
:param outputs: expressions to compute
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs according to these expressions
:rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`.
"""
# Note: in its early design, pfunc was also meant to accept another
# parameter, 'givens'. This was a dictionary assigning some specific
# values to some of the Variable in the graph, so as to allow the
# function to possibly make some optimizations at compile time.
# In the end, this feature was not kept, because it was not obvious
# how to implement it, nor whether it was really needed.
# If one wants to add this feature in the future, it may be easier instead
# to add a new parameter to 'Param' to indicate that some input of the
# function is taking a specific constant value.
if not isinstance(outputs, list):
computed_list = [outputs]
else:
# Copy list (because it may be extended later).
computed_list = [out for out in outputs]
# transform params into theano.compile.In objects.
#
# call theano.function
inputs = [_pfunc_param_to_in(p) for p in params]
set_of_param_variables = set([i.variable for i in inputs])
# It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also
# potentially ambiguous and dangerous. This restriction may be revisited in
# the future if there is a need for such a feature.
if any([isinstance(v, SharedVariable) for v in set_of_param_variables]):
raise TypeError('Cannot use a shared variable (%s) as explicit input '
% v)
# Add update values as quantities that must be computed.
new_updates = {}
for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(update_val, Variable):
# The value for the update is not a Variable: we cast it into
# a shared Variable so that it can be used by 'function'. Note that
# it means the update value may change if it is mutable and its
# value is modified after the function is created.
update_val = shared(update_val)
computed_list.append(update_val)
new_updates[store_into] = update_val
updates = new_updates
# Obtain all inputs we need to compute what we want.
graph_inputs = graph.inputs(computed_list,
blockers=set([i.variable for i in inputs]))
shared_inputs = [i for i in graph_inputs if isinstance(i, SharedVariable)]
# Add shared variables (from shared_inputs) that were not already present in the list of
# params.
inputs += [In(variable=si, value=si.container, mutable=False)
for si in shared_inputs
if si not in set_of_param_variables]
# Iterate over the updates, which are either pairs
# (shared_var, expressionvariable), or a similar dictionary.
# For each shared_variable, find the In instance that we created for it in the inputs list.
# Give that In instance (in_sv) an update expression.
#
# I think we usually want to set these Inputs to be mutable,
# ... are there exceptions?
for (sv, new_val) in iter_over_pairs(updates):
in_sv = None
for in_sv_i in inputs:
if in_sv_i.variable is sv:
assert in_sv is None
in_sv = in_sv_i
if in_sv is None:
# This variable was not used anywhere and thus is not in the input
# list yet.
inputs.append(In(variable=sv, value=sv.container, mutable=True,
update=new_val))
else:
in_sv.update = new_val
in_sv.mutable = True
return function(inputs, outputs, mode, accept_inplace=False)
def _pfunc_param_to_in(param):
if isinstance(param, Constant):
raise TypeError('Constants not allowed in param list', param)
if isinstance(param, Value):
raise NotImplementedError()
if isinstance(param, Variable): #includes SharedVariable
return In(variable=param)
elif isinstance(param, Param):
return In(
variable=param.variable,
name=param.name,
value=param.default,
mutable=param.mutable,
strict=param.strict,
implicit = param.implicit)
raise NotImplementedError()
def iter_over_pairs(pairs):
"""
Return an iterator over pairs present in the 'pairs' input.
:type pairs: dictionary or iterable
:param pairs: The pairs to iterate upon. These may be stored either as
(key, value) items in a dictionary, or directly as pairs in any kind of
iterable structure
:rtype: iterable
:returns: an iterable yielding pairs
"""
if isinstance(pairs, dict):
return pairs.iteritems()
else:
return pairs
======================================
Proposal for pfunc Function Interface
======================================
Following discussion on theano-dev (titled TheanoObject), the following
changes are proposed to make function-construction calls more
readable and intuitive, and to make it easier to share values between
functions.
The strategy is to
- introduce a new kind of ``Variable`` (``SharedVariable``) that has a container
associated with it, and can allow multiple functions to share a value.
- introduce a class called ``Param`` to serve a role similar to that of ``In``,
- introduce a friendlier version of function (tentative name ``pfunc``),
The following code gives a very quick idea of what is being proposed:
..code-block:: python
a = lscalar()
b = shared(1) #NEW: create a shared variable
f1 = pfunc([a], a+b)
f2 = pfunc([Param(a, default=44)], a + b, updates={b: b + 1})
b.value # -> 1
f1(3) # -> 4
f2(3) # -> 4 (but update b.value with += 1)
b.value # -> 2
f1(3) # -> 5
b.value = 0
f1(3) # -> 3
Declaring a Shared Variable
===========================
The proposal is for two new ways of creating a *shared* variable:
.. code-block:: python
class SharedVariable(Variable):
"""
Variable with a value that is (defaults to being) shared between functions that it appears in.
"""
def __init__(self, name, type, value, strict):
"""
:param name: The name for this variable (see `Variable`).
:param type: The type for this variable (see `Variable`).
:param value: A value to associate with this variable (a new container will be created).
:param strict: True -> assignments to .value will not be casted or copied, so they must
have the correct type.
:param container: The container to use for this variable. Illegal to pass this as well
as a value.
For more user-friendly constructor, see `shared`
"""
...
value = property(...)
"""Read/write the non-symbolic value associated with this SharedVariable.
If the SharedVariable is shared, changes to this value will be visible to all functions using
this SharedVariable. If this SharedVariable is not shared, a change will not be visible to
functions that were created before the change.
"""
def shared(value, name=None, strict=False, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or reference of `value`.
This function iterates over constructor functions (see `shared_constructor`) to find a
suitable SharedVariable subclass.
:note:
By passing kwargs, you effectively limit the set of potential constructors to those that
can accept those kwargs.
"""
...
The function `shared` is a factory-method intended for end-users.
Direct construction of a ``SharedVariable`` is probably not going to be a common
pattern, it will be more common to subclass it (i.e. ``TensorSharedVariable``,
``SparseSharedVariable``, etc.) and to register a constructor so that these
subclasses will be instantiated by the `shared` factory method.
A ``SharedVariable`` instance is meant to change over the duration of a program,
either because of the updates of a function call, or because of direct
assignment to its ``.value`` field.
At any time, the ``.value`` field can be be used to access the current value
associated with the shared value.
Using SharedVariables as pfunc Parameters
=========================================
A ``SharedVariable`` instance has a ``value`` property that can be used to get and
set the value associated with that shared variable in all the ``pfunc``
functions that use it.
.. code-block:: python
a = tensor.lscalar()
b = shared(7)
# create two functions that use `b` as an implicit input
f1 = pfunc([a], a + b)
f2 = pfunc([a], a * b)
f1(5) # -> 12
b.value = 8 # modify the shared variable's value
f1(5) # -> 13 # the new value is reflected in any compiled functions
f2(4) # -> 32 # f2 uses the latest value in b's container
However, SharedVariables cannot be used as inputs to theano functions.
This is because doing it may yield code that would be either ambiguous, or
prone to easy mistakes (e.g. accidentally overwriting the content of a shared
variable).
Param and pfunc
===============
The examples above give the general flavour of what pfunc and Param are for.
Their signatures are below.
Corner cases and exotic examples can be found in the tests.
.. code-block:: python
def pfunc(params, outputs, mode=None, givens={}, updates=[])
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
:param params: function parameters, these are not allowed to be shared
variables
:type outputs: list of Variables or Out instances
:param outputs: expressions to compute
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs according to these expressions
:rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`.
"""
...
.. code-block:: python
class Param(object):
def __init__(self, variable, default=None, mutable=False, strict=False):
"""
:param variable: A node in an expression graph to set with each function call.
:param default: The default value to use at call-time (can also be a Container where
the function will find a value at call-time.)
:param name: A string to identify this parameter from function kwargs.
:param mutable: True -> function is allowed to modify this argument.
:param strict: False -> function arguments may be copied or casted to match the
type required by the parameter `variable`. True -> function arguments must exactly match the type
required by `variable`.
:param implicit: see help(theano.io.In)
"""
Note that if some update value is not a variable, it will be cast into
a ``SharedVariable`` using the ``shared`` function. This ensures it is
properly taken into account to build the Theano function underlying the
``pfunc``. A consequence of this is that if this update value is mutable
(e.g. a Numpy array), it may be modified after the function is created.
NNet Example
============
Of course there are lots of ways to write the following code, but this is one
simple one.
.. code-block:: python
import numpy, theano
from pfunc import pfunc
from sharedvalue import shared
from theano import tensor
from theano.tensor.nnet import sigmoid
class NNet(object):
def __init__(self,
input = tensor.dvector('input'),
target = tensor.dvector('target'),
n_input=1, n_hidden=1, n_output=1, lr=1e-3, **kw):
super(NNet, self).__init__(**kw)
self.input = input
self.target = target
self.lr = shared(lr, 'learning_rate')
self.w1 = shared(numpy.zeros((n_hidden, n_input)), 'w1')
self.w2 = shared(numpy.zeros((n_output, n_hidden)), 'w2')
self.hidden = sigmoid(tensor.dot(self.w1, self.input))
self.output = tensor.dot(self.w2, self.hidden)
self.cost = tensor.sum((self.output - self.target)**2)
self.sgd_updates = {
self.w1: self.w1 - self.lr * tensor.grad(self.cost, self.w1),
self.w2: self.w2 - self.lr * tensor.grad(self.cost, self.w2)}
self.sgd_step = pfunc(
params = [self.input, self.target],
outputs = [self.output, self.cost],
updates = self.sgd_updates)
self.compute_output = pfunc([self.input], self.output)
self.output_from_hidden = pfunc([self.hidden], self.output)
"""Provide a simple user friendly API """
__docformat__ = 'restructuredtext en'
import copy
from theano.gof import Container, Variable, generic
import theano.tensor.basic
from theano.tensor import TensorType
from theano.scalar import Scalar
from theano.compile import function
import numpy
class SharedVariable(Variable):
"""
Variable that is (defaults to being) shared between functions that it appears in.
"""
#Container object
container = None
"""
A container to use for this SharedVariable when it is an implicit function parameter.
:type: `Container`
"""
def __init__(self, name, type, value, strict, container=None):
"""
:param name: The name for this variable (see `Variable`).
:param type: The type for this variable (see `Variable`).
:param value: A value to associate with this variable (a new container will be created).
:param strict: True -> assignments to .value will not be casted or copied, so they must
have the correct type.
:param container: The container to use for this variable. Illegal to pass this as well
as a value.
For more user-friendly constructor, see `shared`
"""
super(SharedVariable, self).__init__(type=type, name=name, owner=None, index=None)
if container is not None:
self.container = container
if (value is not None) or (strict is not None):
raise TypeError('value and strict are ignored if you pass a container here')
else:
if container is not None:
raise TypeError('Error to specify both value and container')
self.container = Container(self,
storage=[type.filter(value, strict=strict)],
readonly=False,
strict=strict)
def __set(self,new_value):
self.container.value = new_value
def __get(self):
return self.container.value
def clone(self):
cp = self.__class__(
name=self.name,
type=self.type,
value=None,
strict=None,
container=self.container)
cp.tag = copy.copy(self.tag)
return cp
value = property(__get, __set)
#value = self.container.value #GD- would've thought mapping one property to another would work
"""Read/write the non-symbolic value associated with this SharedVariable.
If the SharedVariable is shared, changes to this value will be visible to all functions using
this SharedVariable. If this SharedVariable is not shared, a change will not be visible to
functions that were created before the change.
"""
def shared_constructor(ctor):
shared.constructors.append(ctor)
return ctor
def shared(value, name=None, strict=False, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or reference of `value`.
This function iterates over constructor functions (see `shared_constructor`) to find a
suitable SharedVariable subclass.
:note:
By passing kwargs, you effectively limit the set of potential constructors to those that
can accept those kwargs.
"""
for ctor in reversed(shared.constructors):
try:
return ctor(value, name=name, strict=strict, **kwargs)
except TypeError:
continue
# This may happen when kwargs were supplied
# if kwargs were given, the generic_constructor won't be callable.
#
# This was done on purpose, the rationale being that if kwargs were supplied,
# the user didn't want them to be ignored.
raise TypeError('No suitable SharedVariable constructor could be found', (value, kwargs))
shared.constructors = []
@shared_constructor
def generic_constructor(value, name=None, strict=False):
"""SharedVariable Constructor"""
return SharedVariable(type=generic, value=value, name=name, strict=strict)
class TensorSharedVariable(SharedVariable, theano.tensor.basic._tensor_py_operators):
pass
@shared_constructor
def tensor_constructor(value, name=None, strict=False):
"""SharedVariable Constructor for TensorType"""
if not isinstance(value, numpy.ndarray):
raise TypeError
bcast = [b==1 for b in value.shape]
type = TensorType(value.dtype, broadcastable=bcast)
return TensorSharedVariable(type=type, value=value, name=name, strict=strict)
@shared_constructor
def scalar_constructor(value, name=None, dtype=None, strict=False):
"""SharedVariable constructor for scalar values. Defaults to int64 or float64"""
if not isinstance(value, (float,int)):
raise TypeError
# use float64 and int64 by default, user can override
if not dtype:
dtype = 'int64' if isinstance(value,int) else 'float64'
type = Scalar(dtype)
return TensorSharedVariable(type=type, value=numpy.asarray(value), name=name, strict=strict)
import numpy, theano, unittest
from pfunc import pfunc
from sharedvalue import shared
from theano import tensor
from theano.tensor.nnet import sigmoid
class NNet(object):
def __init__(self,
input = tensor.dvector('input'),
target = tensor.dvector('target'),
n_input=1, n_hidden=1, n_output=1, lr=1e-3, **kw):
super(NNet, self).__init__(**kw)
self.input = input
self.target = target
self.lr = shared(lr, 'learning_rate')
self.w1 = shared(numpy.zeros((n_hidden, n_input)), 'w1')
self.w2 = shared(numpy.zeros((n_output, n_hidden)), 'w2')
self.hidden = sigmoid(tensor.dot(self.w1, self.input))
self.output = tensor.dot(self.w2, self.hidden)
self.cost = tensor.sum((self.output - self.target)**2)
self.sgd_updates = {
self.w1: self.w1 - self.lr * tensor.grad(self.cost, self.w1),
self.w2: self.w2 - self.lr * tensor.grad(self.cost, self.w2)}
self.sgd_step = pfunc(
params = [self.input, self.target],
outputs = [self.output, self.cost],
updates = self.sgd_updates)
self.compute_output = pfunc([self.input], self.output)
self.output_from_hidden = pfunc([self.hidden], self.output)
class TestNnet(unittest.TestCase):
def test_nnet(self):
#theano.compile.default_mode = 'FAST_RUN'
rng = numpy.random.RandomState(1827)
data = rng.rand(10, 4)
nnet = NNet(n_input = 3, n_hidden = 10)
for epoch in range(3):
mean_cost = 0
for x in data:
input = x[0:3]
target = x[3:]
output, cost = nnet.sgd_step(input, target)
mean_cost += cost
mean_cost /= float(len(data))
print 'Mean cost at epoch %s: %s' % (epoch, mean_cost)
self.failUnless(abs(mean_cost - 0.20588975452) < 1e-6)
# Just call functions to make sure they do not crash.
out = nnet.compute_output(input)
out = nnet.output_from_hidden(numpy.ones(10))
import numpy
import unittest
import copy
import theano
from theano.tensor import Tensor, dmatrix, dvector, lscalar
from theano import tensor
from sharedvalue import *
from pfunc import *
class Test_pfunc(unittest.TestCase):
def test_doc(self):
"""Ensure the code given in pfunc.txt works as expected"""
# Example #1.
a = lscalar()
b = shared(1)
f1 = pfunc([a], a+b)
f2 = pfunc([Param(a, default=44)], a + b, updates={b: b + 1})
self.failUnless(b.value == 1)
self.failUnless(f1(3) == 4)
self.failUnless(f2(3) == 4)
self.failUnless(b.value == 2)
self.failUnless(f1(3) == 5)
b.value = 0
self.failUnless(f1(3) == 3)
# Example #2.
a = tensor.lscalar()
b = shared(7)
f1 = pfunc([a], a + b)
f2 = pfunc([a], a * b)
self.failUnless(f1(5) == 12)
b.value = 8
self.failUnless(f1(5) == 13)
self.failUnless(f2(4) == 32)
def test_shared(self):
# CHECK: two functions (f1 and f2) can share w
w = shared(numpy.random.rand(2,2), 'w')
wval = copy.copy(w.value)
x = dmatrix()
out1 = w + x
out2 = w * x
f1 = pfunc([x],[out1])
f2 = pfunc([x],[out2])
xval = numpy.random.rand(2,2)
assert numpy.all(f1(xval) == xval + wval)
assert numpy.all(f2(xval) == xval * wval)
# CHECK: updating a shared value
f3 = pfunc([x], out1, updates=[(w, w-1)])
assert numpy.all(f3(xval) == xval + wval) # f3 changes the value of w
assert numpy.all(f1(xval) == xval + (wval-1)) # this same value is read by f1
w.value *= 10
assert numpy.all(f1(xval) == xval + w.value) # this same value is read by f1
def test_no_shared_as_input(self):
"""Test that shared variables cannot be used as function inputs."""
w_init = numpy.random.rand(2,2)
w = shared(w_init.copy(), 'w')
try:
f = pfunc([w], theano.tensor.sum(w * w))
assert False
except TypeError, e:
msg = 'Cannot use a shared variable (w) as explicit input'
if str(e).find(msg) < 0:
raise
def test_default_container(self):
# Ensure it is possible to (implicitly) use a shared variable in a
# function, as a 'state' that can be updated at will.
rng = numpy.random.RandomState(1827)
w_init = rng.rand(5)
w = shared(w_init.copy(), 'w')
reg = theano.tensor.sum(w*w)
f = pfunc([], reg)
assert f() == numpy.sum(w_init * w_init)
# Change the value of w and ensure the output changes accordingly.
w.value += 1.0
assert f() == numpy.sum((w_init+1)**2)
def test_default_scalar_container(self):
# Similar in spirit to test_default_container, but updating a scalar
# variable. This is a sanity check for non mutable types.
x = shared(0.0, 'x')
f = pfunc([], x)
assert f() == 0
x.value += 1
assert f() == 1
def test_param_strict(self):
a = tensor.dvector()
b = shared(7)
out = a + b
f = pfunc([Param(a, strict=False)], [out])
f(numpy.random.rand(8)) # works, rand generates float64 by default
f(numpy.array([1,2,3,4], dtype='int32')) # works, casting is allowed
f = pfunc([Param(a, strict=True)], [out])
try:
f(numpy.array([1,2,3,4], dtype='int32')) # fails, f expects float64
except TypeError:
pass
def test_param_mutable(self):
a = tensor.dvector()
b = shared(7)
out = a + b
a_out = a * 2 # assuming the op which makes this "in place" triggers
# using mutable=True will let fip change the value in aval
fip = pfunc([Param(a, mutable=True)], [a_out], mode='FAST_RUN')
aval = numpy.random.rand(10)
aval2 = aval.copy()
assert numpy.all( fip(aval) == aval2*2 )
assert not numpy.all( aval == aval2 )
# using mutable=False should leave the input untouched
f = pfunc([Param(a, mutable=False)], [a_out], mode='FAST_RUN')
aval = numpy.random.rand(10)
aval2 = aval.copy()
assert numpy.all( f(aval) == aval2*2 )
assert numpy.all( aval == aval2 )
def test_shared_mutable(self):
bval = numpy.arange(5)
b = shared(bval)
assert b.value is bval
b_out = b * 2
# by default, shared are not mutable unless doing an explicit update
f = pfunc([], [b_out], mode='FAST_RUN')
assert (f() == numpy.arange(5) * 2).all()
assert all( b.value == numpy.arange(5))
# using updates, b is now a mutable parameter
f = pfunc([], [b_out], updates=[(b, b_out)], mode='FAST_RUN')
assert (f() == numpy.arange(5)*2 ).all()
assert all( b.value == numpy.arange(5)*2) # because of the update
assert all( bval == numpy.arange(5)*2) # because of mutable=True
# do not depend on updates being in-place though!
bval = numpy.arange(5)
b.value = bval
f = pfunc([], [b_out], updates=[(b, b_out+3)], mode='FAST_RUN')
assert ( f() == numpy.arange(5)*2 ).all()
assert (b.value == ((numpy.arange(5)*2)+3)).all() # because of the update
# bval got modified to something...
assert not all(bval == numpy.arange(5))
# ... but not to b.value !
assert not (bval == b.value).all()
def test_update(self):
"""Test update mechanism in different settings."""
# Simple value assignment.
x = shared(0)
assign = pfunc([], [], updates = {x: 3})
assign()
self.failUnless(x.value == 3)
# Same but using a mutable constant to show how it can be used to
# modify the update value after the function is created.
x.value = 0
y = numpy.ones(())
assign_mutable = pfunc([], [], updates = {x: y})
assign_mutable()
self.failUnless(x.value == 1)
y.fill(4)
assign_mutable()
self.failUnless(x.value == 4)
# Basic increment function.
x.value = 0
inc = pfunc([], [], updates = {x: x + 1})
inc()
self.failUnless(x.value == 1)
# Increment by a constant value.
x.value = -1
y = shared(2)
inc_by_y = pfunc([], [], updates = {x: x + y})
inc_by_y()
self.failUnless(x.value == 1)
if __name__ == '__main__':
theano.compile.mode.default_mode = 'FAST_COMPILE'
Test_pfunc().test_default_scalar_container()
import numpy
import unittest
import copy
import theano
from theano.tensor import Tensor
from sharedvalue import *
class Test_SharedVariable(unittest.TestCase):
def test_ctors(self):
assert shared(7).type == Scalar('int64')
assert shared(7.0).type == Scalar('float64')
assert shared(7, dtype='float64').type == Scalar('float64')
# test tensor constructor
b = shared(numpy.zeros((5,5), dtype='int32'))
assert b.type == TensorType('int32', broadcastable=[False,False])
b = shared(numpy.random.rand(4,5))
assert b.type == TensorType('float64', broadcastable=[False,False])
b = shared(numpy.random.rand(5,1,2))
assert b.type == TensorType('float64', broadcastable=[False,True,False])
assert shared([]).type == generic
def badfunc():
shared(7, bad_kw=False)
self.failUnlessRaises(TypeError, badfunc)
def test_strict_generic(self):
#this should work, because
# generic can hold anything even when strict=True
u = shared('asdf', strict=False)
v = shared('asdf', strict=True)
u.value = 88
v.value = 88
def test_create_numpy_strict_false(self):
# here the value is perfect, and we're not strict about it,
# so creation should work
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=numpy.asarray([1., 2.]),
strict=False)
# here the value is castable, and we're not strict about it,
# so creation should work
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=[1., 2.],
strict=False)
# here the value is castable, and we're not strict about it,
# so creation should work
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=[1, 2], #different dtype and not a numpy array
strict=False)
# here the value is not castable, and we're not strict about it,
# this is beyond strictness, it must fail
try:
SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=dict(), #not an array by any stretch
strict=False)
assert 0
except TypeError:
pass
def test_use_numpy_strict_false(self):
# here the value is perfect, and we're not strict about it,
# so creation should work
u = SharedVariable(
name='u',
type=Tensor(broadcastable=[False], dtype='float64'),
value=numpy.asarray([1., 2.]),
strict=False)
# check that assignments to value are casted properly
u.value = [3,4]
assert type(u.value) is numpy.ndarray
assert str(u.value.dtype) == 'float64'
assert numpy.all(u.value == [3,4])
# check that assignments of nonsense fail
try:
u.value = 'adsf'
assert 0
except ValueError:
pass
# check that an assignment of a perfect value results in no copying
uval = numpy.asarray([5,6,7,8], dtype='float64')
u.value = uval
assert u.value is uval
def test_strict(self):
def f(var, val): var.value = val
b = shared(7, strict=True)
self.failUnlessRaises(TypeError, f(b,8.23))
b = shared(7.234, strict=True)
self.failUnlessRaises(TypeError, f(b,8))
c = shared(numpy.zeros((5,5), dtype='float32'))
self.failUnlessRaises(TypeError, f(b, numpy.random.rand(5,5)))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论