提交 1b4a9e7e authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improvements

上级 486b71f2
...@@ -510,7 +510,7 @@ class FunctionMaker(object): ...@@ -510,7 +510,7 @@ class FunctionMaker(object):
if isinstance(input, SymbolicInputKit): if isinstance(input, SymbolicInputKit):
if default is NODEFAULT: if default is NODEFAULT:
_defaults.append((False, False, None)) _defaults.append((False, False, None))
if default is None: elif default is None:
_defaults.append((True, True, None)) _defaults.append((True, True, None))
else: else:
_defaults.append((False, False, default)) _defaults.append((False, False, default))
......
...@@ -404,7 +404,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False): ...@@ -404,7 +404,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
return rval_list return rval_list
def inputs(result_list): def inputs(result_list, blockers = None):
"""Return the inputs required to compute the given Results. """Return the inputs required to compute the given Results.
:type result_list: list of `Result` instances :type result_list: list of `Result` instances
...@@ -417,7 +417,7 @@ def inputs(result_list): ...@@ -417,7 +417,7 @@ def inputs(result_list):
""" """
def expand(r): def expand(r):
if r.owner: if r.owner and (not blockers or r not in blockers):
l = list(r.owner.inputs) l = list(r.owner.inputs)
l.reverse() l.reverse()
return l return l
......
...@@ -140,12 +140,12 @@ class Member(_RComponent): ...@@ -140,12 +140,12 @@ class Member(_RComponent):
from theano.sandbox import pprint from theano.sandbox import pprint
class Method(Component): class Method(Component):
def __init__(self, inputs, outputs, updates = {}, **kwupdates): def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
super(Method, self).__init__() super(Method, self).__init__()
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.updates = dict(updates, **kwupdates) self.updates = dict(updates, **kwupdates)
self.kits = [] self.kits = list(kits)
def bind(self, parent, name): def bind(self, parent, name):
super(Method, self).bind(parent, name) super(Method, self).bind(parent, name)
...@@ -208,7 +208,8 @@ class Method(Component): ...@@ -208,7 +208,8 @@ class Method(Component):
outputs = self.outputs outputs = self.outputs
_inputs = [x.result for x in inputs] _inputs = [x.result for x in inputs]
for input in gof.graph.inputs(outputs if isinstance(outputs, (list, tuple)) else [outputs] for input in gof.graph.inputs(outputs if isinstance(outputs, (list, tuple)) else [outputs]
+ [x.update for x in inputs if getattr(x, 'update', False)]): + [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value): if input not in _inputs and not isinstance(input, gof.Value):
inputs += [compile.In(result = input, inputs += [compile.In(result = input,
value = get_storage(input, True))] value = get_storage(input, True))]
...@@ -235,6 +236,13 @@ class Method(Component): ...@@ -235,6 +236,13 @@ class Method(Component):
"; " if self.updates else "", "; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems())) ", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
def __copy__(self):
self.resolve_all()
return self.__class__(list(self.inputs),
list(self.outputs) if isinstance(outputs, list) else outputs,
dict(self.updates),
list(self.kits))
class CompositeInstance(object): class CompositeInstance(object):
...@@ -326,6 +334,9 @@ class ComponentListInstance(CompositeInstance): ...@@ -326,6 +334,9 @@ class ComponentListInstance(CompositeInstance):
def __str__(self): def __str__(self):
return '[%s]' % ', '.join(map(str, self.__items__)) return '[%s]' % ', '.join(map(str, self.__items__))
def __len__(self):
return len(self.__items__)
def initialize(self, init): def initialize(self, init):
for i, initv in enumerate(init): for i, initv in enumerate(init):
self[i] = initv self[i] = initv
...@@ -402,7 +413,6 @@ class ComponentList(Composite): ...@@ -402,7 +413,6 @@ class ComponentList(Composite):
class ModuleInstance(CompositeInstance): class ModuleInstance(CompositeInstance):
__hide__ = []
def __setitem__(self, item, value): def __setitem__(self, item, value):
if item not in self.__items__: if item not in self.__items__:
...@@ -415,7 +425,7 @@ class ModuleInstance(CompositeInstance): ...@@ -415,7 +425,7 @@ class ModuleInstance(CompositeInstance):
for k, v in sorted(self.__items__.iteritems()): for k, v in sorted(self.__items__.iteritems()):
if isinstance(v, gof.Container): if isinstance(v, gof.Container):
v = v.value v = v.value
if not k.startswith('_') and not callable(v) and not k in self.__hide__: if not k.startswith('_') and not callable(v) and not k in getattr(self, '__hide__', []):
pre = '%s: ' % k pre = '%s: ' % k
strings.append('%s%s' % (pre, str(v).replace('\n', '\n' + ' '*len(pre)))) strings.append('%s%s' % (pre, str(v).replace('\n', '\n' + ' '*len(pre))))
return '{%s}' % '\n'.join(strings).replace('\n', '\n ') return '{%s}' % '\n'.join(strings).replace('\n', '\n ')
......
...@@ -1994,8 +1994,13 @@ def grad(cost, wrt, g_cost=None): ...@@ -1994,8 +1994,13 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []), Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype)) numpy.asarray(0, dtype=p.type.dtype))
if hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)): try:
return [gmap.get(p, zero(p)) for p in wrt] it = iter(wrt)
except:
it = None
if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in it]
else: else:
return gmap.get(wrt, zero(wrt)) return gmap.get(wrt, zero(wrt))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论