提交 cf22caab authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -65,7 +65,12 @@ Software Requirements ...@@ -65,7 +65,12 @@ Software Requirements
- python 2.5 - python 2.5
- SciPy (specifically numpy, sparse, weave). Numpy version >= 1.1 fixes memory leak. Numpy version >=1.2 fixes more memory leak. - SciPy (specifically numpy, sparse, weave). We recommend scipy >=0.7 if you
are using sparse matrices, because scipy.sparse is buggy in 0.6.
(scipy.csc_matrix dot has a bug with singleton dimensions. There may be more
bugs.)
Numpy version >= 1.1 fixes
memory leak. Numpy version >=1.2 fixes more memory leak.
- docutils, pygments (optional, to build documentation) - docutils, pygments (optional, to build documentation)
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# The list of objects to document. Objects can be named using # The list of objects to document. Objects can be named using
# dotted names, module filenames, or package directory names. # dotted names, module filenames, or package directory names.
# Alases for this option include "objects" and "values". # Alases for this option include "objects" and "values".
modules: theano modules: theano scipy.sparse
# The type of output that should be generated. Should be one # The type of output that should be generated. Should be one
# of: html, text, latex, dvi, ps, pdf. # of: html, text, latex, dvi, ps, pdf.
......
from .. import gof from theano import gof
from ..printing import pprint from theano.printing import pprint
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from functools import partial from functools import partial
...@@ -791,31 +791,29 @@ def wrap(x): ...@@ -791,31 +791,29 @@ def wrap(x):
return wrapper(x) return wrapper(x)
return x return x
def dict_wrap(d):
for k,v in d.iteritems():
d[k]=wrap(v)
return d
# Result -> Member
register_wrapper(lambda x: isinstance(x, gof.Result) and not x.owner,
lambda x: Member(x))
# Result -> External # Result -> External
register_wrapper(lambda x: isinstance(x, gof.Result), register_wrapper(lambda x: isinstance(x, gof.Result) and x.owner,
lambda x: External(x)) lambda x: External(x))
# [Component1, Component2, ...] -> ComponentList(Component1, Component2, ...) # [[Result1], {Result2}, Result3...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) and all(isinstance(r, Component) for r in x),
lambda x: ComponentList(*x))
# [Result1, Result2, ...] -> ComponentList(Member(Result1), Member(Result2), ...)
register_wrapper(lambda x: isinstance(x, (list, tuple)) \ register_wrapper(lambda x: isinstance(x, (list, tuple)) \
and all(isinstance(r, gof.Result) and not r.owner for r in x), and all(isinstance(r, (gof.Result,Component,list,
lambda x: ComponentList(*map(Member, x))) tuple, dict)) for r in x),
#{ "name1":Result1,...} -> ComponentDict(Member(Result1),...) lambda x: ComponentList(*map(wrap, x)))
def dict_member(d):
nd={} #{ "name1":{Component,Result,list,tuple,dict},...} -> ComponentDict({Component,Result,list,tuple,dict},...)
for k,v in d.iteritems():
nd[k]=Member(v)
return nd
register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,gof.Result) \
and not r.owner for r in x.itervalues()),
lambda x: ComponentDict(dict_member(x)))
register_wrapper(lambda x: isinstance(x, dict) \ register_wrapper(lambda x: isinstance(x, dict) \
and all(isinstance(r,Component) for r in x.itervalues()), and all(isinstance(r,(Component,gof.Result,list,tuple,dict)) for r in x.itervalues()),
lambda x: ComponentDict(x)) lambda x: ComponentDict(dict_wrap(x)))
class Curry: class Curry:
def __init__(self, obj, name, arg): def __init__(self, obj, name, arg):
...@@ -869,13 +867,9 @@ class Module(ComponentDict): ...@@ -869,13 +867,9 @@ class Module(ComponentDict):
if attr == '_components' and '_components' not in self.__dict__: if attr == '_components' and '_components' not in self.__dict__:
self.__dict__['_components'] = {} self.__dict__['_components'] = {}
try: try:
rval = self[attr] rval = self.__dict__["local_attr"][attr]
except KeyError: except KeyError:
raise AttributeError('%s has no %s attribute.' % (self.__class__, attr)) raise AttributeError('%s has no %s attribute.' % (self.__class__, attr))
if isinstance(rval, (External, Member)):
# Special treatment for External and Member, so that
# the user may use them to build graphs more easily.
return rval.r
return rval return rval
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
...@@ -886,16 +880,34 @@ class Module(ComponentDict): ...@@ -886,16 +880,34 @@ class Module(ComponentDict):
self.__set_name__(value) self.__set_name__(value)
return return
value = self.__wrapper__(value) def remove_member(v):
try: if isinstance(v, (Member, External)):
self[attr] = value return v.r
except: elif isinstance(v, (gof.Result,Method,Module)):
if isinstance(value, Component): return v
raise elif isinstance(v,(int,bool)):
return v
elif isinstance(v, (list,tuple)):
return map(remove_member,v)
elif isinstance(v,dict):
for k,vv in v.iteritems():
v[k]=remove_member(vv)
return v
else: else:
self.__dict__[attr] = value # raise NotImplementedError
# print "WARNING: unknow:",v
return v
value=remove_member(value)
if not hasattr(self,"local_attr"):
self.__dict__["local_attr"]={}
self.__dict__["local_attr"][attr]=value
def build(self, mode, memo): def build(self, mode, memo):
for k,v in self.local_attr.iteritems():
self.__setattr__(k,v)
inst = super(Module, self).build(mode, memo) inst = super(Module, self).build(mode, memo)
for method in dir(self): for method in dir(self):
# Any method with a name like '_instance_XXX' is added to # Any method with a name like '_instance_XXX' is added to
...@@ -911,6 +923,64 @@ class Module(ComponentDict): ...@@ -911,6 +923,64 @@ class Module(ComponentDict):
for name, value in chain(init.iteritems(), kwinit.iteritems()): for name, value in chain(init.iteritems(), kwinit.iteritems()):
inst[name] = value inst[name] = value
def make_mi(self, *args, **kwargs):
meth=[]#we put the method after the member to be sure of the ordering.
for k,v in self.local_attr.iteritems():
if isinstance(v,Module):
v=v.make_mi(args,kwargs)
if isinstance(v,Method):
meth.append((k,v))
else:
v = self.__wrapper__(v)
try:
self[k] = v
except:
if isinstance(v, Component):
raise
else:
self.__dict__[k] = v
# self.__setitem__(k,v)
for k,v in meth:
self.__setitem__(k,v)
return self
def make(self, *args, **kwargs):
"""
Allocates the necessary containers using allocate() and uses
build() to make an instance which will be returned. The
initialize() method of the instance will be called with the
arguments and the keyword arguments. If 'mode' is in the
keyword arguments it will be passed to build().
"""
self.make_mi(args,kwargs)
mode = kwargs.pop('mode', 'FAST_COMPILE')
rval = self.make_no_init(mode)
if hasattr(rval, 'initialize'):
rval.initialize(*args, **kwargs)
return rval
def __str__(self):
return self.__class__.__name__+"(%s)" % ', '.join(x for x in sorted(map(str, self.local_attr)) if x[0] != '_')
def __get_name__(self):
"""
Getter for self.name
"""
return self._name
def __set_name__(self, name):
"""
Setter for self.name
"""
self._name = name
name = property(lambda self: self.__get_name__(),
lambda self, value: self.__set_name__(value),
"Contains the name of this Component")
FancyModule = Module FancyModule = Module
FancyModuleInstance = ModuleInstance FancyModuleInstance = ModuleInstance
......
...@@ -10,7 +10,7 @@ class T_test_module(unittest.TestCase): ...@@ -10,7 +10,7 @@ class T_test_module(unittest.TestCase):
def test_whats_up_with_submembers(self): def test_whats_up_with_submembers(self):
class Blah(FancyModule): class Blah(FancyModule):
def __init__(self, stepsize): def __init__(self, stepsize):
super(Blah, self).__init__(self) super(Blah, self).__init__()
self.stepsize = Member(T.value(stepsize)) self.stepsize = Member(T.value(stepsize))
x = T.dscalar() x = T.dscalar()
......
...@@ -22,6 +22,8 @@ def register_specialize(lopt, *tags, **kwargs): ...@@ -22,6 +22,8 @@ def register_specialize(lopt, *tags, **kwargs):
""" Types of sparse matrices to use for testing """ """ Types of sparse matrices to use for testing """
_mtypes = [sparse.csc_matrix, sparse.csr_matrix] _mtypes = [sparse.csc_matrix, sparse.csr_matrix]
#_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix, sparse.lil_matrix, sparse.coo_matrix] #_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix, sparse.lil_matrix, sparse.coo_matrix]
#* new class ``dia_matrix`` : the sparse DIAgonal format
#* new class ``bsr_matrix`` : the Block CSR format
_mtype_to_str = {sparse.csc_matrix: "csc", sparse.csr_matrix: "csr"} _mtype_to_str = {sparse.csc_matrix: "csc", sparse.csr_matrix: "csr"}
...@@ -673,16 +675,14 @@ class StructuredDot(gof.Op): ...@@ -673,16 +675,14 @@ class StructuredDot(gof.Op):
""" """
def make_node(self, a, b): def make_node(self, a, b):
assert a.type.dtype == b.type.dtype assert a.type.dtype == b.type.dtype
if type(a) is not SparseResult: if type(a) is not SparseResult and type(a) is not SparseConstant:
raise TypeError('First argument must be of type SparseResult'); raise TypeError('First argument must be of type SparseResult or SparseConstant');
return gof.Apply(self, [a,b], [tensor.tensor(a.type.dtype, (False, False))]) return gof.Apply(self, [a,b], [tensor.tensor(a.type.dtype, (False, False))])
def perform(self, node, (a,b), (out,)): def perform(self, node, (a,b), (out,)):
if a.shape[1] != b.shape[0]: if a.shape[1] != b.shape[0]:
raise ValueError('shape mismatch in StructuredDot.perform', (a.shape, b.shape)) raise ValueError('shape mismatch in StructuredDot.perform', (a.shape, b.shape))
if b.shape[0] == 1:
raise NotImplemented('ERROR: scipy.csc_matrix dot has bug with singleton dimensions')
result = a.dot(b) result = a.dot(b)
...@@ -699,6 +699,12 @@ class StructuredDot(gof.Op): ...@@ -699,6 +699,12 @@ class StructuredDot(gof.Op):
assert result.ndim == 2 assert result.ndim == 2
if result.shape != (a.shape[0], b.shape[1]):
if b.shape[0] == 1:
raise Exception("a.shape=%s, b.shape=%s, result.shape=%s ??? This is probably because scipy.csc_matrix dot has a bug with singleton dimensions (i.e. b.shape[0]=1), for scipy 0.6. Use scipy 0.7" % (a.shape, b.shape, result.shape))
else:
raise Exception("a.shape=%s, b.shape=%s, result.shape=%s ??? I have no idea why")
## Commenting this out because result should be a numpy.ndarray since the assert above ## Commenting this out because result should be a numpy.ndarray since the assert above
## (JB 20090109) ## (JB 20090109)
#out[0] = numpy.asarray(result) #TODO: fix this really bad implementation #out[0] = numpy.asarray(result) #TODO: fix this really bad implementation
......
...@@ -245,7 +245,7 @@ class RModule(compile.Module): ...@@ -245,7 +245,7 @@ class RModule(compile.Module):
def __init__(self, components = {}, **kwcomponents): def __init__(self, components = {}, **kwcomponents):
super(RModule, self).__init__(components, **kwcomponents) super(RModule, self).__init__(components, **kwcomponents)
self.random = RandomKit('rkit') self.random = RandomKit('rkit')
self._components['_rkit'] = compile.KitComponent(self.random) self._rkit = compile.KitComponent(self.random)
def __wrapper__(self, x): def __wrapper__(self, x):
x = compile.module.wrap(x) x = compile.module.wrap(x)
......
...@@ -32,7 +32,7 @@ class T_test_module(unittest.TestCase): ...@@ -32,7 +32,7 @@ class T_test_module(unittest.TestCase):
""" """
class B(RModule): class B(RModule):
def __init__(self): def __init__(self):
super(B, self).__init__(self) super(B, self).__init__()
self.x = compile.Member(tensor.dvector()) self.x = compile.Member(tensor.dvector())
self.r = self.random.uniform(tensor.shape(self.x)) self.r = self.random.uniform(tensor.shape(self.x))
...@@ -40,7 +40,7 @@ class T_test_module(unittest.TestCase): ...@@ -40,7 +40,7 @@ class T_test_module(unittest.TestCase):
self.f = compile.Method([self.x], self.r) self.f = compile.Method([self.x], self.r)
class E(RModule): class E(RModule):
def __init__(self): def __init__(self):
super(E, self).__init__(self) super(E, self).__init__()
self.b = B() self.b = B()
self.f = compile.Method([self.b.x], self.b.r) self.f = compile.Method([self.b.x], self.b.r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论