提交 5961805c authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added pretty printing to module

上级 d2658923
......@@ -71,6 +71,9 @@ class Component(object):
def __str__(self):
return self.__class__.__name__
def pretty(self):
raise NotImplementedError
def __get_name__(self):
return self._name
......@@ -103,6 +106,10 @@ class External(Component):
def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self):
rval = 'External :: %s' % self.r.type
return rval
class Member(Component):
......@@ -128,8 +135,21 @@ class Member(Component):
def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self):
rval = 'Member :: %s' % self.r.type
return rval
# def pretty(self, header = False, **kwargs):
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += 'Member:%s' % cr
# rval += '%s :: %s' % ((self.r.name if self.r.name else '<unnamed>'), self.r.type)
# return rval
from theano.sandbox import pprint
class Method(Component):
def __init__(self, inputs, outputs, updates = {}, **kwupdates):
......@@ -206,6 +226,19 @@ class Method(Component):
inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return compile.function(inputs, outputs, mode)
def pretty(self, header = True, **kwargs):
self.resolve_all()
# cr = '\n ' if header else '\n'
# rval = ''
# if header:
# rval += "Method(%s):" % ", ".join(map(str, self.inputs))
if self.inputs:
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else:
rval = ''
rval += pprint.pp.process_graph(self.inputs, self.outputs, self.updates, False)
return rval
def __str__(self):
return "Method(%s -> %s%s%s)" % \
(self.inputs,
......@@ -350,6 +383,16 @@ class ComponentList(Composite):
def __str__(self):
return str(self._components)
def pretty(self, header = True, **kwargs):
cr = '\n ' #if header else '\n'
strings = []
#if header:
# rval += "ComponentList:"
for i, c in self.components_map():
strings.append('%i:%s%s' % (i, cr, c.pretty().replace('\n', cr)))
#rval += cr + '%i -> %s' % (i, c.pretty(header = True, **kwargs).replace('\n', cr))
return '\n'.join(strings)
def __set_name__(self, name):
super(ComponentList, self).__set_name__(name)
for i, member in enumerate(self._components):
......@@ -406,6 +449,21 @@ class Module(Composite):
value.bind(self, item)
self._components[item] = value
def pretty(self, header = True, **kwargs):
cr = '\n ' #if header else '\n'
strings = []
# if header:
# rval += "Module:"
for name, component in self.components_map():
if name.startswith('_'):
continue
strings.append('%s:%s%s' % (name, cr, component.pretty().replace('\n', cr)))
strings.sort()
return '\n'.join(strings)
def __str__(self):
return "Module(%s)" % ', '.join(x for x in sorted(map(str, self._components)) if x[0] != '_')
def __set_name__(self, name):
super(Module, self).__set_name__(name)
for mname, member in self._components.iteritems():
......@@ -463,6 +521,12 @@ class FancyModule(Module):
return rval
def __setattr__(self, attr, value):
if attr == 'parent':
self.__dict__[attr] = value
return
elif attr == 'name':
self.__set_name__(value)
return
value = self.__wrapper__(value)
try:
self[attr] = value
......@@ -579,10 +643,21 @@ if __name__ == '__main__':
mod.whatever = 123
print mod._components
mod2 = RModule()
mod2.submodule = mod
#print mod._components
#print mod
#print mod.inc.pretty()
print mod2.pretty()
inst = mod.make(s = 2, list = [900, 9000])
print '---'
print inst.test1()
print '---'
inst.seed(10)
print inst.test1()
print inst.test1()
......
......@@ -227,17 +227,39 @@ class PPrinter:
cp.assign(condition, printer)
return cp
def process_graph(self, inputs, outputs):
strings = ["inputs: " + ", ".join(map(str, inputs))]
def process_graph(self, inputs, outputs, updates = {}, display_inputs = False):
if not isinstance(inputs, (list, tuple)): inputs = [inputs]
if not isinstance(outputs, (list, tuple)): outputs = [outputs]
current = None
if display_inputs:
strings = [(0, "inputs: " + ", ".join(map(str, list(inputs) + updates.keys())))]
else:
strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and r is not current,
LeafPrinter())
for node in gof.graph.io_toposort(inputs, outputs):
inv_updates = dict((b, a) for (a, b) in updates.iteritems())
i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
list(outputs) + updates.values()):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
strings.append((i + 1000, "%s <- %s" % (name, pprinter.process(output))))
i += 1
if output.name is not None or output in outputs:
name = 'outputs[%i]' % outputs.index(output) if output.name is None else output.name
name = 'out[%i]' % outputs.index(output) if output.name is None else output.name
current = output
strings.append("%s = %s" % (name, pprinter.process(output)))
return "\n".join(strings)
try:
idx = 2000 + outputs.index(output)
except ValueError:
idx = i
if len(outputs) == 1 and outputs[0] is output:
strings.append((idx, "return %s" % pprinter.process(output)))
else:
strings.append((idx, "%s = %s" % (name, pprinter.process(output))))
i += 1
strings.sort()
return "\n".join(s[1] for s in strings)
......@@ -261,6 +283,8 @@ psub = OperatorPrinter('-', -2, 'left')
pdot = OperatorPrinter(special['middle_dot'], -1, 'left')
psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left')
from ..tensor import inplace as I
def pprinter():
pp = PPrinter()
pp.assign(lambda pstate, r: True, DefaultPrinter())
......@@ -276,16 +300,16 @@ def pprinter():
pp.assign(T.tensor_copy, IgnorePrinter())
pp.assign(T.log, FunctionPrinter('log'))
pp.assign(T.tanh, FunctionPrinter('tanh'))
pp.assign(T.transpose_inplace, MemberPrinter('T'))
pp.assign(T._abs, PatternPrinter(('|%(0)s|', -1000)))
pp.assign(I.transpose_inplace, MemberPrinter('T'))
pp.assign(T.abs_, PatternPrinter(('|%(0)s|', -1000)))
pp.assign(T.sgn, FunctionPrinter('sgn'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 0, FunctionPrinter('seros'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 1, FunctionPrinter('ones'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Subtensor), SubtensorPrinter())
pp.assign(T.shape, MemberPrinter('shape'))
pp.assign(T.fill, FunctionPrinter('fill'))
pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter())
#pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
#pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter())
return pp
pp = pprinter()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论