提交 5961805c authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added pretty printing to module

上级 d2658923
...@@ -71,6 +71,9 @@ class Component(object): ...@@ -71,6 +71,9 @@ class Component(object):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def pretty(self):
raise NotImplementedError
def __get_name__(self): def __get_name__(self):
return self._name return self._name
...@@ -103,6 +106,10 @@ class External(Component): ...@@ -103,6 +106,10 @@ class External(Component):
def __str__(self): def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r) return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self):
rval = 'External :: %s' % self.r.type
return rval
class Member(Component): class Member(Component):
...@@ -128,8 +135,21 @@ class Member(Component): ...@@ -128,8 +135,21 @@ class Member(Component):
def __str__(self): def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r) return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self):
rval = 'Member :: %s' % self.r.type
return rval
# def pretty(self, header = False, **kwargs):
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += 'Member:%s' % cr
# rval += '%s :: %s' % ((self.r.name if self.r.name else '<unnamed>'), self.r.type)
# return rval
from theano.sandbox import pprint
class Method(Component): class Method(Component):
def __init__(self, inputs, outputs, updates = {}, **kwupdates): def __init__(self, inputs, outputs, updates = {}, **kwupdates):
...@@ -206,6 +226,19 @@ class Method(Component): ...@@ -206,6 +226,19 @@ class Method(Component):
inputs += [(kit, get_storage(kit, True)) for kit in self.kits] inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return compile.function(inputs, outputs, mode) return compile.function(inputs, outputs, mode)
def pretty(self, header = True, **kwargs):
self.resolve_all()
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += "Method(%s):" % ", ".join(map(str, self.inputs))
if self.inputs:
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else:
rval = ''
rval += pprint.pp.process_graph(self.inputs, self.outputs, self.updates, False)
return rval
def __str__(self): def __str__(self):
return "Method(%s -> %s%s%s)" % \ return "Method(%s -> %s%s%s)" % \
(self.inputs, (self.inputs,
...@@ -350,6 +383,16 @@ class ComponentList(Composite): ...@@ -350,6 +383,16 @@ class ComponentList(Composite):
def __str__(self): def __str__(self):
return str(self._components) return str(self._components)
def pretty(self, header = True, **kwargs):
cr = '\n ' #if header else '\n'
strings = []
#if header:
# rval += "ComponentList:"
for i, c in self.components_map():
strings.append('%i:%s%s' % (i, cr, c.pretty().replace('\n', cr)))
#rval += cr + '%i -> %s' % (i, c.pretty(header = True, **kwargs).replace('\n', cr))
return '\n'.join(strings)
def __set_name__(self, name): def __set_name__(self, name):
super(ComponentList, self).__set_name__(name) super(ComponentList, self).__set_name__(name)
for i, member in enumerate(self._components): for i, member in enumerate(self._components):
...@@ -406,6 +449,21 @@ class Module(Composite): ...@@ -406,6 +449,21 @@ class Module(Composite):
value.bind(self, item) value.bind(self, item)
self._components[item] = value self._components[item] = value
def pretty(self, header = True, **kwargs):
cr = '\n ' #if header else '\n'
strings = []
# if header:
# rval += "Module:"
for name, component in self.components_map():
if name.startswith('_'):
continue
strings.append('%s:%s%s' % (name, cr, component.pretty().replace('\n', cr)))
strings.sort()
return '\n'.join(strings)
def __str__(self):
return "Module(%s)" % ', '.join(x for x in sorted(map(str, self._components)) if x[0] != '_')
def __set_name__(self, name): def __set_name__(self, name):
super(Module, self).__set_name__(name) super(Module, self).__set_name__(name)
for mname, member in self._components.iteritems(): for mname, member in self._components.iteritems():
...@@ -463,6 +521,12 @@ class FancyModule(Module): ...@@ -463,6 +521,12 @@ class FancyModule(Module):
return rval return rval
def __setattr__(self, attr, value): def __setattr__(self, attr, value):
if attr == 'parent':
self.__dict__[attr] = value
return
elif attr == 'name':
self.__set_name__(value)
return
value = self.__wrapper__(value) value = self.__wrapper__(value)
try: try:
self[attr] = value self[attr] = value
...@@ -579,10 +643,21 @@ if __name__ == '__main__': ...@@ -579,10 +643,21 @@ if __name__ == '__main__':
mod.whatever = 123 mod.whatever = 123
print mod._components mod2 = RModule()
mod2.submodule = mod
#print mod._components
#print mod
#print mod.inc.pretty()
print mod2.pretty()
inst = mod.make(s = 2, list = [900, 9000]) inst = mod.make(s = 2, list = [900, 9000])
print '---'
print inst.test1()
print '---'
inst.seed(10) inst.seed(10)
print inst.test1() print inst.test1()
print inst.test1() print inst.test1()
......
...@@ -227,17 +227,39 @@ class PPrinter: ...@@ -227,17 +227,39 @@ class PPrinter:
cp.assign(condition, printer) cp.assign(condition, printer)
return cp return cp
def process_graph(self, inputs, outputs): def process_graph(self, inputs, outputs, updates = {}, display_inputs = False):
strings = ["inputs: " + ", ".join(map(str, inputs))] if not isinstance(inputs, (list, tuple)): inputs = [inputs]
if not isinstance(outputs, (list, tuple)): outputs = [outputs]
current = None
if display_inputs:
strings = [(0, "inputs: " + ", ".join(map(str, list(inputs) + updates.keys())))]
else:
strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and r is not current, pprinter = self.clone_assign(lambda pstate, r: r.name is not None and r is not current,
LeafPrinter()) LeafPrinter())
for node in gof.graph.io_toposort(inputs, outputs): inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
list(outputs) + updates.values()):
for output in node.outputs: for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output))))
i += 1
if output.name is not None or output in outputs: if output.name is not None or output in outputs:
name = 'outputs[%i]' % outputs.index(output) if output.name is None else output.name name = 'out[%i]' % outputs.index(output) if output.name is None else output.name
current = output current = output
strings.append("%s = %s" % (name, pprinter.process(output))) try:
return "\n".join(strings) idx = 2000 + outputs.index(output)
except ValueError:
idx = i
if len(outputs) == 1 and outputs[0] is output:
strings.append((idx, "return %s" % pprinter.process(output)))
else:
strings.append((idx, "%s = %s" % (name, pprinter.process(output))))
i += 1
strings.sort()
return "\n".join(s[1] for s in strings)
...@@ -261,6 +283,8 @@ psub = OperatorPrinter('-', -2, 'left') ...@@ -261,6 +283,8 @@ psub = OperatorPrinter('-', -2, 'left')
pdot = OperatorPrinter(special['middle_dot'], -1, 'left') pdot = OperatorPrinter(special['middle_dot'], -1, 'left')
psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left') psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left')
from ..tensor import inplace as I
def pprinter(): def pprinter():
pp = PPrinter() pp = PPrinter()
pp.assign(lambda pstate, r: True, DefaultPrinter()) pp.assign(lambda pstate, r: True, DefaultPrinter())
...@@ -276,16 +300,16 @@ def pprinter(): ...@@ -276,16 +300,16 @@ def pprinter():
pp.assign(T.tensor_copy, IgnorePrinter()) pp.assign(T.tensor_copy, IgnorePrinter())
pp.assign(T.log, FunctionPrinter('log')) pp.assign(T.log, FunctionPrinter('log'))
pp.assign(T.tanh, FunctionPrinter('tanh')) pp.assign(T.tanh, FunctionPrinter('tanh'))
pp.assign(T.transpose_inplace, MemberPrinter('T')) pp.assign(I.transpose_inplace, MemberPrinter('T'))
pp.assign(T._abs, PatternPrinter(('|%(0)s|', -1000))) pp.assign(T.abs_, PatternPrinter(('|%(0)s|', -1000)))
pp.assign(T.sgn, FunctionPrinter('sgn')) pp.assign(T.sgn, FunctionPrinter('sgn'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 0, FunctionPrinter('seros')) pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 0, FunctionPrinter('seros'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 1, FunctionPrinter('ones')) pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 1, FunctionPrinter('ones'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Subtensor), SubtensorPrinter()) pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Subtensor), SubtensorPrinter())
pp.assign(T.shape, MemberPrinter('shape')) pp.assign(T.shape, MemberPrinter('shape'))
pp.assign(T.fill, FunctionPrinter('fill')) pp.assign(T.fill, FunctionPrinter('fill'))
pp.assign(T.vertical_stack, FunctionPrinter('vstack')) #pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter()) #pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter())
return pp return pp
pp = pprinter() pp = pprinter()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论