提交 f1d61ed9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added pprint (as printing) to the theano package root

上级 d5ae91e1
......@@ -4,7 +4,8 @@ import theano
from theano import tensor as T
from theano.tensor import nnet_ops
from theano.compile import module
from theano.sandbox import pprint
from theano import printing, pprint
from theano import compile
import numpy as N
......@@ -17,6 +18,7 @@ class LogisticRegressionN(module.FancyModule):
self.w = N.random.randn(n_in, n_out)
self.b = N.random.randn(n_out)
self.lr = 0.01
self.__hide__ = ['params']
def __init__(self, x = None, targ = None):
super(LogisticRegressionN, self).__init__() #boilerplate
......@@ -84,8 +86,8 @@ class LogisticRegression2(module.FancyModule):
if __name__ == '__main__':
pprint.pp.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, pprint.FunctionPrinter('xsoftmaxdx'))
pprint.pp.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, pprint.FunctionPrinter('nll', 'softmax', 'argmax'))
pprint.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, printing.FunctionPrinter('xsoftmaxdx'))
pprint.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, printing.FunctionPrinter('nll', 'softmax', 'argmax'))
if 1:
lrc = LogisticRegressionN()
......@@ -94,17 +96,21 @@ if __name__ == '__main__':
print '================'
print lrc.update.pretty(mode = theano.Mode('py', 'fast_run'))
print '================'
# print lrc.update.pretty(mode = compile.FAST_RUN.excluding('inplace'))
# print '================'
# sys.exit(0)
lr = lrc.make(10, 2, mode=theano.Mode('c|py', 'fast_run'))
#lr = lrc.make(10, 2, mode=compile.FAST_RUN.excluding('fast_run'))
#lr = lrc.make(10, 2, mode=theano.Mode('py', 'merge')) #'FAST_RUN')
data_x = N.random.randn(5, 10)
data_y = (N.random.randn(5) > 0)
for i in xrange(10000):
xe = lr.update(data_x, data_y)
lr.lr = 0.02
xe = lr.update(data_x, data_y)
if i % 100 == 0:
print i, xe
......
......@@ -41,13 +41,19 @@ from compile import \
SymbolicOutput, Out, \
Mode, \
predefined_modes, predefined_linkers, predefined_optimizers, \
FunctionMaker, function, OpFromGraph #, eval_outputs, fast_compute
FunctionMaker, function, OpFromGraph, \
Component, External, Member, KitComponent, Method, \
Composite, ComponentList, Module, FancyModule
from printing import \
pprint, pp
import tensor
import scalar
import sparse
import gradient
## import scalar_opt
import subprocess as _subprocess
......
......@@ -102,12 +102,19 @@ class Mode(object):
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, gof.Query):
self.provided_optimizer = optimizer
optimizer = optdb.query(optimizer)
self.optimizer = optimizer
self._optimizer = optimizer
def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer)
def __get_optimizer(self):
if isinstance(self._optimizer, gof.Query):
return optdb.query(self._optimizer)
else:
return self._optimizer
optimizer = property(__get_optimizer)
def including(self, *tags):
return Mode(self.provided_linker, self.provided_optimizer.including(*tags))
......
from .. import gof
from ..printing import pprint
from collections import defaultdict
from itertools import chain
from functools import partial
from copy import copy
import mode
import io
import function_module as F
#from ..sandbox import pprint
def join(*args):
......@@ -117,7 +117,7 @@ class External(_RComponent):
def pretty(self, **kwargs):
rval = super(External, self).pretty()
if self.r.owner:
rval += '\n= %s' % (pprint.pp2.process(self.r, dict(target = self.r)))
rval += '\n= %s' % (pprint(self.r, dict(target = self.r)))
return rval
......@@ -196,13 +196,15 @@ class Method(Component):
else:
return gof.Container(r, storage = [None])
inputs = self.inputs
inputs = [mode.In(result = input,
value = get_storage(input))
inputs = [io.In(result = input,
value = get_storage(input),
mutable = False)
for input in inputs]
inputs += [mode.In(result = k,
update = v,
value = get_storage(k, True),
strict = True)
inputs += [io.In(result = k,
update = v,
value = get_storage(k, True),
mutable = True,
strict = True)
for k, v in self.updates.iteritems()]
outputs = self.outputs
_inputs = [x.result for x in inputs]
......@@ -210,8 +212,9 @@ class Method(Component):
+ [x.update for x in inputs if getattr(x, 'update', False)],
blockers = _inputs):
if input not in _inputs and not isinstance(input, gof.Value):
inputs += [mode.In(result = input,
value = get_storage(input, True))]
inputs += [io.In(result = input,
value = get_storage(input, True),
mutable = False)]
inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return F.function(inputs, outputs, mode)
......@@ -234,11 +237,14 @@ class Method(Component):
nup = len(k)
eff_in = tuple(inputs) + tuple(k)
eff_out = tuple(outputs) + tuple(v)
env = gof.Env(*gof.graph.clone(eff_in + tuple(gof.graph.inputs(eff_out)),
supp_in = tuple(gof.graph.inputs(eff_out))
env = gof.Env(*gof.graph.clone(eff_in + supp_in,
eff_out))
sup = F.Supervisor(set(env.inputs).difference(env.inputs[len(inputs):len(eff_in)]))
env.extend(sup)
mode.optimizer.optimize(env)
inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:]))
rval += pprint.pp.process_graph(inputs, outputs, updates, False)
rval += pprint(inputs, outputs, updates, False)
return rval
def __str__(self):
......
......@@ -101,7 +101,7 @@ class ReplaceValidate(History, Validator):
try:
env.replace(r, new_r)
except Exception, e:
if not 'The type of the replacement must be the same' in str(e) or not 'does not belong to this Env' in str(e):
if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise
......
from .. import tensor as T
from .. import scalar as S
from .. import gof
import gof
from copy import copy
import sys
......@@ -88,7 +86,7 @@ class FunctionPrinter:
raise TypeError("function %s cannot represent a result with no associated operation" % self.names)
idx = node.outputs.index(output)
name = self.names[idx]
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = 1000))
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = -1000))
for input in node.inputs]))
class MemberPrinter:
......@@ -119,65 +117,6 @@ class IgnorePrinter:
return "%s" % pprinter.process(input, pstate)
class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
# return "%s" % self.__p(new_order[1:], pstate, r)
return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), pstate.pprinter.process(r))
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, T.DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
else:
raise TypeError("Can only print DimShuffle.")
class SubtensorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print Subtensor.")
elif isinstance(r.owner.op, T.Subtensor):
idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs)
input = inputs.pop()
sidxs = []
inbrack_pstate = pstate.clone(precedence = -1000)
for entry in idxs:
if isinstance(entry, int):
sidxs.append(str(entry))
elif isinstance(entry, S.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice):
sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start,
"" if entry.stop is None or entry.stop == sys.maxint else entry.stop,
"" if entry.step is None else ":%s" % entry.step))
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)),
", ".join(sidxs))
else:
raise TypeError("Can only print Subtensor.")
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, T.MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
class DefaultPrinter:
def __init__(self):
......@@ -263,6 +202,16 @@ class PPrinter:
strings.sort()
return "\n".join(s[1] for s in strings)
def __call__(self, *args):
if len(args) == 1:
return self.process(*args)
elif len(args) == 2 and isinstance(args[1], (PrinterState, dict)):
return self.process(*args)
elif len(args) > 2:
return self.process_graph(*args)
else:
raise TypeError('Not enough arguments to call.')
......@@ -276,47 +225,10 @@ greek = dict(alpha = u"\u03B1",
epsilon = u"\u03B5")
ppow = OperatorPrinter('**', 1, 'right')
pneg = OperatorPrinter('-', 0, 'either')
pmul = OperatorPrinter('*', -1, 'either')
pdiv = OperatorPrinter('/', -1, 'left')
padd = OperatorPrinter('+', -2, 'either')
psub = OperatorPrinter('-', -2, 'left')
pdot = OperatorPrinter(special['middle_dot'], -1, 'left')
psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left')
from ..tensor import inplace as I
def pprinter():
pp = PPrinter()
pp.assign(lambda pstate, r: True, DefaultPrinter())
pp.assign(T.add, padd)
pp.assign(T.mul, pmul)
pp.assign(T.sub, psub)
pp.assign(T.neg, pneg)
pp.assign(T.div, pdiv)
pp.assign(T.pow, ppow)
pp.assign(T.dot, pdot)
pp.assign(T.Sum(), FunctionPrinter('sum'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.DimShuffle), DimShufflePrinter())
pp.assign(T.tensor_copy, IgnorePrinter())
pp.assign(T.log, FunctionPrinter('log'))
pp.assign(T.tanh, FunctionPrinter('tanh'))
pp.assign(I.transpose_inplace, MemberPrinter('T'))
pp.assign(T.abs_, PatternPrinter(('|%(0)s|', -1000)))
pp.assign(T.sgn, FunctionPrinter('sgn'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 0, FunctionPrinter('seros'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 1, FunctionPrinter('ones'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Subtensor), SubtensorPrinter())
pp.assign(T.shape, MemberPrinter('shape'))
pp.assign(T.fill, FunctionPrinter('fill'))
#pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter())
return pp
pp = pprinter()
pp2 = pprinter()
pp2.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None,
LeafPrinter())
pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter())
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None,
LeafPrinter())
pp = pprint
......@@ -252,17 +252,26 @@ def upcast_out(*types):
return Scalar(dtype = Scalar.upcast(*types)),
def same_out(type):
return type,
class transfer_type:
def __init__(self, i):
assert type(i) == int
self.i = i
class transfer_type(gof.utils.object2):
def __init__(self, *transfer):
assert all(type(x) == int for x in transfer)
self.transfer = transfer
def __call__(self, *types):
return types[self.i],
class specific_out:
upcast = upcast_out(*types)
return [upcast if i is None else types[i] for i in self.transfer]
def __eq__(self, other):
return type(self) == type(other) and self.transfer == other.transfer
def __hash__(self):
return hash(self.transfer)
class specific_out(gof.utils.object2):
def __init__(self, *spec):
self.spec = spec
def __call__(self, *types):
return self.spec
def __eq__(self, other):
return type(self) == type(other) and self.spec == other.spec
def __hash__(self):
return hash(self.spec)
def int_out(*types):
return int64,
def float_out(*types):
......@@ -328,9 +337,10 @@ class ScalarOp(Op):
raise AbstractFunctionError()
def __eq__(self, other):
return type(self) == type(other) \
test = type(self) == type(other) \
and getattr(self, 'output_types_preference', None) \
== getattr(other, 'output_types_preference', None)
return test
def __hash__(self):
return hash(getattr(self, 'output_types_preference', 0))
......
......@@ -20,7 +20,8 @@ import elemwise
from .. import scalar as scal
from ..gof.python25 import partial
from .. import compile
from .. import compile, printing
from ..printing import pprint
### set up the external interface
......@@ -614,6 +615,8 @@ def _scal_elemwise(symbol):
rval.__epydoc_asRoutine = symbol
rval.__module__ = 'tensor'
pprint.assign(rval, printing.FunctionPrinter(symbolname))
return rval
......@@ -661,33 +664,34 @@ def cast(t, dtype):
return mapping[dtype](t)
#to be removed as we get the epydoc routine-documenting thing going -JB 20080924
def _conversion(real_value):
def _conversion(real_value, name):
__oplist_tag(real_value, 'casting')
real_value.__module__='tensor.basic'
pprint.assign(real_value, printing.FunctionPrinter(name))
return real_value
convert_to_int8 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int8))))
convert_to_int8 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int8))), 'int8')
"""Cast to 8-bit integer"""
convert_to_int16 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int16))))
convert_to_int16 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int16))), 'int16')
"""Cast to 16-bit integer"""
convert_to_int32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int32))))
convert_to_int32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int32))), 'int32')
"""Cast to 32-bit integer"""
convert_to_int64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int64))))
convert_to_int64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.int64))), 'int64')
"""Cast to 64-bit integer"""
convert_to_float32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float32))))
convert_to_float32 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float32))), 'float32')
"""Cast to single-precision floating point"""
convert_to_float64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float64))))
convert_to_float64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.float64))), 'float64')
"""Cast to double-precision floating point"""
convert_to_complex64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex64))))
convert_to_complex64 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex64))), 'complex64')
"""Cast to single-precision complex"""
convert_to_complex128 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex128))))
convert_to_complex128 = _conversion(elemwise.Elemwise(scal.Identity(scal.specific_out(scal.complex128))), 'complex128')
"""Cast to double-precision complex"""
......@@ -713,6 +717,9 @@ class Shape(Op):
def shape(a):
pass
pprint.assign(shape, printing.MemberPrinter('shape'))
class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis"""
nin=2 # tensor, axis
......@@ -834,6 +841,9 @@ def abs_(a):
"""
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise
def exp(a):
"""e^`a`"""
......@@ -902,6 +912,8 @@ def second(a, b):
"""Create a matrix by filling the shape of a with b"""
fill = second
pprint.assign(fill, printing.FunctionPrinter('fill'))
@constructor
def ones_like(model):
......@@ -967,10 +979,15 @@ def one():
"""WRITEME"""
return Ones(0)([])
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 0, printing.FunctionPrinter('zeros'))
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Filler) and r.owner.op.value == 1, printing.FunctionPrinter('ones'))
@_redefine(elemwise.Elemwise(scal.identity))
def tensor_copy(a):
"""Create a duplicate of `a` (with duplicated storage)"""
pprint.assign(tensor_copy, printing.IgnorePrinter())
@_redefine(elemwise.Elemwise(scal.identity, inplace_pattern = {0: [0]}))
def view(a):
......@@ -981,6 +998,9 @@ def sum(input, axis = None):
"""WRITEME"""
return elemwise.Sum(axis)(input)
pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def mean(input, axis = None):
"""WRITEME"""
......@@ -1043,6 +1063,14 @@ def mod(a, b):
def pow(a, b):
"""elementwise power"""
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either'))
pprint.assign(div, printing.OperatorPrinter('/', -1, 'left'))
pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
##########################
# View Operations
......@@ -1214,6 +1242,36 @@ class Subtensor(Op):
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
class SubtensorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print Subtensor.")
elif isinstance(r.owner.op, Subtensor):
idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs)
input = inputs.pop()
sidxs = []
inbrack_pstate = pstate.clone(precedence = -1000)
for entry in idxs:
if isinstance(entry, int):
sidxs.append(str(entry))
elif isinstance(entry, scal.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice):
sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start,
"" if entry.stop is None or entry.stop == sys.maxint else entry.stop,
"" if entry.step is None else ":%s" % entry.step))
return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)),
", ".join(sidxs))
else:
raise TypeError("Can only print Subtensor.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter())
class SetSubtensor(Subtensor):
"""WRITEME"""
view_map = {}
......@@ -1474,6 +1532,11 @@ def join(axis, *tensors):
"""
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Join),
printing.FunctionPrinter('join'))
@constructor
def shape_padleft(tensor, n_ones):
"""Reshape `tensor` by left-padding the shape with `n_ones` 1s
......@@ -1630,6 +1693,21 @@ make_lvector = MakeVector(lscalar)
"""WRITEME"""
class MakeVectorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else:
raise TypeError("Can only print make_vector.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, MakeVector), MakeVectorPrinter())
#########################
# Linalg : Dot
#########################
......@@ -1696,6 +1774,7 @@ class Dot(Op):
def __str__(self):
return "dot"
dot = Dot()
pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], -1, 'left'))
class Outer(Op):
""" Compute vector-vector outer product
......@@ -1964,6 +2043,8 @@ class Gemm(Op):
gemm = Gemm()
pprint.assign(gemm, printing.FunctionPrinter('gemm'))
#########################
# Gradient
......
......@@ -6,6 +6,8 @@ from .. import gof
from ..gof import Op, Apply
from .. import scalar
from ..scalar import Scalar
from .. import printing
from ..printing import pprint
from ..gof.python25 import all
from copy import copy
......@@ -182,6 +184,31 @@ class DimShuffle(Op):
return DimShuffle(gz.type.broadcastable, grad_order)(gz),
class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
# return "%s" % self.__p(new_order[1:], pstate, r)
return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), pstate.pprinter.process(r))
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
else:
raise TypeError("Can only print DimShuffle.")
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter())
################
### Elemwise ###
################
......
......@@ -2,6 +2,8 @@
from basic import _scal_elemwise, _transpose_inplace
from .. import scalar as scal
import elemwise
from .. import printing
from ..printing import pprint
def _scal_inplace(symbol):
"""Replace a symbol definition with an elementwise version of the corresponding scalar Op"""
......@@ -24,6 +26,16 @@ def _scal_inplace(symbol):
rval.__epydoc_asRoutine = symbol
rval.__module__ = 'theano.tensor.inplace'
def chk(pstate, r):
if not r.owner:
return False
op = r.owner.op
# print op, rval, r.owner and op == rval
# print op.inplace_pattern, rval.inplace_pattern, op.inplace_pattern == rval.inplace_pattern
# print op.scalar_op, rval.scalar_op, op.scalar_op == rval.scalar_op
return r.owner.op == rval
pprint.assign(chk, printing.FunctionPrinter(symbolname.replace('_inplace', '=')))
return rval
......@@ -132,6 +144,8 @@ def second_inplace(a):
"""Fill `a` with `b`"""
fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter('fill='))
@_scal_inplace
def add_inplace(a, b):
......@@ -157,7 +171,17 @@ def mod_inplace(a, b):
def pow_inplace(a, b):
"""elementwise power (inplace on `a`)"""
pprint.assign(add_inplace, printing.OperatorPrinter('+=', -2, 'either'))
pprint.assign(mul_inplace, printing.OperatorPrinter('*=', -1, 'either'))
pprint.assign(sub_inplace, printing.OperatorPrinter('-=', -2, 'left'))
pprint.assign(neg_inplace, printing.OperatorPrinter('-=', 0, 'either'))
pprint.assign(div_inplace, printing.OperatorPrinter('/=', -1, 'left'))
pprint.assign(pow_inplace, printing.OperatorPrinter('**=', 1, 'right'))
transpose_inplace = _transpose_inplace
"""WRITEME"""
pprint.assign(transpose_inplace, printing.MemberPrinter('T'))
......@@ -31,7 +31,7 @@ def in2out(*local_opts):
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((I.sub_inplace,
gemm_pattern_1 = gof.PatternSub((T.sub,
'd',
(T.mul,
dict(pattern = (T.DimShuffle((), ['x', 'x'], inplace = True), 'a'),
......@@ -77,7 +77,10 @@ def _insert_inplace_optimizer(env):
for candidate_input in candidate_inputs:
inplace_pattern = dict(baseline, **{candidate_output: candidate_input})
try:
new = Elemwise(op.scalar_op, inplace_pattern).make_node(*node.inputs)
new = Elemwise(
op.scalar_op.__class__(
scalar.transfer_type(*[inplace_pattern.get(i, None) for i in xrange(len(node.outputs))])),
inplace_pattern).make_node(*node.inputs)
env.replace_all_validate(zip(node.outputs, new.outputs))
except Exception, e:
continue
......@@ -89,7 +92,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1),
out2in(dot_to_gemm),
#out2in(dot_to_gemm),
insert_inplace_optimizer,
failure_callback = gof.keep_going))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
......@@ -537,7 +540,7 @@ def mul_calculate(num, denum, aslist = False):
return [v]
return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate, False)
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
@gof.local_optimizer([T.neg])
def local_neg_to_mul(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论