提交 56718bc7 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

removed bind, bound, dup and resolve from the Component system

上级 18deb7a0
...@@ -98,35 +98,12 @@ def name_split(sym, n=-1): ...@@ -98,35 +98,12 @@ def name_split(sym, n=-1):
""" """
return sym.split('.', n) return sym.split('.', n)
def canonicalize(name):
"""
Splits the name and converts each name to the
right type (e.g. "2" -> 2)
[Fred: why we return the right type? Why int only?]
"""
if isinstance(name, str):
name = name_split(name)
def convert(x):
try:
return int(x)
except (ValueError, TypeError):
return x
return map(convert, name)
class AllocationError(Exception): class AllocationError(Exception):
""" """
Exception raised when a Result has no associated storage. Exception raised when a Result has no associated storage.
""" """
pass pass
class BindError(Exception):
"""
Exception raised when a Component is already bound and we try to
bound it again.
see Component.bind() help for more information.
"""
pass
class Component(object): class Component(object):
""" """
Base class for the various kinds of components which are not Base class for the various kinds of components which are not
...@@ -138,43 +115,6 @@ class Component(object): ...@@ -138,43 +115,6 @@ class Component(object):
self.__dict__['_name'] = '' self.__dict__['_name'] = ''
self.__dict__['parent'] = None self.__dict__['parent'] = None
def bind(self, parent, name, dup_ok=True):
"""
Marks this component as belonging to the parent (the parent is
typically a Composite instance). The component can be accessed
through the parent with the specified name. If dup_ok is True
and that this Component is already bound, a duplicate of the
component will be made using the dup() method and the
duplicate will be bound instead of this Component. If dup_ok
is False and this Component is already bound, a BindError wil
be raised.
bind() returns the Component instance which has been bound to
the parent. For an unbound instance, this will usually be
self.
"""
if self.bound():
if dup_ok:
try:
return self.dup().bind(parent, name, False)
except BindError, e:
e.args = (e.args[0] +
' ; This seems to have been caused by an implementation of dup'
' that keeps the previous binding (%s)' % self.dup,) + e.args[1:]
raise
else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
self.parent = parent
self.name = name_join(parent.name, name)
return self
def bound(self):
"""
Returns True if this Component instance is bound to a
Composite.
"""
return self.parent is not None
def allocate(self, memo): def allocate(self, memo):
""" """
Populates the memo dictionary with gof.Result -> io.In Populates the memo dictionary with gof.Result -> io.In
...@@ -238,17 +178,6 @@ class Component(object): ...@@ -238,17 +178,6 @@ class Component(object):
""" """
raise NotImplementedError raise NotImplementedError
def dup(self):
"""
Returns a Component identical to this one, but which is not
bound to anything and does not retain the original's name.
This is useful to make Components that are slight variations
of another or to have Components that behave identically but
are accessed in different ways.
"""
raise NotImplementedError()
def __get_name__(self): def __get_name__(self):
""" """
Getter for self.name Getter for self.name
...@@ -297,9 +226,6 @@ class _RComponent(Component): ...@@ -297,9 +226,6 @@ class _RComponent(Component):
rval = '%s :: %s' % (self.__class__.__name__, self.r.type) rval = '%s :: %s' % (self.__class__.__name__, self.r.type)
return rval return rval
def dup(self):
return self.__class__(self.r)
class External(_RComponent): class External(_RComponent):
""" """
...@@ -345,8 +271,8 @@ class Member(_RComponent): ...@@ -345,8 +271,8 @@ class Member(_RComponent):
rval = gof.Container(r, storage = [getattr(r, 'data', None)], rval = gof.Container(r, storage = [getattr(r, 'data', None)],
readonly=isinstance(r, gof.Constant)) readonly=isinstance(r, gof.Constant))
memo[r] = io.In(result=r, memo[r] = io.In(result=r,
value=rval, value=rval,
mutable=False) mutable=False)
return memo[r] return memo[r]
def build(self, mode, memo): def build(self, mode, memo):
...@@ -430,28 +356,6 @@ class Method(Component): ...@@ -430,28 +356,6 @@ class Method(Component):
self.updates = dict(updates, **kwupdates) self.updates = dict(updates, **kwupdates)
self.mode = mode self.mode = mode
def bind(self, parent, name, dup_ok=True):
"""Implement`Component.bind`"""
rval = super(Method, self).bind(parent, name, dup_ok=dup_ok)
rval.resolve_all()
return rval
def resolve(self, name):
"""Return the Result corresponding to a given name
:param name: the name of a Result in the Module to which this Method is bound
:type name: str
:rtype: `Result`
"""
if not self.bound():
raise ValueError('Trying to resolve a name on an unbound Method.')
result = self.parent.resolve(name)
if not hasattr(result, 'r'):
raise TypeError('Expected a Component with subtype Member or External.')
return result
def resolve_all(self): def resolve_all(self):
"""Convert all inputs, outputs, and updates specified as strings to Results. """Convert all inputs, outputs, and updates specified as strings to Results.
...@@ -463,7 +367,8 @@ class Method(Component): ...@@ -463,7 +367,8 @@ class Method(Component):
elif isinstance(x, _RComponent): elif isinstance(x, _RComponent):
return x.r return x.r
else: else:
return self.resolve(x).r raise Exception('damnit looks like we need this')
# return self.resolve(x).r
def resolve_inputs(): def resolve_inputs():
if isinstance(self.inputs, (io.In, gof.Result, str)): if isinstance(self.inputs, (io.In, gof.Result, str)):
...@@ -633,13 +538,6 @@ class Method(Component): ...@@ -633,13 +538,6 @@ class Method(Component):
"; " if self.updates else "", "; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems())) ", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
def dup(self):
self.resolve_all()
return self.__class__(inputs=list(self.inputs),
outputs=list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
updates=dict(self.updates),
mode=self.mode)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise TypeError("'Method' object is not callable" raise TypeError("'Method' object is not callable"
" (Hint: compile your module first. See Component.make())") " (Hint: compile your module first. See Component.make())")
...@@ -797,23 +695,6 @@ class ComponentList(Composite): ...@@ -797,23 +695,6 @@ class ComponentList(Composite):
raise TypeError(c, type(c)) raise TypeError(c, type(c))
self.append(c) self.append(c)
def resolve(self, name):
# resolves # to the #th number in the list
# resolves name string to parent.resolve(name)
# TODO: eliminate canonicalize
name = canonicalize(name)
try:
item = self.get(name[0])
except TypeError:
# if name[0] is not a number, we check in the parent
if not self.bound():
raise TypeError('Cannot resolve a non-integer name on an unbound ComponentList.')
return self.parent.resolve(name)
if len(name) > 1:
# TODO: eliminate
return item.resolve(name[1:])
return item
def components(self): def components(self):
return self._components return self._components
...@@ -821,8 +702,12 @@ class ComponentList(Composite): ...@@ -821,8 +702,12 @@ class ComponentList(Composite):
return enumerate(self._components) return enumerate(self._components)
def build(self, mode, memo): def build(self, mode, memo):
if self in memo:
return memo[self]
builds = [c.build(mode, memo) for c in self._components] builds = [c.build(mode, memo) for c in self._components]
return ComponentListInstance(self, builds) rval = ComponentListInstance(self, builds)
memo[self] = rval
return rval
def get(self, item): def get(self, item):
return self._components[item] return self._components[item]
...@@ -832,7 +717,8 @@ class ComponentList(Composite): ...@@ -832,7 +717,8 @@ class ComponentList(Composite):
value = Member(value) value = Member(value)
elif not isinstance(value, Component): elif not isinstance(value, Component):
raise TypeError('ComponentList may only contain Components.', value, type(value)) raise TypeError('ComponentList may only contain Components.', value, type(value))
value = value.bind(self, str(item)) #value = value.bind(self, str(item))
value.name = name_join(self.name, str(item))
self._components[item] = value self._components[item] = value
def append(self, c): def append(self, c):
...@@ -883,9 +769,6 @@ class ComponentList(Composite): ...@@ -883,9 +769,6 @@ class ComponentList(Composite):
for i, member in enumerate(self._components): for i, member in enumerate(self._components):
member.name = '%s.%i' % (name, i) member.name = '%s.%i' % (name, i)
def dup(self):
return self.__class__(*[c.dup() for c in self._components])
def default_initialize(self, init = {}, **kwinit): def default_initialize(self, init = {}, **kwinit):
for k, initv in dict(init, **kwinit).iteritems(): for k, initv in dict(init, **kwinit).iteritems():
...@@ -935,14 +818,6 @@ class ComponentDict(Composite): ...@@ -935,14 +818,6 @@ class ComponentDict(Composite):
self.__dict__['_components'] = components self.__dict__['_components'] = components
def resolve(self, name):
name = canonicalize(name)
item = self.get(name[0])
if len(name) > 1:
return item.resolve(name[1:])
return item
def components(self): def components(self):
return self._components.itervalues() return self._components.itervalues()
...@@ -950,11 +825,14 @@ class ComponentDict(Composite): ...@@ -950,11 +825,14 @@ class ComponentDict(Composite):
return self._components.iteritems() return self._components.iteritems()
def build(self, mode, memo): def build(self, mode, memo):
if self in memo:
return self[memo]
inst = self.InstanceType(self, {}) inst = self.InstanceType(self, {})
for name, c in self._components.iteritems(): for name, c in self._components.iteritems():
x = c.build(mode, memo) x = c.build(mode, memo)
if x is not None: if x is not None:
inst[name] = x inst[name] = x
memo[self] = inst
return inst return inst
def get(self, item): def get(self, item):
...@@ -963,7 +841,8 @@ class ComponentDict(Composite): ...@@ -963,7 +841,8 @@ class ComponentDict(Composite):
def set(self, item, value): def set(self, item, value):
if not isinstance(value, Component): if not isinstance(value, Component):
raise TypeError('ComponentDict may only contain Components.', value, type(value)) raise TypeError('ComponentDict may only contain Components.', value, type(value))
value = value.bind(self, item) #value = value.bind(self, item)
value.name = name_join(self.name, str(item))
self._components[item] = value self._components[item] = value
def pretty(self, **kwargs): def pretty(self, **kwargs):
......
...@@ -492,14 +492,76 @@ class T_module(unittest.TestCase): ...@@ -492,14 +492,76 @@ class T_module(unittest.TestCase):
self.assertRaises(NotImplementedError, c.allocate,"") self.assertRaises(NotImplementedError, c.allocate,"")
self.assertRaises(NotImplementedError, c.build,"","") self.assertRaises(NotImplementedError, c.build,"","")
self.assertRaises(NotImplementedError, c.pretty) self.assertRaises(NotImplementedError, c.pretty)
self.assertRaises(NotImplementedError, c.dup)
c=Composite() c=Composite()
self.assertRaises(NotImplementedError, c.resolve,"n")
self.assertRaises(NotImplementedError, c.components) self.assertRaises(NotImplementedError, c.components)
self.assertRaises(NotImplementedError, c.components_map) self.assertRaises(NotImplementedError, c.components_map)
self.assertRaises(NotImplementedError, c.get,"n") self.assertRaises(NotImplementedError, c.get,"n")
self.assertRaises(NotImplementedError, c.set,"n",1) self.assertRaises(NotImplementedError, c.set,"n",1)
def test_multiple_references():
class A(theano.Module):
def __init__(self, sub_module):
super(A, self).__init__()
self.sub_module = sub_module
def _instance_initialize(self, obj):
print 'Initializing A'
class B(theano.Module):
def __init__(self, sub_module):
super(B, self).__init__()
self.sub_module = sub_module
def _instance_initialize(self, obj):
print 'Initializing B'
class C(theano.Module):
def __init__(self):
super(C, self).__init__()
self.value = theano.tensor.scalar()
def _instance_initialize(self, obj):
print 'Initializing C'
obj.value = 0
def _instance_set(self, obj, value):
print 'Setting C'
obj.value = value
class D(theano.Module):
def __init__(self):
super(D, self).__init__()
self.c = C()
self.a = A(self.c)
self.b = B(self.c)
# Workaround for bug exhibited in a previous email.
self.bug = theano.tensor.scalar()
def _instance_initialize(self, obj):
print 'Initializing D'
obj.c.set(1)
d = D()
d_instance = d.make(mode = 'FAST_COMPILE')
assert d_instance.c.value == 1
assert d_instance.a.sub_module.value == 1
assert d_instance.b.sub_module.value == 1
def test_tuple_members(): def test_tuple_members():
M = Module() M = Module()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论