提交 950634de authored 作者: Olivier Breuleux's avatar Olivier Breuleux

Deployed optdb

上级 32cce9b4
...@@ -4,6 +4,7 @@ import theano ...@@ -4,6 +4,7 @@ import theano
from theano import tensor as T from theano import tensor as T
from theano.sandbox import nnet_ops from theano.sandbox import nnet_ops
from theano.sandbox import module from theano.sandbox import module
from theano.sandbox import pprint
import numpy as N import numpy as N
...@@ -32,6 +33,10 @@ class LogisticRegressionN(module.FancyModule): ...@@ -32,6 +33,10 @@ class LogisticRegressionN(module.FancyModule):
xent, y = nnet_ops.crossentropy_softmax_1hot( xent, y = nnet_ops.crossentropy_softmax_1hot(
T.dot(self.x, self.w) + self.b, self.targ) T.dot(self.x, self.w) + self.b, self.targ)
xent = T.sum(xent)
self.y = y
self.xent = xent
gparams = T.grad(xent, self.params) gparams = T.grad(xent, self.params)
...@@ -75,15 +80,25 @@ class LogisticRegression2(module.FancyModule): ...@@ -75,15 +80,25 @@ class LogisticRegression2(module.FancyModule):
self.update = module.Method([self.x, self.targ], xent, self.update = module.Method([self.x, self.targ], xent,
updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams))) updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))
self.apply = module.Method([self.x], T.argmax(T.dot(self.x, self.w) + self.b, axis=1)) self.apply = module.Method([self.x], T.argmax(T.dot(self.x, self.w) + self.b, axis=1))
if __name__ == '__main__': if __name__ == '__main__':
pprint.pp.assign(nnet_ops.crossentropy_softmax_1hot_with_bias_dx, pprint.FunctionPrinter('xsoftmaxdx'))
pprint.pp.assign(nnet_ops.crossentropy_softmax_argmax_1hot_with_bias, pprint.FunctionPrinter('nll', 'softmax', 'argmax'))
if 1: if 1:
lrc = LogisticRegressionN() lrc = LogisticRegressionN()
#lr = lrc.make(10, 2, mode='FAST_RUN') print '================'
lr = lrc.make(10, 2, mode=theano.Mode('c|py', 'merge')) #'FAST_RUN') print lrc.update.pretty()
print '================'
print lrc.update.pretty(mode = theano.Mode('py', 'fast_run'))
print '================'
# sys.exit(0)
lr = lrc.make(10, 2, mode=theano.Mode('py', 'fast_run'))
#lr = lrc.make(10, 2, mode=theano.Mode('py', 'merge')) #'FAST_RUN')
data_x = N.random.randn(5, 10) data_x = N.random.randn(5, 10)
data_y = (N.random.randn(5) > 0) data_y = (N.random.randn(5) > 0)
......
...@@ -53,7 +53,7 @@ class Supervisor: ...@@ -53,7 +53,7 @@ class Supervisor:
return True return True
for r in self.protected + list(env.outputs): for r in self.protected + list(env.outputs):
if env.destroyers(r): if env.destroyers(r):
raise gof.InconsistencyError("Trying to destroy a protected Result.") raise gof.InconsistencyError("Trying to destroy a protected Result.", r)
def std_env(input_specs, output_specs, accept_inplace = False): def std_env(input_specs, output_specs, accept_inplace = False):
...@@ -88,7 +88,7 @@ def std_env(input_specs, output_specs, accept_inplace = False): ...@@ -88,7 +88,7 @@ def std_env(input_specs, output_specs, accept_inplace = False):
break break
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not spec.mutable)) env.extend(Supervisor(input for spec, input in zip(input_specs, inputs) if not (spec.mutable or (hasattr(env, 'destroyers') and env.destroyers(input)))))
return env, map(SymbolicOutput, updates) return env, map(SymbolicOutput, updates)
......
...@@ -16,6 +16,9 @@ def check_equal(x, y): ...@@ -16,6 +16,9 @@ def check_equal(x, y):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
# If a string is passed as the linker argument in the constructor for # If a string is passed as the linker argument in the constructor for
# Mode, it will be used as the key to retrieve the real linker in this # Mode, it will be used as the key to retrieve the real linker in this
# dictionary # dictionary
...@@ -35,12 +38,22 @@ def register_linker(name, linker): ...@@ -35,12 +38,22 @@ def register_linker(name, linker):
predefined_linkers[name] = linker predefined_linkers[name] = linker
# If a string is passed as the optimizer argument in the constructor # If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer # for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary # in this dictionary
OPT_FAST_RUN = gof.Query(include = ['fast_run'])
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable')
OPT_FAST_COMPILE = gof.Query(include = ['fast_compile'])
predefined_optimizers = { predefined_optimizers = {
None : lambda env: None, None : lambda env: None,
'merge' : gof.MergeOptimizer(), 'merge' : gof.MergeOptimizer(),
'fast_run' : OPT_FAST_RUN,
'fast_run_stable' : OPT_FAST_RUN_STABLE,
'fast_compile' : OPT_FAST_COMPILE
} }
default_optimizer = 'merge' default_optimizer = 'merge'
...@@ -50,6 +63,12 @@ def register_optimizer(name, opt): ...@@ -50,6 +63,12 @@ def register_optimizer(name, opt):
raise ValueError('Optimizer name already taken: %s' % name) raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
optdb = gof.SequenceDB()
optdb.register('merge1', gof.MergeOptimizer(), 0, 'fast_run', 'fast_compile')
optdb.register('canonicalize', gof.EquilibriumDB(), 1, 'fast_run')
optdb.register('specialize', gof.EquilibriumDB(), 2, 'fast_run')
optdb.register('merge2', gof.EquilibriumDB(), 100, 'fast_run')
class Mode(object): class Mode(object):
""" """
...@@ -81,15 +100,32 @@ class Mode(object): ...@@ -81,15 +100,32 @@ class Mode(object):
self.linker = linker self.linker = linker
if isinstance(optimizer, str) or optimizer is None: if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, gof.Query):
self.provided_optimizer = optimizer
optimizer = optdb.query(optimizer)
self.optimizer = optimizer self.optimizer = optimizer
def __str__(self): def __str__(self):
return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer) return "Mode(linker = %s, optimizer = %s)" % (self.provided_linker, self.provided_optimizer)
def including(self, *tags):
return Mode(self.provided_linker, self.provided_optimizer.including(*tags))
def excluding(self, *tags):
return Mode(self.provided_linker, self.provided_optimizer.excluding(*tags))
def requiring(self, *tags):
return Mode(self.provided_linker, self.provided_optimizer.requiring(*tags))
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the # FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key # string as the key
predefined_modes = {'FAST_COMPILE': Mode('py', 'merge')}
FAST_COMPILE = Mode('py', 'fast_compile')
FAST_RUN = Mode('c|py', 'fast_run')
predefined_modes = {'FAST_COMPILE': FAST_COMPILE,
'FAST_RUN': FAST_RUN}
default_mode = 'FAST_COMPILE' default_mode = 'FAST_COMPILE'
def register_mode(name, mode): def register_mode(name, mode):
......
...@@ -70,7 +70,7 @@ class Component(object): ...@@ -70,7 +70,7 @@ class Component(object):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def pretty(self): def pretty(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def __get_name__(self): def __get_name__(self):
...@@ -99,7 +99,7 @@ class _RComponent(Component): ...@@ -99,7 +99,7 @@ class _RComponent(Component):
def __str__(self): def __str__(self):
return "%s(%s)" % (self.__class__.__name__, self.r) return "%s(%s)" % (self.__class__.__name__, self.r)
def pretty(self): def pretty(self, **kwargs):
rval = '%s :: %s' % (self.__class__.__name__, self.r.type) rval = '%s :: %s' % (self.__class__.__name__, self.r.type)
return rval return rval
...@@ -114,7 +114,7 @@ class External(_RComponent): ...@@ -114,7 +114,7 @@ class External(_RComponent):
def build(self, mode, memo): def build(self, mode, memo):
return None return None
def pretty(self): def pretty(self, **kwargs):
rval = super(External, self).pretty() rval = super(External, self).pretty()
if self.r.owner: if self.r.owner:
rval += '\n= %s' % (pprint.pp2.process(self.r, dict(target = self.r))) rval += '\n= %s' % (pprint.pp2.process(self.r, dict(target = self.r)))
...@@ -216,7 +216,7 @@ class Method(Component): ...@@ -216,7 +216,7 @@ class Method(Component):
inputs += [(kit, get_storage(kit, True)) for kit in self.kits] inputs += [(kit, get_storage(kit, True)) for kit in self.kits]
return compile.function(inputs, outputs, mode) return compile.function(inputs, outputs, mode)
def pretty(self, header = True, **kwargs): def pretty(self, **kwargs):
self.resolve_all() self.resolve_all()
# cr = '\n ' if header else '\n' # cr = '\n ' if header else '\n'
# rval = '' # rval = ''
...@@ -226,7 +226,20 @@ class Method(Component): ...@@ -226,7 +226,20 @@ class Method(Component):
rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs)) rval = 'inputs: %s\n' % ", ".join(map(str, self.inputs))
else: else:
rval = '' rval = ''
rval += pprint.pp.process_graph(self.inputs, self.outputs, self.updates, False) mode = kwargs.pop('mode', None)
inputs, outputs, updates = self.inputs, self.outputs if isinstance(self.outputs, (list, tuple)) else [self.outputs], self.updates
if mode:
nin = len(inputs)
nout = len(outputs)
k, v = zip(*updates.items()) if updates else ((), ())
nup = len(k)
eff_in = tuple(inputs) + tuple(k)
eff_out = tuple(outputs) + tuple(v)
env = gof.Env(*gof.graph.clone(eff_in + tuple(gof.graph.inputs(eff_out)),
eff_out))
mode.optimizer.optimize(env)
inputs, outputs, updates = env.inputs[:nin], env.outputs[:nout], dict(zip(env.inputs[nin:], env.outputs[nout:]))
rval += pprint.pp.process_graph(inputs, outputs, updates, False)
return rval return rval
def __str__(self): def __str__(self):
...@@ -395,13 +408,13 @@ class ComponentList(Composite): ...@@ -395,13 +408,13 @@ class ComponentList(Composite):
def __str__(self): def __str__(self):
return str(self._components) return str(self._components)
def pretty(self, header = True, **kwargs): def pretty(self, **kwargs):
cr = '\n ' #if header else '\n' cr = '\n ' #if header else '\n'
strings = [] strings = []
#if header: #if header:
# rval += "ComponentList:" # rval += "ComponentList:"
for i, c in self.components_map(): for i, c in self.components_map():
strings.append('%i:%s%s' % (i, cr, c.pretty().replace('\n', cr))) strings.append('%i:%s%s' % (i, cr, c.pretty(**kwargs).replace('\n', cr)))
#rval += cr + '%i -> %s' % (i, c.pretty(header = True, **kwargs).replace('\n', cr)) #rval += cr + '%i -> %s' % (i, c.pretty(header = True, **kwargs).replace('\n', cr))
return '\n'.join(strings) return '\n'.join(strings)
...@@ -469,7 +482,7 @@ class Module(Composite): ...@@ -469,7 +482,7 @@ class Module(Composite):
value.bind(self, item) value.bind(self, item)
self._components[item] = value self._components[item] = value
def pretty(self, header = True, **kwargs): def pretty(self, **kwargs):
cr = '\n ' #if header else '\n' cr = '\n ' #if header else '\n'
strings = [] strings = []
# if header: # if header:
...@@ -477,7 +490,7 @@ class Module(Composite): ...@@ -477,7 +490,7 @@ class Module(Composite):
for name, component in self.components_map(): for name, component in self.components_map():
if name.startswith('_'): if name.startswith('_'):
continue continue
strings.append('%s:%s%s' % (name, cr, component.pretty().replace('\n', cr))) strings.append('%s:%s%s' % (name, cr, component.pretty(**kwargs).replace('\n', cr)))
strings.sort() strings.sort()
return '\n'.join(strings) return '\n'.join(strings)
......
...@@ -416,7 +416,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(theano.Op): ...@@ -416,7 +416,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(theano.Op):
if g_sm is not None or g_am is not None: if g_sm is not None or g_am is not None:
raise NotImplementedError() raise NotImplementedError()
nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx) nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx)
dx = CrossentropySoftmax1HotWithBiasDx()(g_nll, sm, y_idx) #dx = CrossentropySoftmax1HotWithBiasDx()(g_nll, sm, y_idx)
dx = crossentropy_softmax_1hot_with_bias_dx(g_nll, sm, y_idx)
db = tensor.sum(dx, axis = [0]) db = tensor.sum(dx, axis = [0])
return dx, db, None return dx, db, None
...@@ -597,6 +598,9 @@ class CrossentropySoftmax1HotWithBiasDx (theano.Op): ...@@ -597,6 +598,9 @@ class CrossentropySoftmax1HotWithBiasDx (theano.Op):
crossentropy_softmax_argmax_1hot_with_bias = \ crossentropy_softmax_argmax_1hot_with_bias = \
CrossentropySoftmaxArgmax1HotWithBias() CrossentropySoftmaxArgmax1HotWithBias()
crossentropy_softmax_1hot_with_bias_dx = \
CrossentropySoftmax1HotWithBiasDx()
def crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs): def crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs):
return crossentropy_softmax_argmax_1hot_with_bias(x, b, y_idx, **kwargs)[0:2] return crossentropy_softmax_argmax_1hot_with_bias(x, b, y_idx, **kwargs)[0:2]
......
...@@ -88,7 +88,7 @@ class FunctionPrinter: ...@@ -88,7 +88,7 @@ class FunctionPrinter:
raise TypeError("function %s cannot represent a result with no associated operation" % self.names) raise TypeError("function %s cannot represent a result with no associated operation" % self.names)
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = -1000)) return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = 1000))
for input in node.inputs])) for input in node.inputs]))
class MemberPrinter: class MemberPrinter:
...@@ -123,8 +123,8 @@ class DimShufflePrinter: ...@@ -123,8 +123,8 @@ class DimShufflePrinter:
def __p(self, new_order, pstate, r): def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x': if new_order != () and new_order[0] == 'x':
return "%s" % self.__p(new_order[1:], pstate, r) # return "%s" % self.__p(new_order[1:], pstate, r)
# return "[%s]" % self.__p(new_order[1:], pstate, r) return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim): if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r) return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))): if list(new_order) == list(reversed(range(r.type.ndim))):
...@@ -161,7 +161,7 @@ class SubtensorPrinter: ...@@ -161,7 +161,7 @@ class SubtensorPrinter:
sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start, sidxs.append("%s:%s%s" % ("" if entry.start is None or entry.start == 0 else entry.start,
"" if entry.stop is None or entry.stop == sys.maxint else entry.stop, "" if entry.stop is None or entry.stop == sys.maxint else entry.stop,
"" if entry.step is None else ":%s" % entry.step)) "" if entry.step is None else ":%s" % entry.step))
return "%s[%s]" % (pstate.clone(precedence = 1000).pprinter.process(input), return "%s[%s]" % (pstate.pprinter.process(input, pstate.clone(precedence = 1000)),
", ".join(sidxs)) ", ".join(sidxs))
else: else:
raise TypeError("Can only print Subtensor.") raise TypeError("Can only print Subtensor.")
...@@ -173,7 +173,7 @@ class MakeVectorPrinter: ...@@ -173,7 +173,7 @@ class MakeVectorPrinter:
if r.owner is None: if r.owner is None:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, T.MakeVector): elif isinstance(r.owner.op, T.MakeVector):
return "[%s]" % ", ".join(pstate.clone(precedence = 1000).pprinter.process(input) for input in r.owner.inputs) return "[%s]" % ", ".join(pstate.pprinter.process(input, pstate.clone(precedence = 1000)) for input in r.owner.inputs)
else: else:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
...@@ -311,7 +311,7 @@ def pprinter(): ...@@ -311,7 +311,7 @@ def pprinter():
pp.assign(T.shape, MemberPrinter('shape')) pp.assign(T.shape, MemberPrinter('shape'))
pp.assign(T.fill, FunctionPrinter('fill')) pp.assign(T.fill, FunctionPrinter('fill'))
#pp.assign(T.vertical_stack, FunctionPrinter('vstack')) #pp.assign(T.vertical_stack, FunctionPrinter('vstack'))
#pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter()) pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.MakeVector), MakeVectorPrinter())
return pp return pp
pp = pprinter() pp = pprinter()
......
...@@ -1198,8 +1198,11 @@ class Subtensor(Op): ...@@ -1198,8 +1198,11 @@ class Subtensor(Op):
return type(self) == type(other) and self.idx_list == other.idx_list return type(self) == type(other) and self.idx_list == other.idx_list
def __hash__(self): def __hash__(self):
# FIXME: this doesn't work if there are slices in the list because for some mysterious reason slice is unhashable idx_list = tuple((entry.start, entry.stop, entry.step)
return hash(tuple(self.idx_list)) if isinstance(entry, slice)
else entry
for entry in self.idx_list)
return hash(idx_list)
def __str__(self): def __str__(self):
indices = [] indices = []
...@@ -1605,27 +1608,28 @@ if 0: #vertical and horizontal stacking are deprecated. Better to use stack() a ...@@ -1605,27 +1608,28 @@ if 0: #vertical and horizontal stacking are deprecated. Better to use stack() a
return gz[:xs[0]], gz[xs[0]:] return gz[:xs[0]], gz[xs[0]:]
vertical_stack = VerticalStack() vertical_stack = VerticalStack()
class MakeVector(Op):
"""WRITEME"""
def __init__(self, stype):
self.stype = stype
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
assert all(a.type == self.stype for a in inputs)
return Apply(self, inputs, [Tensor(broadcastable = (False,),
dtype = self.stype.dtype)()])
def perform(self, node, inputs, (out,)):
out[0] = numpy.asarray(inputs)
def grad(self, inputs, (gout,)):
return [None]*len(inputs)
make_lvector = MakeVector(lscalar)
"""WRITEME"""
else: else:
pass pass
class MakeVector(Op):
"""WRITEME"""
def __init__(self, stype):
self.stype = stype
def make_node(self, *inputs):
inputs = map(as_tensor, inputs)
assert all(a.type == self.stype for a in inputs)
return Apply(self, inputs, [Tensor(broadcastable = (False,),
dtype = self.stype.dtype)()])
def perform(self, node, inputs, (out,)):
out[0] = numpy.asarray(inputs)
def grad(self, inputs, (gout,)):
return [None]*len(inputs)
make_lvector = MakeVector(lscalar)
"""WRITEME"""
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
......
...@@ -87,17 +87,25 @@ def _insert_inplace_optimizer(env): ...@@ -87,17 +87,25 @@ def _insert_inplace_optimizer(env):
break break
insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1),
out2in(dot_to_gemm),
insert_inplace_optimizer,
failure_callback = gof.keep_going))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
inplace_optimizer = gof.SeqOptimizer(out2in(gemm_pattern_1),
out2in(dot_to_gemm),
insert_inplace_optimizer)
def register_canonicalize(lopt, *tags, **kwargs):
compile.optdb['canonicalize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags)
def register_specialize(lopt, *tags, **kwargs):
compile.optdb['specialize'].register((kwargs and kwargs.pop('name')) or lopt.__name__, lopt, 'fast_run', *tags)
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
###################### ######################
@gof.local_optimizer @gof.local_optimizer([None, None])
def local_dimshuffle_lift(node): def local_dimshuffle_lift(node):
""" """
"Lifts" DimShuffle through Elemwise operations and merges "Lifts" DimShuffle through Elemwise operations and merges
...@@ -129,14 +137,15 @@ def local_dimshuffle_lift(node): ...@@ -129,14 +137,15 @@ def local_dimshuffle_lift(node):
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
dimshuffle_lift = out2in(local_dimshuffle_lift) register_canonicalize(local_dimshuffle_lift)
################# #################
# Shape lifters # # Shape lifters #
################# #################
@gof.local_optimizer @gof.local_optimizer([T.shape, None])
def local_shape_lift_elemwise(node): def local_shape_lift_elemwise(node):
""" """
shape(elemwise_op(..., x, ...)) -> shape(x) shape(elemwise_op(..., x, ...)) -> shape(x)
...@@ -155,7 +164,10 @@ def local_shape_lift_elemwise(node): ...@@ -155,7 +164,10 @@ def local_shape_lift_elemwise(node):
return False return False
@gof.local_optimizer register_canonicalize(local_shape_lift_elemwise)
@gof.local_optimizer([T.shape, None])
def local_shape_lift_sum(node): def local_shape_lift_sum(node):
""" """
shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...] shape(sum{n}(x)) -> [shape(x)[0], ..., shape(x)[n-1], shape(x)[n+1], ...]
...@@ -173,7 +185,10 @@ def local_shape_lift_sum(node): ...@@ -173,7 +185,10 @@ def local_shape_lift_sum(node):
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs # return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
@gof.local_optimizer register_canonicalize(local_shape_lift_sum)
@gof.local_optimizer([T.shape, T.dot])
def local_shape_lift_dot(node): def local_shape_lift_dot(node):
""" """
shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]] shape(dot(a, b)) -> [shape(a)[0], shape(b)[1]]
...@@ -183,9 +198,12 @@ def local_shape_lift_dot(node): ...@@ -183,9 +198,12 @@ def local_shape_lift_dot(node):
a, b = node.inputs[0].owner.inputs a, b = node.inputs[0].owner.inputs
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise, register_canonicalize(local_shape_lift_dot)
local_shape_lift_sum,
local_shape_lift_dot)
# local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise,
# local_shape_lift_sum,
# local_shape_lift_dot)
################ ################
...@@ -201,7 +219,7 @@ def encompasses_broadcastable(b1, b2): ...@@ -201,7 +219,7 @@ def encompasses_broadcastable(b1, b2):
def merge_broadcastables(broadcastables): def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)] return [all(bcast) for bcast in zip(*broadcastables)]
@gof.local_optimizer @gof.local_optimizer([T.fill, None])
def local_fill_lift(node): def local_fill_lift(node):
""" """
fill(f(a), b) -> fill(a, b) fill(f(a), b) -> fill(a, b)
...@@ -217,10 +235,10 @@ def local_fill_lift(node): ...@@ -217,10 +235,10 @@ def local_fill_lift(node):
mb, fb = model.type.broadcastable, filling.type.broadcastable mb, fb = model.type.broadcastable, filling.type.broadcastable
if model.type.dtype == filling.type.dtype and encompasses_broadcastable(fb, mb): if model.type.dtype == filling.type.dtype and encompasses_broadcastable(fb, mb):
return [filling] return False# [filling]
parent = model.owner parent = model.owner
if parent is None: if parent is None or not isinstance(parent, T.Elemwise):
return False return False
for input in parent.inputs: for input in parent.inputs:
if input.type == model.type: if input.type == model.type:
...@@ -228,13 +246,15 @@ def local_fill_lift(node): ...@@ -228,13 +246,15 @@ def local_fill_lift(node):
return False return False
register_canonicalize(local_fill_lift)
################## ##################
# Subtensor opts # # Subtensor opts #
################## ##################
@gof.local_optimizer @gof.local_optimizer([None, None])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
""" """
[a,b,c][0] -> a [a,b,c][0] -> a
...@@ -242,7 +262,7 @@ def local_subtensor_make_vector(node): ...@@ -242,7 +262,7 @@ def local_subtensor_make_vector(node):
If the index or slice is constant. If the index or slice is constant.
""" """
if not opt.check_chain(node, T.Subtensor, T.Join): if not opt.check_chain(node, T.Subtensor, T.MakeVector):
return False return False
joined_r = node.inputs[0] joined_r = node.inputs[0]
...@@ -263,13 +283,15 @@ def local_subtensor_make_vector(node): ...@@ -263,13 +283,15 @@ def local_subtensor_make_vector(node):
return T.make_vector(*(node.owner.inputs[0].owner.inputs.__getslice__(idx))) return T.make_vector(*(node.owner.inputs[0].owner.inputs.__getslice__(idx)))
except TypeError: except TypeError:
return False return False
register_canonicalize(local_subtensor_make_vector)
################## ##################
# Middleman cuts # # Middleman cuts #
################## ##################
@gof.local_optimizer @gof.local_optimizer([None, T.fill])
def local_fill_cut(node): def local_fill_cut(node):
""" """
f(fill(a,b), c) -> f(b, c) f(fill(a,b), c) -> f(b, c)
...@@ -301,7 +323,10 @@ def local_fill_cut(node): ...@@ -301,7 +323,10 @@ def local_fill_cut(node):
return False return False
return node.op.make_node(*new_inputs).outputs return node.op.make_node(*new_inputs).outputs
@gof.local_optimizer register_canonicalize(local_fill_cut)
@gof.local_optimizer([None, T.fill])
def local_fill_sink(node): def local_fill_sink(node):
""" """
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e))) f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
...@@ -323,6 +348,8 @@ def local_fill_sink(node): ...@@ -323,6 +348,8 @@ def local_fill_sink(node):
c = T.fill(model, c) c = T.fill(model, c)
return [c] return [c]
register_canonicalize(local_fill_sink)
################ ################
# Canonization # # Canonization #
...@@ -367,15 +394,23 @@ class Canonizer(gof.LocalOptimizer): ...@@ -367,15 +394,23 @@ class Canonizer(gof.LocalOptimizer):
2 * x / 2 -> x 2 * x / 2 -> x
""" """
def __init__(self, main, inverse, reciprocal, calculate): def __init__(self, main, inverse, reciprocal, calculate, use_reciprocal = True):
self.main = main self.main = main
self.inverse = inverse self.inverse = inverse
self.reciprocal = reciprocal self.reciprocal = reciprocal
self.calculate = calculate self.calculate = calculate
self.use_reciprocal = use_reciprocal
def tracks(self):
#return [[None], [None, None], [None]*3, [None]*4, [None]*5]
return [[self.main, None], [self.inverse, None], [self.reciprocal, None]]
def get_num_denum(self, input): def get_num_denum(self, input):
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]: if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
return [input], [] if input.owner and isinstance(input.owner.op, T.DimShuffle):
return self.get_num_denum(input.owner.inputs[0])
else:
return [input], []
num = [] num = []
denum = [] denum = []
parent = input.owner parent = input.owner
...@@ -396,7 +431,10 @@ class Canonizer(gof.LocalOptimizer): ...@@ -396,7 +431,10 @@ class Canonizer(gof.LocalOptimizer):
if not ln and not ld: if not ln and not ld:
return T.as_tensor(self.calculate([], [])) return T.as_tensor(self.calculate([], []))
if not ln: if not ln:
return self.reciprocal(self.merge_num_denum(denum, [])) if self.use_reciprocal:
return self.reciprocal(self.merge_num_denum(denum, []))
else:
ln = [self.calculate([], [], aslist = False)]
if not ld: if not ld:
if ln == 1: if ln == 1:
if isinstance(num[0], gof.Result): if isinstance(num[0], gof.Result):
...@@ -444,10 +482,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -444,10 +482,13 @@ class Canonizer(gof.LocalOptimizer):
dcc += 1 dcc += 1
denum.remove(v) denum.remove(v)
denumct.append(ct) denumct.append(ct)
ct = self.calculate(numct, denumct, aslist = True) if self.use_reciprocal:
ct = self.calculate(numct, denumct, aslist = True)
else:
ct = [self.calculate(numct, denumct, aslist = False)]
# if len(ct) and ncc == 1 and dcc == 0: # if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum # return orig_num, orig_denum
if orig_num and ct == self.get_constant(orig_num[0]): if orig_num and N.all(ct == self.get_constant(orig_num[0])):
return orig_num, orig_denum return orig_num, orig_denum
return ct + num, denum return ct + num, denum
...@@ -471,14 +512,21 @@ class Canonizer(gof.LocalOptimizer): ...@@ -471,14 +512,21 @@ class Canonizer(gof.LocalOptimizer):
num, denum = list(orig_num), list(orig_denum) num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum) num, denum = self.simplify(num, denum)
if not reorg and orig_num == num and orig_denum == denum: def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
if not reorg and same(orig_num, num) and same(orig_denum, denum):
return False return False
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
if new.type != out.type: if new.type != out.type:
new = T.fill(out, new) #new = T.fill(out, new)
new = T.fill(out, T.Elemwise(scalar.Identity(scalar.specific_out(getattr(scalar, out.type.dtype))))(new))
return [new] return [new]
def __str__(self):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal))
def mul_calculate(num, denum, aslist = False): def mul_calculate(num, denum, aslist = False):
v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0) v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0)
...@@ -489,28 +537,42 @@ def mul_calculate(num, denum, aslist = False): ...@@ -489,28 +537,42 @@ def mul_calculate(num, denum, aslist = False):
return [v] return [v]
return v return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate) local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate, False)
@gof.local_optimizer @gof.local_optimizer([T.neg])
def local_neg_to_mul(node): def local_neg_to_mul(node):
if node.op == T.neg: if node.op == T.neg:
return [-1.0 * node.inputs[0]] return [-1 * node.inputs[0]]
else: else:
return False return False
@gof.local_optimizer @gof.local_optimizer([T.mul])
def local_mul_to_neg(node): def local_mul_to_neg(node):
if node.op == T.mul and local_mul_canonizer.get_constant(node.inputs[0]) == -1.0: if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])] return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])]
else: else:
return False return False
neg_to_mul = out2in(gof.LocalOptGroup(local_neg_to_mul)) @gof.local_optimizer([T.div])
mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg)) def local_div_to_inv(node):
if node.op == T.div and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == 1.0):
return [T.inv(local_mul_canonizer.merge_num_denum(node.inputs[1:], []))]
else:
return False
register_canonicalize(local_neg_to_mul)
register_specialize(local_mul_to_neg)
register_specialize(local_div_to_inv)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
# neg_to_mul = out2in(gof.LocalOptGroup(local_neg_to_mul))
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink)) mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink))
def add_calculate(num, denum, aslist = False): def add_calculate(num, denum, aslist = False):
v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0) v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
if aslist: if aslist:
...@@ -523,6 +585,8 @@ def add_calculate(num, denum, aslist = False): ...@@ -523,6 +585,8 @@ def add_calculate(num, denum, aslist = False):
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate) local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, local_fill_cut, local_fill_sink)) add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, local_fill_cut, local_fill_sink))
register_canonicalize(local_add_canonizer, name = 'local_add_canonizer')
################## ##################
# Distributivity # # Distributivity #
...@@ -583,7 +647,8 @@ def attempt_distribution(factor, num, denum): ...@@ -583,7 +647,8 @@ def attempt_distribution(factor, num, denum):
list(itertools.starmap(local_mul_canonizer.merge_num_denum, pos_pairs)), list(itertools.starmap(local_mul_canonizer.merge_num_denum, pos_pairs)),
list(itertools.starmap(local_mul_canonizer.merge_num_denum, neg_pairs))), num, denum list(itertools.starmap(local_mul_canonizer.merge_num_denum, neg_pairs))), num, denum
@gof.local_optimizer @gof.local_optimizer([T.mul, T.add, T.mul], [T.mul, T.sub, T.mul],
[T.mul, T.add, T.div], [T.mul, T.sub, T.div])
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
This optimization tries to apply distributivity of multiplication This optimization tries to apply distributivity of multiplication
...@@ -638,41 +703,48 @@ def local_greedy_distributor(node): ...@@ -638,41 +703,48 @@ def local_greedy_distributor(node):
return [local_mul_canonizer.merge_num_denum(new_num, new_denum)] return [local_mul_canonizer.merge_num_denum(new_num, new_denum)]
register_canonicalize(local_greedy_distributor)
def _math_optimizer():
pass_1 = in2out(local_fill_sink) # def _math_optimizer():
pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut) # pass_1 = in2out(local_fill_sink)
pass_3 = out2in(local_subtensor_make_vector, local_fill_cut) # pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
# pass_3 = out2in(local_subtensor_make_vector, local_fill_cut)
canonizer = in2out(local_add_canonizer, # canonizer = in2out(local_add_canonizer,
local_mul_canonizer, # local_mul_canonizer,
local_fill_sink) # local_fill_sink)
# pass_4 = out2in(local_greedy_distributor)
# return gof.SeqOptimizer(pass_1,
# pass_2,
# pass_3,
# neg_to_mul,
# canonizer,
# pass_4,
# mul_to_neg)
# math_optimizer = _math_optimizer()
pass_4 = out2in(local_greedy_distributor)
return gof.SeqOptimizer(pass_1,
pass_2,
pass_3,
neg_to_mul,
canonizer,
pass_4,
mul_to_neg)
math_optimizer = _math_optimizer()
compile.register_optimizer('math', # compile.register_optimizer('math',
gof.MergeOptMerge( # gof.MergeOptMerge(
gof.PureThenInplaceOptimizer( # gof.PureThenInplaceOptimizer(
math_optimizer, # math_optimizer,
inplace_optimizer))) # inplace_optimizer)))
compile.register_mode('SANITY_CHECK', compile.Mode('c&py', 'math')) # compile.register_mode('SANITY_CHECK', compile.Mode('c&py', 'math'))
compile.register_mode('FAST_RUN', compile.Mode('c|py', 'math')) # compile.register_mode('FAST_RUN', compile.Mode('c|py', 'math'))
compile.register_mode('EXPENSIVE_OPTIMIZATIONS', compile.Mode('c|py', 'math')) # compile.register_mode('EXPENSIVE_OPTIMIZATIONS', compile.Mode('c|py', 'math'))
# @gof.local_optimizer # @gof.local_optimizer
......
...@@ -1808,7 +1808,7 @@ class T_op_cache(unittest.TestCase): ...@@ -1808,7 +1808,7 @@ class T_op_cache(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT': if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
default_mode = compile.Mode(linker = 'c&py', default_mode = compile.Mode(linker = 'c&py',
optimizer = 'math') optimizer = 'fast_run')
sys.argv[1:] = sys.argv[2:] sys.argv[1:] = sys.argv[2:]
if 1: if 1:
unittest.main() unittest.main()
......
...@@ -9,6 +9,7 @@ from theano import tensor ...@@ -9,6 +9,7 @@ from theano import tensor
from theano.tensor import Tensor from theano.tensor import Tensor
from theano.gof import Env from theano.gof import Env
from theano.tensor.elemwise import DimShuffle from theano.tensor.elemwise import DimShuffle
from theano.sandbox import pprint
import numpy import numpy
#import scalar_opt #import scalar_opt
...@@ -129,12 +130,13 @@ class test_canonize(unittest.TestCase): ...@@ -129,12 +130,13 @@ class test_canonize(unittest.TestCase):
# e = (a * b) / (b * c) / (c * d) # e = (a * b) / (b * c) / (c * d)
# e = 2 * x / 2 # e = 2 * x / 2
# e = x / y / x # e = x / y / x
e = (x / x) * (y / y) # e = (x / x) * (y / y)
e = (-1 * x) / y / (-2 * z)
g = Env([x, y, z, a, b, c, d], [e]) g = Env([x, y, z, a, b, c, d], [e])
##print pprint.pp.process(g.outputs[0]) print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g) mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g) gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
##print pprint.pp.process(g.outputs[0]) print pprint.pp.process(g.outputs[0])
# def test_plusmin(self): # def test_plusmin(self):
# x, y, z = inputs() # x, y, z = inputs()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论