提交 c104defb authored 作者: james@X40's avatar james@X40

more changes to Module, more tests

上级 c31223aa
......@@ -32,14 +32,14 @@ import function_module as F
from mode import default_mode
def join(*args):
def name_join(*args):
"""
Creates a string representation for the given names:
join('a', 'b', 'c') => 'a.b.c'
"""
return ".".join(arg for arg in args if arg)
def split(sym, n=-1):
def name_split(sym, n=-1):
"""
Gets the names from their joined representation
split('a.b.c') => ['a', 'b', 'c']
......@@ -55,7 +55,7 @@ def canonicalize(name):
[Fred: why we return the right type? Why int only?]
"""
if isinstance(name, str):
name = split(name)
name = name_split(name)
def convert(x):
try:
return int(x)
......@@ -63,7 +63,6 @@ def canonicalize(name):
return x
return map(convert, name)
class AllocationError(Exception):
"""
Exception raised when a Result has no associated storage.
......@@ -116,7 +115,7 @@ class Component(object):
else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
self.parent = parent
self.name = join(parent.name, name)
self.name = name_join(parent.name, name)
return self
def bound(self):
......@@ -303,29 +302,79 @@ class Member(_RComponent):
return memo[self.r].value
class Method(Component):
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
"""
Method is a declaration of a function. It contains inputs,
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them
without declaring them in the inputs, outputs or updates list.
[TODO: remove references to kits, for they are not really
needed anymore]
inputs, outputs or updates may be strings. In that case, they
will be resolved in the Composite which is the parent of this
Method.
Method builds a Function (same structure as a call to
theano.function)
"""
inputs = []
"""function inputs (see `compile.function`)
If Module members are named explicitly in this list, then they will not use shared storage.
Storage must be provided either via an `io.In` value argument, or at the point of the
function call.
"""
outputs=None
"""function outputs (see `compile.function`)"""
updates = {}
"""update expressions for module members
If this method should update the shared storage value for a Module member, then the
update expression must be given in this dictionary.
Keys in this dictionary must be members of the module graph--results for which this Method
will use the shared storage.
The value associated with each key should be a Result (or a string that can be resolved to
a Result) representing the computation of a new value for this shared storage after
each function call.
"""
mode=None
"""This will override the Module compilation mode for this Method"""
def __init__(self, inputs, outputs, updates = {}, mode=None, **kwupdates):
"""Initialize attributes
:param inputs: value for `Method.inputs`
:param outputs: value for `Method.outputs`
:param updates: value for `Method.updates`
:param kwupdates: additions to `updates`
:param mode: value for `Method.mode`
:type inputs: list of (str or `Result` or `io.In`)
:type outputs: None or str or `Result` or `io.Out` or list of (str or `Result` or
`io.Out`)
:type updates: dict of `Result` or str -> `Result` or str
:type kwupdates: extra updates
:type mode: None or any mode accepted by `compile.function`
"""
super(Method, self).__init__()
self.inputs = inputs
self.outputs = outputs
self.updates = dict(updates, **kwupdates)
self.kits = list(kits)
self.mode = mode
def bind(self, parent, name, dup_ok=True):
rval = super(Method, self).bind(parent, name, dup_ok=dup_ok)
......@@ -333,8 +382,12 @@ class Method(Component):
return rval
def resolve(self, name):
"""
Resolves the name of an input or output in the parent.
"""Return the Result corresponding to a given name
:param name: the name of a Result in the Module instance containing this Method
:type name: str
:rtype: `Result`
"""
if not self.bound():
raise ValueError('Trying to resolve a name on an unbound Method.')
......@@ -343,7 +396,12 @@ class Method(Component):
raise TypeError('Expected a Component with subtype Member or External.')
return result
def resolve_result(self, x, passthrough=(gof.Result)):
def resolve_all(self):
"""Convert all inputs, outputs, and updates specified as strings to Results.
This works by searching the containing Module for Result attributes by these names.
"""
def resolve_result(x, passthrough=(gof.Result)):
if isinstance(x, passthrough):
return x
elif isinstance(x, _RComponent):
......@@ -351,41 +409,34 @@ class Method(Component):
else:
return self.resolve(x).r
def resolve_inputs(self):
def resolve_inputs():
if isinstance(self.inputs, (io.In, gof.Result, str)):
inputs = [self.inputs]
else:
inputs = list(self.inputs)
self.inputs = [self.resolve_result(input,
self.inputs = [resolve_result(input,
passthrough=(gof.Result, io.In)) for input in inputs]
def resolve_outputs(self):
if isinstance(self.outputs, (io.Out, gof.Result, str)):
def resolve_outputs():
if isinstance(self.outputs, (io.Out, gof.Result, str, None)):
output = self.outputs
self.outputs = self.resolve_result(output,
passthrough=(gof.Result, io.Out))
self.outputs = resolve_result(output,
passthrough=(gof.Result, io.Out, None))
else:
outputs = list(self.outputs)
self.outputs = [self.resolve_result(output,
self.outputs = [resolve_result(output,
passthrough=(gof.Result, io.Out)) for output in outputs]
def resolve_updates(self):
def resolve_updates():
updates = self.updates
self.updates = {}
for k, v in updates.iteritems():
k, v = self.resolve_result(k), self.resolve_result(v)
k, v = resolve_result(k), resolve_result(v)
self.updates[k] = v
def resolve_all(self):
"""
Resolves all inputs, outputs and updates that were given as
strings so that the fields contain the corresponding Result
instances instead.
"""
self.resolve_inputs()
self.resolve_outputs()
self.resolve_updates()
resolve_inputs()
resolve_outputs()
resolve_updates()
def allocate(self, memo):
"""
......@@ -394,13 +445,21 @@ class Method(Component):
return None
def build(self, mode, memo, allocate_all = False):
"""
Produces a function. If allocate_all is True, storage will be
allocated for all needed Results, even if there is no
"""Compile a function for this Method.
:param allocate_all: if True, storage will be
allocated for all needed Results even if there is no
associated storage for them in the memo. If allocate_all is
False, storage will only be allocated for Results that are
reachable from the inputs list.
:returns: a function that implements this method
:rtype: `Function` instance
"""
if self in memo:
return memo[self]
self.resolve_all() # resolve all so we don't have to mess with strings
def get_storage(r, require = False):
# If require is True, we can only get storage from the memo.
......@@ -430,7 +489,7 @@ class Method(Component):
else:
raise TypeError(input, type(input))
# Deal with updates
# Deal with updates to shared storage
for k, v in self.updates.iteritems():
assert isinstance(k, gof.Result)
assert isinstance(v, gof.Result)
......@@ -441,7 +500,7 @@ class Method(Component):
if input.result == k:
input_k = input
print 'METHOD UPDATE', k, v, input_k
#print 'METHOD UPDATE', k, v, input_k
if input_k is None:
# this is an implicit input,
# use shared storage
......@@ -452,10 +511,9 @@ class Method(Component):
mutable=True)
inputs.append(input_k)
else:
# this was an explicit input
# don't use shared storage
input_k.update=v
input_k.mutable=True
raise ValueError(('Result listed in both inputs and updates.'
' Use inputs to use your own storage, use updates to '
'work on module-shared storage'), k)
outputs = self.outputs
_inputs = [x.result for x in inputs]
......@@ -478,7 +536,10 @@ class Method(Component):
assert type(storage) is io.In
inputs.append(storage)
return F.function(inputs, outputs, mode)
effective_mode = mode if self.mode is None else self.mode
rval = F.function(inputs, outputs, effective_mode)
memo[self] = rval
return rval
def pretty(self, **kwargs):
self.resolve_all()
......@@ -507,10 +568,10 @@ class Method(Component):
def dup(self):
self.resolve_all()
return self.__class__(list(self.inputs),
list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
dict(self.updates),
list(self.kits))
return self.__class__(inputs=list(self.inputs),
outputs=list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
updates=dict(self.updates),
mode=self.mode)
def __call__(self, *args, **kwargs):
raise TypeError("'Method' object is not callable"
......
......@@ -267,6 +267,7 @@ class T_module(unittest.TestCase):
def get_element(i):
return [i.x,i.lx[0],i.tx[0],i.dx['x'],i.llx[0][0], i.llx[1][0], i.ltx[0][0], i.ldx[0]['x'], i.tlx[0][0], i.tlx[0][0], i.tdx[0]['x'], i.dlx['x'][0], i.dtx['x'][0], i.ddx['x']['x']]
m1=Module()
m2=Module()
x=T.dscalar()
......@@ -393,7 +394,13 @@ class T_module(unittest.TestCase):
assert isinstance(inst.dy['y'],theano.compile.function_module.Function)
assert isinstance(inst.tty[0][0],theano.compile.function_module.Function)
print >> sys.stderr, "MODULE TEST IMPLEMENTED BUT WE DON'T KNOW WHAT WE WANT AS A RESULT"
assert m1.y is m1.ly[0]
assert inst.y is inst.ly[0]
assert inst.y is inst.lly[0][0]
assert inst.y is inst.ty[0]
assert inst.y is inst.tty[0][0]
assert inst.y is inst.dy['y']
def test_member_method_inputs(self):
"""Test that module Members can be named as Method inputs, in which case the function will
......@@ -416,7 +423,6 @@ class T_module(unittest.TestCase):
assert m.y == 77
assert m.x == 1000
def test_member_input_flags(self):
"""Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of
Method inputs"""
......@@ -448,20 +454,30 @@ class T_module(unittest.TestCase):
m.f([3, 2])
assert numpy.all(v0 != v0_copy)
def test_sanity_check_mode(self):
"""Test that Module.make(self) can take the same list of Modes that function can, so we can
debug modules"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
def test_member_value(self):
"""Test that module Members of Value work correctly. As Result?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
M = Module()
x = T.dscalar()
M.y = T.value(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
m.y = 80
assert m.f(20) == 180
def test_member_constant(self):
"""Test that module Members of Constant work correctly.
As Result with more optimization?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
M = Module()
x = T.dscalar()
M.y = T.constant(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
try:
m.y = 77 #fail?
assert 0 #assign to constant should not have worked
except:
pass
assert m.f(20) == 100
def test_raise_NotImplemented(self):
c=Component()
......@@ -476,7 +492,7 @@ class T_module(unittest.TestCase):
self.assertRaises(NotImplementedError, c.get,"n")
self.assertRaises(NotImplementedError, c.set,"n",1)
def test_tuple_members(self):
def test_tuple_members():
M = Module()
M.a = (1,1)
......@@ -489,6 +505,60 @@ class T_module(unittest.TestCase):
assert isinstance(M.a, tuple)
def test_method_updates():
# updates work
M = Module()
M.x = T.dvector()
x = T.dvector()
xval= numpy.asarray([0, 0.5])
M.f = Method([x], M.x*4, updates={M.x:M.x * 2}, mode='FAST_COMPILE')
m = M.make(mode='FAST_RUN')
m.x = xval
m.f([9,9])
assert numpy.all(m.x == [0, 1])
assert numpy.all(xval == [0, 0.5])
# In(update) works
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4)
m = M.make()
m.f([9,9])
assert m.x is None
assert numpy.all(xval == [0, 1])
# when a result is listed explicitly and in an update, then there's a problem.
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4,
updates={M.x:M.x * 7})
try:
m = M.make()
assert False
except ValueError, e:
if str(e[0]).startswith('Result listed in both inputs and up'):
pass
else:
raise
def test_method_mode():
"""Test that Methods can override the module build mode"""
M = Module()
M.x = T.dvector()
M.f = Method([M.x], M.x*4, mode='FAST_COMPILE')
M.g = Method([M.x], M.x*4)
M.h = Method([M.x], M.x*4)
m = M.make(mode='FAST_RUN')
assert m.f.maker.mode != m.g.maker.mode
assert m.h.maker.mode == m.g.maker.mode
assert numpy.all(m.f([1,2]) == m.g([1,2]))
def test_pickle():
"""Test that a module can be pickled"""
M = Module()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论