提交 f1d61ed9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added pprint (as printing) to the theano package root

上级 d5ae91e1
...@@ -4,7 +4,8 @@ import theano ...@@ -4,7 +4,8 @@ import theano
from theano import tensor as T from theano import tensor as T
from theano.tensor import nnet_ops from theano.tensor import nnet_ops
from theano.compile import module from theano.compile import module
from theano.sandbox import pprint from theano import printing, pprint
from theano import compile
import numpy as N import numpy as N
...@@ -17,6 +18,7 @@ class LogisticRegressionN(module.FancyModule): ...@@ -17,6 +18,7 @@ class LogisticRegressionN(module.FancyModule):
self.w = N.random.randn(n_in, n_out) self.w = N.random.randn(n_in, n_out)
self.b = N.random.randn(n_out) self.b = N.random.randn(n_out)
self.lr = 0.01 self.lr = 0.01
self.__hide__ = ['params']
def __init__(self, x = None, targ = None): def __init__(self, x = None, targ = None):
super(LogisticRegressionN, self).__init__() #boilerplate super(LogisticRegressionN, self).__init__() #boilerplate
...@@ -84,8 +86,8 @@ class LogisticRegression2(module.FancyModule): ...@@ -84,8 +86,8 @@ class LogisticRegression2(module.FancyModule):
if __name__ == '__main__': if __name__ == '__main__':
pprint.pp.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, pprint.FunctionPrinter('xsoftmaxdx')) pprint.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, printing.FunctionPrinter('xsoftmaxdx'))
pprint.pp.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, pprint.FunctionPrinter('nll', 'softmax', 'argmax')) pprint.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, printing.FunctionPrinter('nll', 'softmax', 'argmax'))
if 1: if 1:
lrc = LogisticRegressionN() lrc = LogisticRegressionN()
...@@ -94,16 +96,20 @@ if __name__ == '__main__': ...@@ -94,16 +96,20 @@ if __name__ == '__main__':
print '================' print '================'
print lrc.update.pretty(mode = theano.Mode('py', 'fast_run')) print lrc.update.pretty(mode = theano.Mode('py', 'fast_run'))
print '================' print '================'
# print lrc.update.pretty(mode = compile.FAST_RUN.excluding('inplace'))
# print '================'
# sys.exit(0) # sys.exit(0)
lr = lrc.make(10, 2, mode=theano.Mode('c|py', 'fast_run')) lr = lrc.make(10, 2, mode=theano.Mode('c|py', 'fast_run'))
#lr = lrc.make(10, 2, mode=compile.FAST_RUN.excluding('fast_run'))
#lr = lrc.make(10, 2, mode=theano.Mode('py', 'merge')) #'FAST_RUN') #lr = lrc.make(10, 2, mode=theano.Mode('py', 'merge')) #'FAST_RUN')
data_x = N.random.randn(5, 10) data_x = N.random.randn(5, 10)
data_y = (N.random.randn(5) > 0) data_y = (N.random.randn(5) > 0)
for i in xrange(10000): for i in xrange(10000):
lr.lr = 0.02
xe = lr.update(data_x, data_y) xe = lr.update(data_x, data_y)
if i % 100 == 0: if i % 100 == 0:
print i, xe print i, xe
......
...@@ -41,13 +41,19 @@ from compile import \ ...@@ -41,13 +41,19 @@ from compile import \
SymbolicOutput, Out, \ SymbolicOutput, Out, \
Mode, \ Mode, \
predefined_modes, predefined_linkers, predefined_optimizers, \ predefined_modes, predefined_linkers, predefined_optimizers, \
FunctionMaker, function, OpFromGraph #, eval_outputs, fast_compute FunctionMaker, function, OpFromGraph, \
Component, External, Member, KitComponent, Method, \
Composite, ComponentList, Module, FancyModule
from printing import \
pprint, pp
import tensor import tensor
import scalar import scalar
import sparse import sparse
import gradient import gradient
## import scalar_opt ## import scalar_opt
import subprocess as _subprocess import subprocess as _subprocess
......
...@@ -102,12 +102,19 @@ class Mode(object): ...@@ -102,12 +102,19 @@ class Mode(object):
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, gof.Query): if isinstance(optimizer, gof.Query):
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
optimizer = optdb.query(optimizer) self._optimizer = optimizer
self.optimizer = optimizer
def __str__(self): def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer) return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer)
def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query):
return optdb.query(self._optimizer)
else:
return self._optimizer
optimizer = property(__get_optimizer)
def including(self, *tags): def including(self, *tags):
return Mode(self.provided_linker, self.provided_optimizer.including(*tags)) return Mode(self.provided_linker, self.provided_optimizer.including(*tags))
......
from .. import gof from .. import gof
from ..printing import pprint
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from functools import partial from functools import partial
from copy import copy from copy import copy
import mode import io
import function_module as F import function_module as F
#from ..sandbox import pprint
def join(*args): def join(*args):
...@@ -117,7 +117,7 @@ class External(_RComponent): ...@@ -117,7 +117,7 @@ class External(_RComponent):
def pretty(self, **kwargs): def pretty(self, **kwargs):
rval = super(External, self).pretty() rval = super(External, self).pretty()
if self.r.owner: if self.r.owner:
rval += '\n= %s' % (pprint.pp2.process(self.r, dict(target = self.r))) rval += '\n= %s' % (pprint(self.r, dict(target = self.r)))
return rval return rval
...@@ -196,12 +196,14 @@ class Method(Component): ...@@ -196,12 +196,14 @@ class Method(Component):
else: else:
return gof.Container(r, storage = [None]) return gof.Container(r, storage = [None])
inputs = self.inputs inputs = self.inputs
inputs = [mode.In(result = input, inputs = [io.In(result = input,
value = get_storage(input)) value = get_storage(input),
mutable = False)
for input in inputs] for input in inputs]
inputs += [mode.In(result = k, inputs += [io.In(result = k,
update = v, update = v,
value = get_storage(k, True), value = get_storage(k, True),
mutable = True,
strict = True) strict = True)
for k, v in self.updates.iteritems()] for k, v in self.updates.iteritems()]
outputs = self.outputs outputs = self.outputs
...@@ -210,8 +212,9 @@ class Method(Component): ...@@ -210,8 +212,9 @@ class Method(Component):
+ [x.update for x in inputs if getattr(x, 'update', False)], + [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs): blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value): if input not in _inputs and not isinstance(input, gof.Value):
inputs += [mode.In(result = input, inputs += [io.In(result = input,
value = get_storage(input, True))] value = get_storage(input, True),
mutable = False)]
inputs += [(kit, get_storage(kit, True)) for kit in self.kits] inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return F.function(inputs, outputs, mode) return F.function(inputs, outputs, mode)
...@@ -234,11 +237,14 @@ class Method(Component): ...@@ -234,11 +237,14 @@ class Method(Component):
nup = len(k) nup = len(k)
eff_in = tuple(inputs) + tuple(k) eff_in = tuple(inputs) + tuple(k)
eff_out = tuple(outputs) + tuple(v) eff_out = tuple(outputs) + tuple(v)
env = gof.Env(*gof.graph.clone(eff_in + tuple(gof.graph.inputs(eff_out)), supp_in = tuple(gof.graph.inputs(eff_out))
env = gof.Env(*gof.graph.clone(eff_in + supp_in,
eff_out)) eff_out))
sup = F.Supervisor(set(env.inputs).difference(env.inputs[len(inputs):len(eff_in)]))
env.extend(sup)
mode.optimizer.optimize(env) mode.optimizer.optimize(env)
inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:])) inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:]))
rval += pprint.pp.process_graph(inputs, outputs, updates, False) rval += pprint(inputs, outputs, updates, False)
return rval return rval
def __str__(self): def __str__(self):
......
...@@ -101,7 +101,7 @@ class ReplaceValidate(History, Validator): ...@@ -101,7 +101,7 @@ class ReplaceValidate(History, Validator):
try: try:
env.replace(r, new_r) env.replace(r, new_r)
except Exception, e: except Exception, e:
if not 'The type of the replacement must be the same' in str(e) or not 'does not belong to this Env' in str(e): if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling) env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise raise
......
from .. import tensor as T import gof
from .. import scalar as S
from .. import gof
from copy import copy from copy import copy
import sys import sys
...@@ -88,7 +86,7 @@ class FunctionPrinter: ...@@ -88,7 +86,7 @@ class FunctionPrinter:
raise TypeError("function %s cannot represent a result with no associated operation" % self.names) raise TypeError("function %s cannot represent a result with no associated operation" % self.names)
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = 1000)) return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = -1000))
for input in node.inputs])) for input in node.inputs]))
class MemberPrinter: class MemberPrinter:
...@@ -119,65 +117,6 @@ class IgnorePrinter: ...@@ -119,65 +117,6 @@ class IgnorePrinter:
return "%s" % pprinter.process(input, pstate) return "%s" % pprinter.process(input, pstate)
class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
# return "%s" % self.__p(new_order[1:], pstate, r)
return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), pstate.pprinter.process(r))
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, T.DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
else:
raise TypeError("Can only print DimShuffle.")
class SubtensorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print Subtensor.")
elif isinstance(r.owner.op, T.Subtensor):
idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs)
input = inputs.pop()
sidxs = []
inbrack_pstate = pstate.clone(precedence = -1000)
for entry in idxs:
if isinstance(entry, int):
sidxs.append(str(entry))
elif isinstance(entry, S.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice):
sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start,
"" if entry.stop is None or entry.stop == sys.maxint else entry.stop,
"" if entry.step is None else ":%s" % entry.step))
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)),
", ".join(sidxs))
else:
raise TypeError("Can only print Subtensor.")
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, T.MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
class DefaultPrinter: class DefaultPrinter:
def __init__(self): def __init__(self):
...@@ -263,6 +202,16 @@ class PPrinter: ...@@ -263,6 +202,16 @@ class PPrinter:
strings.sort() strings.sort()
return "\n".join(s[1] for s in strings) return "\n".join(s[1] for s in strings)
def __call__(self, *args):
if len(args) == 1:
return self.process(*args)
elif len(args) == 2 and isinstance(args[1], (PrinterState, dict)):
return self.process(*args)
elif len(args) > 2:
return self.process_graph(*args)
else:
raise TypeError('Not enough arguments to call.')
...@@ -276,47 +225,10 @@ greek = dict(alpha = u"\u03B1", ...@@ -276,47 +225,10 @@ greek = dict(alpha = u"\u03B1",
epsilon = u"\u03B5") epsilon = u"\u03B5")
ppow = OperatorPrinter('**', 1, 'right') pprint = PPrinter()
pneg = OperatorPrinter('-', 0, 'either') pprint.assign(lambda pstate, r: True, DefaultPrinter())
pmul = OperatorPrinter('*', -1, 'either') pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None,
pdiv = OperatorPrinter('/', -1, 'left')
padd = OperatorPrinter('+', -2, 'either')
psub = OperatorPrinter('-', -2, 'left')
pdot = OperatorPrinter(special['middle_dot'], -1, 'left')
psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left')
from ..tensor import inplace as I
def pprinter():
pp = PPrinter()
pp.assign(lambda pstate, r: True, DefaultPrinter())
pp.assign(T.add, padd)
pp.assign(T.mul, pmul)
pp.assign(T.sub, psub)
pp.assign(T.neg, pneg)
pp.assign(T.div, pdiv)
pp.assign(T.pow, ppow)
pp.assign(T.dot, pdot)
pp.assign(T.Sum(), FunctionPrinter('sum'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.DimShuffle), DimShufflePrinter())
pp.assign(T.tensor_copy, IgnorePrinter())
pp.assign(T.log, FunctionPrinter('log'))
pp.assign(T.tanh, FunctionPrinter('tanh'))
pp.assign(I.transpose_inplace, MemberPrinter('T'))
pp.assign(T.abs_, PatternPrinter(('|%(0)s|', -1000)))
pp.assign(T.sgn, FunctionPrinter('sgn'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 0, FunctionPrinter('seros'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 1, FunctionPrinter('ones'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Subtensor), SubtensorPrinter())
pp.assign(T.shape, MemberPrinter('shape'))
pp.assign(T.fill, FunctionPrinter('fill'))
#pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter())
return pp
pp = pprinter()
pp2 = pprinter()
pp2.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None,
LeafPrinter()) LeafPrinter())
pp = pprint
...@@ -252,17 +252,26 @@ def upcast_out(*types): ...@@ -252,17 +252,26 @@ def upcast_out(*types):
return Scalar(dtype = Scalar.upcast(*types)), return Scalar(dtype = Scalar.upcast(*types)),
def same_out(type): def same_out(type):
return type, return type,
class transfer_type: class transfer_type(gof.utils.object2):
def __init__(self, i): def __init__(self, *transfer):
assert type(i) == int assert all(type(x) == int for x in transfer)
self.i = i self.transfer = transfer
def __call__(self, *types): def __call__(self, *types):
return types[self.i], upcast = upcast_out(*types)
class specific_out: return [upcast if i is None else types[i] for i in self.transfer]
def __eq__(self, other):
return type(self) == type(other) and self.transfer == other.transfer
def __hash__(self):
return hash(self.transfer)
class specific_out(gof.utils.object2):
def __init__(self, *spec): def __init__(self, *spec):
self.spec = spec self.spec = spec
def __call__(self, *types): def __call__(self, *types):
return self.spec return self.spec
def __eq__(self, other):
return type(self) == type(other) and self.spec == other.spec
def __hash__(self):
return hash(self.spec)
def int_out(*types): def int_out(*types):
return int64, return int64,
def float_out(*types): def float_out(*types):
...@@ -328,9 +337,10 @@ class ScalarOp(Op): ...@@ -328,9 +337,10 @@ class ScalarOp(Op):
raise AbstractFunctionError() raise AbstractFunctionError()
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) \ test = type(self) == type(other) \
and getattr(self, 'output_types_preference', None) \ and getattr(self, 'output_types_preference', None) \
== getattr(other, 'output_types_preference', None) == getattr(other, 'output_types_preference', None)
return test
def __hash__(self): def __hash__(self):
return hash(getattr(self, 'output_types_preference', 0)) return hash(getattr(self, 'output_types_preference', 0))
......
...@@ -20,7 +20,8 @@ import elemwise ...@@ -20,7 +20,8 @@ import elemwise
from .. import scalar as scal from .. import scalar as scal
from ..gof.python25 import partial from ..gof.python25 import partial
from .. import compile from .. import compile, printing
from ..printing import pprint
### set up the external interface ### set up the external interface
...@@ -614,6 +615,8 @@ def _scal_elemwise(symbol): ...@@ -614,6 +615,8 @@ def _scal_elemwise(symbol):
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = 'tensor' rval.__module__ = 'tensor'
pprint.assign(rval, printing.FunctionPrinter(symbolname))
return rval return rval
...@@ -661,33 +664,34 @@ def cast(t, dtype): ...@@ -661,33 +664,34 @@ def cast(t, dtype):
return mapping[dtype](t) return mapping[dtype](t)
#to be removed as we get the epydoc routine-documenting thing going -JB 20080924 #to be removed as we get the epydoc routine-documenting thing going -JB 20080924
def _conversion(real_value): def _conversion(real_value, name):
__oplist_tag(real_value, 'casting') __oplist_tag(real_value, 'casting')
real_value.__module__='tensor.basic' real_value.__module__='tensor.basic'
pprint.assign(real_value, printing.FunctionPrinter(name))
return real_value return real_value
convert_to_int8 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int8)))) convert_to_int8 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int8))), 'int8')
"""Cast to 8-bit integer""" """Cast to 8-bit integer"""
convert_to_int16 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int16)))) convert_to_int16 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int16))), 'int16')
"""Cast to 16-bit integer""" """Cast to 16-bit integer"""
convert_to_int32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int32)))) convert_to_int32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int32))), 'int32')
"""Cast to 32-bit integer""" """Cast to 32-bit integer"""
convert_to_int64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int64)))) convert_to_int64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int64))), 'int64')
"""Cast to 64-bit integer""" """Cast to 64-bit integer"""
convert_to_float32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float32)))) convert_to_float32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float32))), 'float32')
"""Cast to single-precision floating point""" """Cast to single-precision floating point"""
convert_to_float64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float64)))) convert_to_float64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float64))), 'float64')
"""Cast to double-precision floating point""" """Cast to double-precision floating point"""
convert_to_complex64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex64)))) convert_to_complex64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex64))), 'complex64')
"""Cast to single-precision complex""" """Cast to single-precision complex"""
convert_to_complex128 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex128)))) convert_to_complex128 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex128))), 'complex128')
"""Cast to double-precision complex""" """Cast to double-precision complex"""
...@@ -713,6 +717,9 @@ class Shape(Op): ...@@ -713,6 +717,9 @@ class Shape(Op):
def shape(a): def shape(a):
pass pass
pprint.assign(shape, printing.MemberPrinter('shape'))
class MaxAndArgmax(Op): class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis""" """Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis nin=2 # tensor, axis
...@@ -834,6 +841,9 @@ def abs_(a): ...@@ -834,6 +841,9 @@ def abs_(a):
""" """
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise @_scal_elemwise
def exp(a): def exp(a):
"""e^`a`""" """e^`a`"""
...@@ -902,6 +912,8 @@ def second(a, b): ...@@ -902,6 +912,8 @@ def second(a, b):
"""Create a matrix by filling the shape of a with b""" """Create a matrix by filling the shape of a with b"""
fill = second fill = second
pprint.assign(fill, printing.FunctionPrinter('fill'))
@constructor @constructor
def ones_like(model): def ones_like(model):
...@@ -967,10 +979,15 @@ def one(): ...@@ -967,10 +979,15 @@ def one():
"""WRITEME""" """WRITEME"""
return Ones(0)([]) return Ones(0)([])
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros'))
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones'))
@_redefine(elemwise.Elemwise(scal.identity)) @_redefine(elemwise.Elemwise(scal.identity))
def tensor_copy(a): def tensor_copy(a):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
pprint.assign(tensor_copy, printing.IgnorePrinter())
@_redefine(elemwise.Elemwise(scal.identity, inplace_pattern = {0: [0]})) @_redefine(elemwise.Elemwise(scal.identity, inplace_pattern = {0: [0]}))
def view(a): def view(a):
...@@ -981,6 +998,9 @@ def sum(input, axis = None): ...@@ -981,6 +998,9 @@ def sum(input, axis = None):
"""WRITEME""" """WRITEME"""
return elemwise.Sum(axis)(input) return elemwise.Sum(axis)(input)
pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor @constructor
def mean(input, axis = None): def mean(input, axis = None):
"""WRITEME""" """WRITEME"""
...@@ -1043,6 +1063,14 @@ def mod(a, b): ...@@ -1043,6 +1063,14 @@ def mod(a, b):
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either'))
pprint.assign(div, printing.OperatorPrinter('/', -1, 'left'))
pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
########################## ##########################
# View Operations # View Operations
...@@ -1214,6 +1242,36 @@ class Subtensor(Op): ...@@ -1214,6 +1242,36 @@ class Subtensor(Op):
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices)) return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
class SubtensorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print Subtensor.")
elif isinstance(r.owner.op, Subtensor):
idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs)
input = inputs.pop()
sidxs = []
inbrack_pstate = pstate.clone(precedence = -1000)
for entry in idxs:
if isinstance(entry, int):
sidxs.append(str(entry))
elif isinstance(entry, scal.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice):
sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start,
"" if entry.stop is None or entry.stop == sys.maxint else entry.stop,
"" if entry.step is None else ":%s" % entry.step))
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)),
", ".join(sidxs))
else:
raise TypeError("Can only print Subtensor.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter())
class SetSubtensor(Subtensor): class SetSubtensor(Subtensor):
"""WRITEME""" """WRITEME"""
view_map = {} view_map = {}
...@@ -1474,6 +1532,11 @@ def join(axis, *tensors): ...@@ -1474,6 +1532,11 @@ def join(axis, *tensors):
""" """
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join),
printing.FunctionPrinter('join'))
@constructor @constructor
def shape_padleft(tensor, n_ones): def shape_padleft(tensor, n_ones):
"""Reshape `tensor` by left-padding the shape with `n_ones` 1s """Reshape `tensor` by left-padding the shape with `n_ones` 1s
...@@ -1630,6 +1693,21 @@ make_lvector = MakeVector(lscalar) ...@@ -1630,6 +1693,21 @@ make_lvector = MakeVector(lscalar)
"""WRITEME""" """WRITEME"""
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
...@@ -1696,6 +1774,7 @@ class Dot(Op): ...@@ -1696,6 +1774,7 @@ class Dot(Op):
def __str__(self): def __str__(self):
return "dot" return "dot"
dot = Dot() dot = Dot()
pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], -1, 'left'))
class Outer(Op): class Outer(Op):
""" Compute vector-vector outer product """ Compute vector-vector outer product
...@@ -1964,6 +2043,8 @@ class Gemm(Op): ...@@ -1964,6 +2043,8 @@ class Gemm(Op):
gemm = Gemm() gemm = Gemm()
pprint.assign(gemm, printing.FunctionPrinter('gemm'))
######################### #########################
# Gradient # Gradient
......
...@@ -6,6 +6,8 @@ from .. import gof ...@@ -6,6 +6,8 @@ from .. import gof
from ..gof import Op, Apply from ..gof import Op, Apply
from .. import scalar from .. import scalar
from ..scalar import Scalar from ..scalar import Scalar
from .. import printing
from ..printing import pprint
from ..gof.python25 import all from ..gof.python25 import all
from copy import copy from copy import copy
...@@ -182,6 +184,31 @@ class DimShuffle(Op): ...@@ -182,6 +184,31 @@ class DimShuffle(Op):
return DimShuffle(gz.type.broadcastable, grad_order)(gz), return DimShuffle(gz.type.broadcastable, grad_order)(gz),
class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
# return "%s" % self.__p(new_order[1:], pstate, r)
return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), pstate.pprinter.process(r))
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
else:
raise TypeError("Can only print DimShuffle.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter())
################ ################
### Elemwise ### ### Elemwise ###
################ ################
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
from basic import _scal_elemwise, _transpose_inplace from basic import _scal_elemwise, _transpose_inplace
from .. import scalar as scal from .. import scalar as scal
import elemwise import elemwise
from .. import printing
from ..printing import pprint
def _scal_inplace(symbol): def _scal_inplace(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op""" """Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
...@@ -24,6 +26,16 @@ def _scal_inplace(symbol): ...@@ -24,6 +26,16 @@ def _scal_inplace(symbol):
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = 'theano.tensor.inplace' rval.__module__ = 'theano.tensor.inplace'
def chk(pstate, r):
if not r.owner:
return False
op = r.owner.op
# print op, rval, r.owner and op == rval
# print op.inplace_pattern, rval.inplace_pattern, op.inplace_pattern == rval.inplace_pattern
# print op.scalar_op, rval.scalar_op, op.scalar_op == rval.scalar_op
return r.owner.op == rval
pprint.assign(chk, printing.FunctionPrinter(symbolname.replace('_inplace', '=')))
return rval return rval
...@@ -132,6 +144,8 @@ def second_inplace(a): ...@@ -132,6 +144,8 @@ def second_inplace(a):
"""Fill `a` with `b`""" """Fill `a` with `b`"""
fill_inplace = second_inplace fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter('fill='))
@_scal_inplace @_scal_inplace
def add_inplace(a, b): def add_inplace(a, b):
...@@ -157,7 +171,17 @@ def mod_inplace(a, b): ...@@ -157,7 +171,17 @@ def mod_inplace(a, b):
def pow_inplace(a, b): def pow_inplace(a, b):
"""elementwise power (inplace on `a`)""" """elementwise power (inplace on `a`)"""
pprint.assign(add_inplace, printing.OperatorPrinter('+=', -2, 'either'))
pprint.assign(mul_inplace, printing.OperatorPrinter('*=', -1, 'either'))
pprint.assign(sub_inplace, printing.OperatorPrinter('-=', -2, 'left'))
pprint.assign(neg_inplace, printing.OperatorPrinter('-=', 0, 'either'))
pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
transpose_inplace = _transpose_inplace transpose_inplace = _transpose_inplace
"""WRITEME""" """WRITEME"""
pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
...@@ -31,7 +31,7 @@ def in2out(*local_opts): ...@@ -31,7 +31,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((I.sub_inplace, gemm_pattern_1 = gof.PatternSub((T.sub,
'd', 'd',
(T.mul, (T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'), dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
...@@ -77,7 +77,10 @@ def _insert_inplace_optimizer(env): ...@@ -77,7 +77,10 @@ def _insert_inplace_optimizer(env):
for candidate_input in candidate_inputs: for candidate_input in candidate_inputs:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input}) inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
try: try:
new = Elemwise(op.scalar_op, inplace_pattern).make_node(*node.inputs) new = Elemwise(
op.scalar_op.__class__(
scalar.transfer_type(*[inplace_pattern.get(i, None) for i in xrange(len(node.outputs))])),
inplace_pattern).make_node(*node.inputs)
env.replace_all_validate(zip(node.outputs, new.outputs)) env.replace_all_validate(zip(node.outputs, new.outputs))
except Exception, e: except Exception, e:
continue continue
...@@ -89,7 +92,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) ...@@ -89,7 +92,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer( inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1), gof.SeqOptimizer(out2in(gemm_pattern_1),
out2in(dot_to_gemm), #out2in(dot_to_gemm),
insert_inplace_optimizer, insert_inplace_optimizer,
failure_callback = gof.keep_going)) failure_callback = gof.keep_going))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run') compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
...@@ -537,7 +540,7 @@ def mul_calculate(num, denum, aslist = False): ...@@ -537,7 +540,7 @@ def mul_calculate(num, denum, aslist = False):
return [v] return [v]
return v return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate, False) local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
@gof.local_optimizer([T.neg]) @gof.local_optimizer([T.neg])
def local_neg_to_mul(node): def local_neg_to_mul(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论