提交 56718bc7 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

removed bind, bound, dup and resolve from the Component system

上级 18deb7a0
......@@ -98,35 +98,12 @@ def name_split(sym, n=-1):
"""
return sym.split('.', n)
def canonicalize(name):
"""
Splits the name and converts each name to the
right type (e.g. "2" -> 2)
[Fred: why we return the right type? Why int only?]
"""
if isinstance(name, str):
name = name_split(name)
def convert(x):
try:
return int(x)
except (ValueError, TypeError):
return x
return map(convert, name)
class AllocationError(Exception):
"""
Exception raised when a Result has no associated storage.
"""
pass
class BindError(Exception):
"""
Exception raised when a Component is already bound and we try to
bound it again.
see Component.bind() help for more information.
"""
pass
class Component(object):
"""
Base class for the various kinds of components which are not
......@@ -138,43 +115,6 @@ class Component(object):
self.__dict__['_name'] = ''
self.__dict__['parent'] = None
def bind(self, parent, name, dup_ok=True):
"""
Marks this component as belonging to the parent (the parent is
typically a Composite instance). The component can be accessed
through the parent with the specified name. If dup_ok is True
and that this Component is already bound, a duplicate of the
component will be made using the dup() method and the
duplicate will be bound instead of this Component. If dup_ok
is False and this Component is already bound, a BindError wil
be raised.
bind() returns the Component instance which has been bound to
the parent. For an unbound instance, this will usually be
self.
"""
if self.bound():
if dup_ok:
try:
return self.dup().bind(parent, name, False)
except BindError, e:
e.args = (e.args[0] +
' ; This seems to have been caused by an implementation of dup'
' that keeps the previous binding (%s)' % self.dup,) + e.args[1:]
raise
else:
raise BindError("%s is already bound to %s as %s" % (self, self.parent, self.name))
self.parent = parent
self.name = name_join(parent.name, name)
return self
def bound(self):
"""
Returns True if this Component instance is bound to a
Composite.
"""
return self.parent is not None
def allocate(self, memo):
"""
Populates the memo dictionary with gof.Result -> io.In
......@@ -238,17 +178,6 @@ class Component(object):
"""
raise NotImplementedError
def dup(self):
"""
Returns a Component identical to this one, but which is not
bound to anything and does not retain the original's name.
This is useful to make Components that are slight variations
of another or to have Components that behave identically but
are accessed in different ways.
"""
raise NotImplementedError()
def __get_name__(self):
"""
Getter for self.name
......@@ -297,9 +226,6 @@ class _RComponent(Component):
rval = '%s :: %s' % (self.__class__.__name__, self.r.type)
return rval
def dup(self):
return self.__class__(self.r)
class External(_RComponent):
"""
......@@ -345,8 +271,8 @@ class Member(_RComponent):
rval = gof.Container(r, storage = [getattr(r, 'data', None)],
readonly=isinstance(r, gof.Constant))
memo[r] = io.In(result=r,
value=rval,
mutable=False)
value=rval,
mutable=False)
return memo[r]
def build(self, mode, memo):
......@@ -430,28 +356,6 @@ class Method(Component):
self.updates = dict(updates, **kwupdates)
self.mode = mode
def bind(self, parent, name, dup_ok=True):
"""Implement`Component.bind`"""
rval = super(Method, self).bind(parent, name, dup_ok=dup_ok)
rval.resolve_all()
return rval
def resolve(self, name):
"""Return the Result corresponding to a given name
:param name: the name of a Result in the Module to which this Method is bound
:type name: str
:rtype: `Result`
"""
if not self.bound():
raise ValueError('Trying to resolve a name on an unbound Method.')
result = self.parent.resolve(name)
if not hasattr(result, 'r'):
raise TypeError('Expected a Component with subtype Member or External.')
return result
def resolve_all(self):
"""Convert all inputs, outputs, and updates specified as strings to Results.
......@@ -463,7 +367,8 @@ class Method(Component):
elif isinstance(x, _RComponent):
return x.r
else:
return self.resolve(x).r
raise Exception('damnit looks like we need this')
# return self.resolve(x).r
def resolve_inputs():
if isinstance(self.inputs, (io.In, gof.Result, str)):
......@@ -633,13 +538,6 @@ class Method(Component):
"; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
def dup(self):
self.resolve_all()
return self.__class__(inputs=list(self.inputs),
outputs=list(self.outputs) if isinstance(self.outputs, list) else self.outputs,
updates=dict(self.updates),
mode=self.mode)
def __call__(self, *args, **kwargs):
raise TypeError("'Method' object is not callable"
" (Hint: compile your module first. See Component.make())")
......@@ -797,23 +695,6 @@ class ComponentList(Composite):
raise TypeError(c, type(c))
self.append(c)
def resolve(self, name):
# resolves # to the #th number in the list
# resolves name string to parent.resolve(name)
# TODO: eliminate canonicalize
name = canonicalize(name)
try:
item = self.get(name[0])
except TypeError:
# if name[0] is not a number, we check in the parent
if not self.bound():
raise TypeError('Cannot resolve a non-integer name on an unbound ComponentList.')
return self.parent.resolve(name)
if len(name) > 1:
# TODO: eliminate
return item.resolve(name[1:])
return item
def components(self):
return self._components
......@@ -821,8 +702,12 @@ class ComponentList(Composite):
return enumerate(self._components)
def build(self, mode, memo):
if self in memo:
return memo[self]
builds = [c.build(mode, memo) for c in self._components]
return ComponentListInstance(self, builds)
rval = ComponentListInstance(self, builds)
memo[self] = rval
return rval
def get(self, item):
return self._components[item]
......@@ -832,7 +717,8 @@ class ComponentList(Composite):
value = Member(value)
elif not isinstance(value, Component):
raise TypeError('ComponentList may only contain Components.', value, type(value))
value = value.bind(self, str(item))
#value = value.bind(self, str(item))
value.name = name_join(self.name, str(item))
self._components[item] = value
def append(self, c):
......@@ -883,9 +769,6 @@ class ComponentList(Composite):
for i, member in enumerate(self._components):
member.name = '%s.%i' % (name, i)
def dup(self):
return self.__class__(*[c.dup() for c in self._components])
def default_initialize(self, init = {}, **kwinit):
for k, initv in dict(init, **kwinit).iteritems():
......@@ -935,14 +818,6 @@ class ComponentDict(Composite):
self.__dict__['_components'] = components
def resolve(self, name):
name = canonicalize(name)
item = self.get(name[0])
if len(name) > 1:
return item.resolve(name[1:])
return item
def components(self):
return self._components.itervalues()
......@@ -950,11 +825,14 @@ class ComponentDict(Composite):
return self._components.iteritems()
def build(self, mode, memo):
if self in memo:
return self[memo]
inst = self.InstanceType(self, {})
for name, c in self._components.iteritems():
x = c.build(mode, memo)
if x is not None:
inst[name] = x
memo[self] = inst
return inst
def get(self, item):
......@@ -963,7 +841,8 @@ class ComponentDict(Composite):
def set(self, item, value):
if not isinstance(value, Component):
raise TypeError('ComponentDict may only contain Components.', value, type(value))
value = value.bind(self, item)
#value = value.bind(self, item)
value.name = name_join(self.name, str(item))
self._components[item] = value
def pretty(self, **kwargs):
......
......@@ -492,14 +492,76 @@ class T_module(unittest.TestCase):
self.assertRaises(NotImplementedError, c.allocate,"")
self.assertRaises(NotImplementedError, c.build,"","")
self.assertRaises(NotImplementedError, c.pretty)
self.assertRaises(NotImplementedError, c.dup)
c=Composite()
self.assertRaises(NotImplementedError, c.resolve,"n")
self.assertRaises(NotImplementedError, c.components)
self.assertRaises(NotImplementedError, c.components_map)
self.assertRaises(NotImplementedError, c.get,"n")
self.assertRaises(NotImplementedError, c.set,"n",1)
def test_multiple_references():
class A(theano.Module):
def __init__(self, sub_module):
super(A, self).__init__()
self.sub_module = sub_module
def _instance_initialize(self, obj):
print 'Initializing A'
class B(theano.Module):
def __init__(self, sub_module):
super(B, self).__init__()
self.sub_module = sub_module
def _instance_initialize(self, obj):
print 'Initializing B'
class C(theano.Module):
def __init__(self):
super(C, self).__init__()
self.value = theano.tensor.scalar()
def _instance_initialize(self, obj):
print 'Initializing C'
obj.value = 0
def _instance_set(self, obj, value):
print 'Setting C'
obj.value = value
class D(theano.Module):
def __init__(self):
super(D, self).__init__()
self.c = C()
self.a = A(self.c)
self.b = B(self.c)
# Workaround for bug exhibited in a previous email.
self.bug = theano.tensor.scalar()
def _instance_initialize(self, obj):
print 'Initializing D'
obj.c.set(1)
d = D()
d_instance = d.make(mode = 'FAST_COMPILE')
assert d_instance.c.value == 1
assert d_instance.a.sub_module.value == 1
assert d_instance.b.sub_module.value == 1
def test_tuple_members():
M = Module()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论