提交 c104defb authored 作者: james@X40's avatar james@X40

more changes to Module, more tests

上级 c31223aa
...@@ -32,14 +32,14 @@ import function_module as F ...@@ -32,14 +32,14 @@ import function_module as F
from mode import default_mode from mode import default_mode
def join(*args): def name_join(*args):
""" """
Creates a string representation for the given names: Creates a string representation for the given names:
join('a', 'b', 'c') => 'a.b.c' join('a', 'b', 'c') => 'a.b.c'
""" """
return ".".join(arg for arg in args if arg) return ".".join(arg for arg in args if arg)
def split(sym, n=-1): def name_split(sym, n=-1):
""" """
Gets the names from their joined representation Gets the names from their joined representation
split('a.b.c') => ['a', 'b', 'c'] split('a.b.c') => ['a', 'b', 'c']
...@@ -55,7 +55,7 @@ def canonicalize(name): ...@@ -55,7 +55,7 @@ def canonicalize(name):
[Fred: why we return the right type? Why int only?] [Fred: why we return the right type? Why int only?]
""" """
if isinstance(name, str): if isinstance(name, str):
name = split(name) name = name_split(name)
def convert(x): def convert(x):
try: try:
return int(x) return int(x)
...@@ -63,7 +63,6 @@ def canonicalize(name): ...@@ -63,7 +63,6 @@ def canonicalize(name):
return x return x
return map(convert, name) return map(convert, name)
class AllocationError(Exception): class AllocationError(Exception):
""" """
Exception raised when a Result has no associated storage. Exception raised when a Result has no associated storage.
...@@ -116,7 +115,7 @@ class Component(object): ...@@ -116,7 +115,7 @@ class Component(object):
else: else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name)) raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
self.parent = parent self.parent = parent
self.name = join(parent.name, name) self.name = name_join(parent.name, name)
return self return self
def bound(self): def bound(self):
...@@ -303,29 +302,79 @@ class Member(_RComponent): ...@@ -303,29 +302,79 @@ class Member(_RComponent):
return memo[self.r].value return memo[self.r].value
class Method(Component): class Method(Component):
"""
Method is a declaration of a function. It contains inputs,
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them
without declaring them in the inputs, outputs or updates list.
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates): inputs, outputs or updates may be strings. In that case, they
""" will be resolved in the Composite which is the parent of this
Method is a declaration of a function. It contains inputs, Method.
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them Method builds a Function (same structure as a call to
without declaring them in the inputs, outputs or updates list. theano.function)
"""
inputs = []
"""function inputs (see `compile.function`)
If Module members are named explicitly in this list, then they will not use shared storage.
Storage must be provided either via an `io.In` value argument, or at the point of the
function call.
"""
outputs=None
"""function outputs (see `compile.function`)"""
updates = {}
"""update expressions for module members
If this method should update the shared storage value for a Module member, then the
update expression must be given in this dictionary.
Keys in this dictionary must be members of the module graph--results for which this Method
will use the shared storage.
The value associated with each key should be a Result (or a string that can be resolved to
a Result) representing the computation of a new value for this shared storage after
each function call.
"""
mode=None
"""This will override the Module compilation mode for this Method"""
def __init__(self, inputs, outputs, updates = {}, mode=None, **kwupdates):
"""Initialize attributes
:param inputs: value for `Method.inputs`
:param outputs: value for `Method.outputs`
:param updates: value for `Method.updates`
:param kwupdates: additions to `updates`
:param mode: value for `Method.mode`
:type inputs: list of (str or `Result` or `io.In`)
:type outputs: None or str or `Result` or `io.Out` or list of (str or `Result` or
`io.Out`)
[TODO: remove references to kits, for they are not really :type updates: dict of `Result` or str -> `Result` or str
needed anymore]
inputs, outputs or updates may be strings. In that case, they :type kwupdates: extra updates
will be resolved in the Composite which is the parent of this
Method. :type mode: None or any mode accepted by `compile.function`
Method builds a Function (same structure as a call to
theano.function)
""" """
super(Method, self).__init__() super(Method, self).__init__()
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.updates = dict(updates, **kwupdates) self.updates = dict(updates, **kwupdates)
self.kits = list(kits) self.mode = mode
def bind(self, parent, name, dup_ok=True): def bind(self, parent, name, dup_ok=True):
rval = super(Method, self).bind(parent, name, dup_ok=dup_ok) rval = super(Method, self).bind(parent, name, dup_ok=dup_ok)
...@@ -333,8 +382,12 @@ class Method(Component): ...@@ -333,8 +382,12 @@ class Method(Component):
return rval return rval
def resolve(self, name): def resolve(self, name):
""" """Return the Result corresponding to a given name
Resolves the name of an input or output in the parent.
:param name: the name of a Result in the Module instance containing this Method
:type name: str
:rtype: `Result`
""" """
if not self.bound(): if not self.bound():
raise ValueError('Trying to resolve a name on an unbound Method.') raise ValueError('Trying to resolve a name on an unbound Method.')
...@@ -343,49 +396,47 @@ class Method(Component): ...@@ -343,49 +396,47 @@ class Method(Component):
raise TypeError('Expected a Component with subtype Member or External.') raise TypeError('Expected a Component with subtype Member or External.')
return result return result
def resolve_result(self, x, passthrough=(gof.Result)): def resolve_all(self):
if isinstance(x, passthrough): """Convert all inputs, outputs, and updates specified as strings to Results.
return x
elif isinstance(x, _RComponent):
return x.r
else:
return self.resolve(x).r
def resolve_inputs(self): This works by searching the containing Module for Result attributes by these names.
if isinstance(self.inputs, (io.In, gof.Result, str)): """
inputs = [self.inputs] def resolve_result(x, passthrough=(gof.Result)):
else: if isinstance(x, passthrough):
inputs = list(self.inputs) return x
self.inputs = [self.resolve_result(input, elif isinstance(x, _RComponent):
passthrough=(gof.Result, io.In)) for input in inputs] return x.r
else:
def resolve_outputs(self): return self.resolve(x).r
if isinstance(self.outputs, (io.Out, gof.Result, str)):
output = self.outputs
self.outputs = self.resolve_result(output,
passthrough=(gof.Result, io.Out))
else:
outputs = list(self.outputs)
self.outputs = [self.resolve_result(output,
passthrough=(gof.Result, io.Out)) for output in outputs]
def resolve_updates(self): def resolve_inputs():
updates = self.updates if isinstance(self.inputs, (io.In, gof.Result, str)):
self.updates = {} inputs = [self.inputs]
for k, v in updates.iteritems(): else:
k, v = self.resolve_result(k), self.resolve_result(v) inputs = list(self.inputs)
self.updates[k] = v self.inputs = [resolve_result(input,
passthrough=(gof.Result, io.In)) for input in inputs]
def resolve_outputs():
if isinstance(self.outputs, (io.Out, gof.Result, str, None)):
output = self.outputs
self.outputs = resolve_result(output,
passthrough=(gof.Result, io.Out, None))
else:
outputs = list(self.outputs)
self.outputs = [resolve_result(output,
passthrough=(gof.Result, io.Out)) for output in outputs]
def resolve_all(self): def resolve_updates():
""" updates = self.updates
Resolves all inputs, outputs and updates that were given as self.updates = {}
strings so that the fields contain the corresponding Result for k, v in updates.iteritems():
instances instead. k, v = resolve_result(k), resolve_result(v)
""" self.updates[k] = v
self.resolve_inputs()
self.resolve_outputs()
self.resolve_updates()
resolve_inputs()
resolve_outputs()
resolve_updates()
def allocate(self, memo): def allocate(self, memo):
""" """
...@@ -394,13 +445,21 @@ class Method(Component): ...@@ -394,13 +445,21 @@ class Method(Component):
return None return None
def build(self, mode, memo, allocate_all = False): def build(self, mode, memo, allocate_all = False):
""" """Compile a function for this Method.
Produces a function. If allocate_all is True, storage will be
allocated for all needed Results, even if there is no :param allocate_all: if True, storage will be
allocated for all needed Results even if there is no
associated storage for them in the memo. If allocate_all is associated storage for them in the memo. If allocate_all is
False, storage will only be allocated for Results that are False, storage will only be allocated for Results that are
reachable from the inputs list. reachable from the inputs list.
:returns: a function that implements this method
:rtype: `Function` instance
""" """
if self in memo:
return memo[self]
self.resolve_all() # resolve all so we don't have to mess with strings self.resolve_all() # resolve all so we don't have to mess with strings
def get_storage(r, require = False): def get_storage(r, require = False):
# If require is True, we can only get storage from the memo. # If require is True, we can only get storage from the memo.
...@@ -430,7 +489,7 @@ class Method(Component): ...@@ -430,7 +489,7 @@ class Method(Component):
else: else:
raise TypeError(input, type(input)) raise TypeError(input, type(input))
# Deal with updates # Deal with updates to shared storage
for k, v in self.updates.iteritems(): for k, v in self.updates.iteritems():
assert isinstance(k, gof.Result) assert isinstance(k, gof.Result)
assert isinstance(v, gof.Result) assert isinstance(v, gof.Result)
...@@ -441,7 +500,7 @@ class Method(Component): ...@@ -441,7 +500,7 @@ class Method(Component):
if input.result == k: if input.result == k:
input_k = input input_k = input
print 'METHOD UPDATE', k, v, input_k #print 'METHOD UPDATE', k, v, input_k
if input_k is None: if input_k is None:
# this is an implicit input, # this is an implicit input,
# use shared storage # use shared storage
...@@ -452,10 +511,9 @@ class Method(Component): ...@@ -452,10 +511,9 @@ class Method(Component):
mutable=True) mutable=True)
inputs.append(input_k) inputs.append(input_k)
else: else:
# this was an explicit input raise ValueError(('Result listed in both inputs and updates.'
# don't use shared storage ' Use inputs to use your own storage, use updates to '
input_k.update=v 'work on module-shared storage'), k)
input_k.mutable=True
outputs = self.outputs outputs = self.outputs
_inputs = [x.result for x in inputs] _inputs = [x.result for x in inputs]
...@@ -478,7 +536,10 @@ class Method(Component): ...@@ -478,7 +536,10 @@ class Method(Component):
assert type(storage) is io.In assert type(storage) is io.In
inputs.append(storage) inputs.append(storage)
return F.function(inputs, outputs, mode) effective_mode = mode if self.mode is None else self.mode
rval = F.function(inputs, outputs, effective_mode)
memo[self] = rval
return rval
def pretty(self, **kwargs): def pretty(self, **kwargs):
self.resolve_all() self.resolve_all()
...@@ -507,10 +568,10 @@ class Method(Component): ...@@ -507,10 +568,10 @@ class Method(Component):
def dup(self): def dup(self):
self.resolve_all() self.resolve_all()
return self.__class__(list(self.inputs), return self.__class__(inputs=list(self.inputs),
list(self.outputs) if isinstance(self.outputs, list) else self.outputs, outputs=list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
dict(self.updates), updates=dict(self.updates),
list(self.kits)) mode=self.mode)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise TypeError("'Method' object is not callable" raise TypeError("'Method' object is not callable"
......
...@@ -267,6 +267,7 @@ class T_module(unittest.TestCase): ...@@ -267,6 +267,7 @@ class T_module(unittest.TestCase):
def get_element(i): def get_element(i):
return [i.x,i.lx[0],i.tx[0],i.dx['x'],i.llx[0][0], i.llx[1][0], i.ltx[0][0], i.ldx[0]['x'], i.tlx[0][0], i.tlx[0][0], i.tdx[0]['x'], i.dlx['x'][0], i.dtx['x'][0], i.ddx['x']['x']] return [i.x,i.lx[0],i.tx[0],i.dx['x'],i.llx[0][0], i.llx[1][0], i.ltx[0][0], i.ldx[0]['x'], i.tlx[0][0], i.tlx[0][0], i.tdx[0]['x'], i.dlx['x'][0], i.dtx['x'][0], i.ddx['x']['x']]
m1=Module() m1=Module()
m2=Module() m2=Module()
x=T.dscalar() x=T.dscalar()
...@@ -393,7 +394,13 @@ class T_module(unittest.TestCase): ...@@ -393,7 +394,13 @@ class T_module(unittest.TestCase):
assert isinstance(inst.dy['y'],theano.compile.function_module.Function) assert isinstance(inst.dy['y'],theano.compile.function_module.Function)
assert isinstance(inst.tty[0][0],theano.compile.function_module.Function) assert isinstance(inst.tty[0][0],theano.compile.function_module.Function)
print >> sys.stderr, "MODULE TEST IMPLEMENTED BUT WE DON'T KNOW WHAT WE WANT AS A RESULT"
assert m1.y is m1.ly[0]
assert inst.y is inst.ly[0]
assert inst.y is inst.lly[0][0]
assert inst.y is inst.ty[0]
assert inst.y is inst.tty[0][0]
assert inst.y is inst.dy['y']
def test_member_method_inputs(self): def test_member_method_inputs(self):
"""Test that module Members can be named as Method inputs, in which case the function will """Test that module Members can be named as Method inputs, in which case the function will
...@@ -416,7 +423,6 @@ class T_module(unittest.TestCase): ...@@ -416,7 +423,6 @@ class T_module(unittest.TestCase):
assert m.y == 77 assert m.y == 77
assert m.x == 1000 assert m.x == 1000
def test_member_input_flags(self): def test_member_input_flags(self):
"""Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of """Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of
Method inputs""" Method inputs"""
...@@ -448,20 +454,30 @@ class T_module(unittest.TestCase): ...@@ -448,20 +454,30 @@ class T_module(unittest.TestCase):
m.f([3, 2]) m.f([3, 2])
assert numpy.all(v0 != v0_copy) assert numpy.all(v0 != v0_copy)
def test_sanity_check_mode(self):
"""Test that Module.make(self) can take the same list of Modes that function can, so we can
debug modules"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED"
def test_member_value(self): def test_member_value(self):
"""Test that module Members of Value work correctly. As Result?""" """Test that module Members of Value work correctly. As Result?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED" M = Module()
x = T.dscalar()
M.y = T.value(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
m.y = 80
assert m.f(20) == 180
def test_member_constant(self): def test_member_constant(self):
"""Test that module Members of Constant work correctly. """Test that module Members of Constant work correctly.
As Result with more optimization?""" As Result with more optimization?"""
print >> sys.stderr, "WARNING MODULE TEST NOT IMPLEMENTED" M = Module()
x = T.dscalar()
M.y = T.constant(40)
M.f = Method([x], x + 2 * M.y)
m = M.make()
try:
m.y = 77 #fail?
assert 0 #assign to constant should not have worked
except:
pass
assert m.f(20) == 100
def test_raise_NotImplemented(self): def test_raise_NotImplemented(self):
c=Component() c=Component()
...@@ -476,18 +492,72 @@ class T_module(unittest.TestCase): ...@@ -476,18 +492,72 @@ class T_module(unittest.TestCase):
self.assertRaises(NotImplementedError, c.get,"n") self.assertRaises(NotImplementedError, c.get,"n")
self.assertRaises(NotImplementedError, c.set,"n",1) self.assertRaises(NotImplementedError, c.set,"n",1)
def test_tuple_members(self): def test_tuple_members():
M = Module() M = Module()
M.a = (1,1) M.a = (1,1)
assert isinstance(M.a, tuple) assert isinstance(M.a, tuple)
class Temp(Module):
def __init__(self):
self.a = (1,1)
M = Temp()
assert isinstance(M.a, tuple)
class Temp(Module):
def __init__(self):
self.a = (1,1)
M = Temp()
assert isinstance(M.a, tuple)
def test_method_updates():
# updates work
M = Module()
M.x = T.dvector()
x = T.dvector()
xval= numpy.asarray([0, 0.5])
M.f = Method([x], M.x*4, updates={M.x:M.x * 2}, mode='FAST_COMPILE')
m = M.make(mode='FAST_RUN')
m.x = xval
m.f([9,9])
assert numpy.all(m.x == [0, 1])
assert numpy.all(xval == [0, 0.5])
# In(update) works
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4)
m = M.make()
m.f([9,9])
assert m.x is None
assert numpy.all(xval == [0, 1])
# when a result is listed explicitly and in an update, then there's a problem.
M = Module()
M.x = T.dvector()
x = T.dvector()
M.f = Method([x, io.In(M.x, value=xval, update=M.x*2)], M.x*4,
updates={M.x:M.x * 7})
try:
m = M.make()
assert False
except ValueError, e:
if str(e[0]).startswith('Result listed in both inputs and up'):
pass
else:
raise
def test_method_mode():
"""Test that Methods can override the module build mode"""
M = Module()
M.x = T.dvector()
M.f = Method([M.x], M.x*4, mode='FAST_COMPILE')
M.g = Method([M.x], M.x*4)
M.h = Method([M.x], M.x*4)
m = M.make(mode='FAST_RUN')
assert m.f.maker.mode != m.g.maker.mode
assert m.h.maker.mode == m.g.maker.mode
assert numpy.all(m.f([1,2]) == m.g([1,2]))
def test_pickle(): def test_pickle():
"""Test that a module can be pickled""" """Test that a module can be pickled"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论