提交 d28f7c55 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

......@@ -510,7 +510,7 @@ class FunctionMaker(object):
if isinstance(input, SymbolicInputKit):
if default is NODEFAULT:
_defaults.append((False, False, None))
if default is None:
elif default is None:
_defaults.append((True, True, None))
else:
_defaults.append((False, False, default))
......
......@@ -404,7 +404,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
return rval_list
def inputs(result_list):
def inputs(result_list, blockers = None):
"""Return the inputs required to compute the given Results.
:type result_list: list of `Result` instances
......@@ -417,7 +417,7 @@ def inputs(result_list):
"""
def expand(r):
if r.owner:
if r.owner and (not blockers or r not in blockers):
l = list(r.owner.inputs)
l.reverse()
return l
......
......@@ -140,12 +140,12 @@ class Member(_RComponent):
from theano.sandbox import pprint
class Method(Component):
def __init__(self, inputs, outputs, updates = {}, **kwupdates):
def __init__(self, inputs, outputs, updates = {}, kits = [], **kwupdates):
super(Method, self).__init__()
self.inputs = inputs
self.outputs = outputs
self.updates = dict(updates, **kwupdates)
self.kits = []
self.kits = list(kits)
def bind(self, parent, name):
super(Method, self).bind(parent, name)
......@@ -208,7 +208,8 @@ class Method(Component):
outputs = self.outputs
_inputs = [x.result for x in inputs]
for input in gof.graph.inputs(outputs if isinstance(outputs, (list, tuple)) else [outputs]
+ [x.update for x in inputs if getattr(x, 'update', False)]):
+ [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value):
inputs += [compile.In(result = input,
value = get_storage(input, True))]
......@@ -235,6 +236,13 @@ class Method(Component):
"; " if self.updates else "",
", ".join("%s <= %s" % (old, new) for old, new in self.updates.iteritems()))
def __copy__(self):
self.resolve_all()
return self.__class__(list(self.inputs),
list(self.outputs) if isinstance(outputs, list) else outputs,
dict(self.updates),
list(self.kits))
class CompositeInstance(object):
......@@ -326,6 +334,9 @@ class ComponentListInstance(CompositeInstance):
def __str__(self):
return '[%s]' % ', '.join(map(str, self.__items__))
def __len__(self):
return len(self.__items__)
def initialize(self, init):
for i, initv in enumerate(init):
self[i] = initv
......@@ -402,7 +413,6 @@ class ComponentList(Composite):
class ModuleInstance(CompositeInstance):
__hide__ = []
def __setitem__(self, item, value):
if item not in self.__items__:
......@@ -415,7 +425,7 @@ class ModuleInstance(CompositeInstance):
for k, v in sorted(self.__items__.iteritems()):
if isinstance(v, gof.Container):
v = v.value
if not k.startswith('_') and not callable(v) and not k in self.__hide__:
if not k.startswith('_') and not callable(v) and not k in getattr(self, '__hide__', []):
pre = '%s: ' % k
strings.append('%s%s' % (pre, str(v).replace('\n', '\n' + ' '*len(pre))))
return '{%s}' % '\n'.join(strings).replace('\n', '\n ')
......
......@@ -1994,8 +1994,13 @@ def grad(cost, wrt, g_cost=None):
Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype))
if hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in wrt]
try:
it = iter(wrt)
except:
it = None
if it: #hasattr(wrt, '__iter__'): # isinstance(wrt, (list, tuple)):
return [gmap.get(p, zero(p)) for p in it]
else:
return gmap.get(wrt, zero(wrt))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论