提交 c104defb authored 作者: james@X40's avatar james@X40

more changes to Module, more tests

上级 c31223aa
......@@ -32,14 +32,14 @@ import function_module as F
from mode import default_mode
def join(*args):
def name_join(*args):
"""
Creates a string representation for the given names:
join('a', 'b', 'c') => 'a.b.c'
"""
return ".".join(arg for arg in args if arg)
def split(sym, n=-1):
def name_split(sym, n=-1):
"""
Gets the names from their joined representation
split('a.b.c') => ['a', 'b', 'c']
......@@ -55,7 +55,7 @@ def canonicalize(name):
[Fred: why we return the right type? Why int only?]
"""
if isinstance(name, str):
name = split(name)
name = name_split(name)
def convert(x):
try:
return int(x)
......@@ -63,7 +63,6 @@ def canonicalize(name):
return x
return map(convert, name)
class AllocationError(Exception):
"""
Exception raised when a Result has no associated storage.
......@@ -116,7 +115,7 @@ class Component(object):
else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
self.parent = parent
self.name = join(parent.name, name)
self.name = name_join(parent.name, name)
return self
def bound(self):
......@@ -303,29 +302,79 @@ class Member(_RComponent):
return memo[self.r].value
class Method(Component):
"""
Method is a declaration of a function. It contains inputs,
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them
without declaring them in the inputs, outputs or updates list.
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
"""
Method is a declaration of a function. It contains inputs,
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them
without declaring them in the inputs, outputs or updates list.
inputs, outputs or updates may be strings. In that case, they
will be resolved in the Composite which is the parent of this
Method.
Method builds a Function (same structure as a call to
theano.function)
"""
inputs = []
"""function inputs (see `compile.function`)
If Module members are named explicitly in this list, then they will not use shared storage.
Storage must be provided either via an `io.In` value argument, or at the point of the
function call.
"""
outputs=None
"""function outputs (see `compile.function`)"""
updates = {}
"""update expressions for module members
If this method should update the shared storage value for a Module member, then the
update expression must be given in this dictionary.
Keys in this dictionary must be members of the module graph--results for which this Method
will use the shared storage.
The value associated with each key should be a Result (or a string that can be resolved to
a Result) representing the computation of a new value for this shared storage after
each function call.
"""
mode=None
"""This will override the Module compilation mode for this Method"""
def __init__(self, inputs, outputs, updates = {}, mode=None, **kwupdates):
"""Initialize attributes
:param inputs: value for `Method.inputs`
:param outputs: value for `Method.outputs`
:param updates: value for `Method.updates`
:param kwupdates: additions to `updates`
:param mode: value for `Method.mode`
:type inputs: list of (str or `Result` or `io.In`)
:type outputs: None or str or `Result` or `io.Out` or list of (str or `Result` or
`io.Out`)
[TODO: remove references to kits, for they are not really
needed anymore]
:type updates: dict of `Result` or str -> `Result` or str
inputs, outputs or updates may be strings. In that case, they
will be resolved in the Composite which is the parent of this
Method.
:type kwupdates: extra updates
:type mode: None or any mode accepted by `compile.function`
Method builds a Function (same structure as a call to
theano.function)
"""
super(Method, self).__init__()
self.inputs = inputs
self.outputs = outputs
self.updates = dict(updates, **kwupdates)
self.kits = list(kits)
self.mode = mode
def bind(self, parent, name, dup_ok=True):
rval = super(Method, self).bind(parent, name, dup_ok=dup_ok)
......@@ -333,8 +382,12 @@ class Method(Component):
return rval
def resolve(self, name):
"""
Resolves the name of an input or output in the parent.
"""Return the Result corresponding to a given name
:param name: the name of a Result in the Module instance containing this Method
:type name: str
:rtype: `Result`
"""
if not self.bound():
raise ValueError('Trying to resolve a name on an unbound Method.')
......@@ -343,49 +396,47 @@ class Method(Component):
raise TypeError('Expected a Component with subtype Member or External.')
return result
def resolve_result(self, x, passthrough=(gof.Result)):
if isinstance(x, passthrough):
return x
elif isinstance(x, _RComponent):
return x.r
else:
return self.resolve(x).r
def resolve_all(self):
"""Convert all inputs, outputs, and updates specified as strings to Results.
def resolve_inputs(self):
if isinstance(self.inputs, (io.In, gof.Result, str)):
inputs = [self.inputs]
else:
inputs = list(self.inputs)
self.inputs = [self.resolve_result(input,
passthrough=(gof.Result, io.In)) for input in inputs]
def resolve_outputs(self):
if isinstance(self.outputs, (io.Out, gof.Result, str)):
output = self.outputs
self.outputs = self.resolve_result(output,
passthrough=(gof.Result, io.Out))
else:
outputs = list(self.outputs)
self.outputs = [self.resolve_result(output,
passthrough=(gof.Result, io.Out)) for output in outputs]
This works by searching the containing Module for Result attributes by these names.
"""
def resolve_result(x, passthrough=(gof.Result)):
if isinstance(x, passthrough):
return x
elif isinstance(x, _RComponent):
return x.r
else:
return self.resolve(x).r
def resolve_updates(self):
updates = self.updates
self.updates = {}
for k, v in updates.iteritems():
k, v = self.resolve_result(k), self.resolve_result(v)
self.updates[k] = v
def resolve_inputs():
if isinstance(self.inputs, (io.In, gof.Result, str)):
inputs = [self.inputs]
else:
inputs = list(self.inputs)
self.inputs = [resolve_result(input,
passthrough=(gof.Result, io.In)) for input in inputs]
def resolve_outputs():
if isinstance(self.outputs, (io.Out, gof.Result, str, None)):
output = self.outputs
self.outputs = resolve_result(output,
passthrough=(gof.Result, io.Out, None))
else:
outputs = list(self.outputs)
self.outputs = [resolve_result(output,
passthrough=(gof.Result, io.Out)) for output in outputs]
def resolve_all(self):
"""
Resolves all inputs, outputs and updates that were given as
strings so that the fields contain the corresponding Result
instances instead.
"""
self.resolve_inputs()
self.resolve_outputs()
self.resolve_updates()
def resolve_updates():
updates = self.updates
self.updates = {}
for k, v in updates.iteritems():
k, v = resolve_result(k), resolve_result(v)
self.updates[k] = v
resolve_inputs()
resolve_outputs()
resolve_updates()
def allocate(self, memo):
"""
......@@ -394,13 +445,21 @@ class Method(Component):
return None
def build(self, mode, memo, allocate_all = False):
"""
Produces a function. If allocate_all is True, storage will be
allocated for all needed Results, even if there is no
"""Compile a function for this Method.
:param allocate_all: if True, storage will be
allocated for all needed Results even if there is no
associated storage for them in the memo. If allocate_all is
False, storage will only be allocated for Results that are
reachable from the inputs list.
:returns: a function that implements this method
:rtype: `Function` instance
"""
if self in memo:
return memo[self]
self.resolve_all() # resolve all so we don't have to mess with strings
def get_storage(r, require = False):
# If require is True, we can only get storage from the memo.
......@@ -430,7 +489,7 @@ class Method(Component):
else:
raise TypeError(input, type(input))
# Deal with updates
# Deal with updates to shared storage
for k, v in self.updates.iteritems():
assert isinstance(k, gof.Result)
assert isinstance(v, gof.Result)
......@@ -441,7 +500,7 @@ class Method(Component):
if input.result == k:
input_k = input
print 'METHOD UPDATE', k, v, input_k
#print 'METHOD UPDATE', k, v, input_k
if input_k is None:
# this is an implicit input,
# use shared storage
......@@ -452,10 +511,9 @@ class Method(Component):
mutable=True)
inputs.append(input_k)
else:
# this was an explicit input
# don't use shared storage
input_k.update=v
input_k.mutable=True
raise ValueError(('Result listed in both inputs and updates.'
' Use inputs to use your own storage, use updates to '
'work on module-shared storage'), k)
outputs = self.outputs
_inputs = [x.result for x in inputs]
......@@ -478,7 +536,10 @@ class Method(Component):
assert type(storage) is io.In
inputs.append(storage)
return F.function(inputs, outputs, mode)
effective_mode = mode if self.mode is None else self.mode
rval = F.function(inputs, outputs, effective_mode)
memo[self] = rval
return rval
def pretty(self, **kwargs):
self.resolve_all()
......@@ -507,10 +568,10 @@ class Method(Component):
def dup(self):
self.resolve_all()
return self.__class__(list(self.inputs),
list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
dict(self.updates),
list(self.kits))
return self.__class__(inputs=list(self.inputs),
outputs=list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
updates=dict(self.updates),
mode=self.mode)
def __call__(self, *args, **kwargs):
raise TypeError("'Method' object is not callable"
......
......@@ -267,6 +267,7 @@ class T_module(unittest.TestCase):
def get_element(i):
return [i.x,i.lx[0],i.tx[0],i.dx['x'],i.llx[0][0], i.llx[1][0], i.ltx[0][0], i.ldx[0]['x'], i.tlx[0][0], i.tlx[0][0], i.tdx[0]['x'], i.dlx['x'][0], i.dtx['x'][0], i.ddx['x']['x']]
m1=Module()
m2=Module()
x=T.dscalar()
......@@ -393,7 +394,13 @@ class T_module(unittest.TestCase):
assert isinstance(inst.dy['y'],theano.compile.function_module.Function)
assert isinstance(inst.tty[0][0],theano.compile.function_module.Function)
print >> sys.stderr, "MODULE TEST IMPLEMENTED BUT WE DON'T KNOW WHAT WE WANT AS A RESULT"
assert m1.y is m1.ly[0]
assert inst.y is inst.ly[0]
assert inst.y is inst.lly[0][0]
assert inst.y is inst.ty[0]
assert inst.y is inst.tty[0][0]
assert inst.y is inst.dy['y']
def test_member_method_inputs(self):
"""Test that module Members can be named as Method inputs, in which case the function will
......@@ -416,7 +423,6 @@ class T_module(unittest.TestCase):
assert m.y == 77
assert m.x == 1000
def test_member_input_flags(self):
"""Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of
Method inputs"""
......@@ -448,20 +454,30 @@ class T_module(unittest.TestCase):
m.f([3, 2])
assert numpy.all(v0 != v0_copy)
def test_sanity_check_mode(self):
"""Test that Module.make(self) can take the same list of Modes that function can, so we can
debug modules"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
def test_member_value(self):
"""Test that module Members of Value work correctly. As Result?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
M = Module()
x = T.dscalar()
M.y = T.value(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
m.y = 80
assert m.f(20) == 180
def test_member_constant(self):
"""Test that module Members of Constant work correctly.
As Result with more optimization?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
M = Module()
x = T.dscalar()
M.y = T.constant(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
try:
m.y = 77 #fail?
assert 0 #assign to constant should not have worked
except:
pass
assert m.f(20) == 100
def test_raise_NotImplemented(self):
c=Component()
......@@ -476,18 +492,72 @@ class T_module(unittest.TestCase):
self.assertRaises(NotImplementedError, c.get,"n")
self.assertRaises(NotImplementedError, c.set,"n",1)
def test_tuple_members(self):
def test_tuple_members():
M = Module()
M.a = (1,1)
assert isinstance(M.a, tuple)
M = Module()
M.a = (1,1)
assert isinstance(M.a, tuple)
class Temp(Module):
def __init__(self):
self.a = (1,1)
M = Temp()
assert isinstance(M.a, tuple)
class Temp(Module):
def __init__(self):
self.a = (1,1)
M = Temp()
assert isinstance(M.a, tuple)
def test_method_updates():
# updates work
M = Module()
M.x = T.dvector()
x = T.dvector()
xval= numpy.asarray([0, 0.5])
M.f = Method([x], M.x*4, updates={M.x:M.x * 2}, mode='FAST_COMPILE')
m = M.make(mode='FAST_RUN')
m.x = xval
m.f([9,9])
assert numpy.all(m.x == [0, 1])
assert numpy.all(xval == [0, 0.5])
# In(update) works
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4)
m = M.make()
m.f([9,9])
assert m.x is None
assert numpy.all(xval == [0, 1])
# when a result is listed explicitly and in an update, then there's a problem.
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4,
updates={M.x:M.x * 7})
try:
m = M.make()
assert False
except ValueError, e:
if str(e[0]).startswith('Result listed in both inputs and up'):
pass
else:
raise
def test_method_mode():
"""Test that Methods can override the module build mode"""
M = Module()
M.x = T.dvector()
M.f = Method([M.x], M.x*4, mode='FAST_COMPILE')
M.g = Method([M.x], M.x*4)
M.h = Method([M.x], M.x*4)
m = M.make(mode='FAST_RUN')
assert m.f.maker.mode != m.g.maker.mode
assert m.h.maker.mode == m.g.maker.mode
assert numpy.all(m.f([1,2]) == m.g([1,2]))
def test_pickle():
"""Test that a module can be pickled"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论