提交 fd7e6d59 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

merge

...@@ -7,6 +7,19 @@ Graph optimization ...@@ -7,6 +7,19 @@ Graph optimization
In this section we will define a couple optimizations on doubles. In this section we will define a couple optimizations on doubles.
.. todo::
This tutorial goes way too far under the hood, for someone who just wants
to add yet another pattern to the libraries in tensor.opt for example.
We need another tutorial that covers the decorator syntax, and explains how
to register your optimization right away. That's what you need to get
going.
Later, the rest is more useful for when that decorator syntax type thing
doesn't work. (There are optimizations that don't fit that model).
Global and local optimizations Global and local optimizations
============================== ==============================
...@@ -119,6 +132,11 @@ simplification described above: ...@@ -119,6 +132,11 @@ simplification described above:
simplify = Simplify() simplify = Simplify()
.. todo::
What is add_requirements? Why would we know to do this? Are there other
requirements we might want to know about?
Here's how it works: first, in ``add_requirements``, we add the Here's how it works: first, in ``add_requirements``, we add the
``ReplaceValidate`` :ref:`envfeature` located in ``ReplaceValidate`` :ref:`envfeature` located in
:api:`theano.gof.toolbox`. This feature adds the ``replace_validate`` :api:`theano.gof.toolbox`. This feature adds the ``replace_validate``
...@@ -150,6 +168,7 @@ and :ref:`apply` to get a better understanding of the ...@@ -150,6 +168,7 @@ and :ref:`apply` to get a better understanding of the
pointer-following game you need to get ahold of the nodes of interest pointer-following game you need to get ahold of the nodes of interest
for the simplification (``x``, ``y``, ``z``, ``a``, ``b``, etc.). for the simplification (``x``, ``y``, ``z``, ``a``, ``b``, etc.).
Test time: Test time:
>>> x = double('x') >>> x = double('x')
...@@ -238,6 +257,10 @@ The local version of the above code would be the following: ...@@ -238,6 +257,10 @@ The local version of the above code would be the following:
local_simplify = LocalSimplify() local_simplify = LocalSimplify()
.. todo::
Fix up previous example... it's bad and incomplete.
The definition of transform is the inner loop of the global optimizer, The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made, where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's ``False`` must be returned. Else, a list of what to replace the node's
...@@ -310,6 +333,9 @@ Theano defines some shortcuts to make LocalOptimizers: ...@@ -310,6 +333,9 @@ Theano defines some shortcuts to make LocalOptimizers:
means that everything we said previously about local optimizers means that everything we said previously about local optimizers
apply: they need to be wrapped in a Navigator, etc. apply: they need to be wrapped in a Navigator, etc.
.. todo::
wtf is a navigator?
When an optimization can be naturally expressed using ``OpSub``, ``OpRemove`` When an optimization can be naturally expressed using ``OpSub``, ``OpRemove``
or ``PatternSub``, it is highly recommended to use them. or ``PatternSub``, it is highly recommended to use them.
...@@ -319,6 +345,7 @@ use constraints, etc. - there's some decent doc at ...@@ -319,6 +345,7 @@ use constraints, etc. - there's some decent doc at
:api:`theano.gof.opt.PatternSub` for those interested) :api:`theano.gof.opt.PatternSub` for those interested)
.. _optdb: .. _optdb:
The optimization database (optdb) The optimization database (optdb)
......
...@@ -1781,36 +1781,44 @@ def div_proxy(x, y): ...@@ -1781,36 +1781,44 @@ def div_proxy(x, y):
return true_div(x, y) return true_div(x, y)
@_scal_elemwise @_scal_elemwise
def add(a, b): def add(a, *other_terms):
"""elementwise addition""" """elementwise addition"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def sub(a, b): def sub(a, b):
"""elementwise subtraction""" """elementwise subtraction"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def mul(a, b): def mul(a, *other_terms):
"""elementwise multiplication""" """elementwise multiplication"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def true_div(a, b): def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def int_div(a, b): def int_div(a, b):
"""elementwise integer-division""" """elementwise integer-division"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def mod(a, b): def mod(a, b):
"""elementwise modulo""" """elementwise modulo"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def pow(a, b): def pow(a, b):
"""elementwise power""" """elementwise power"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def clip(x, min, max): def clip(x, min, max):
"""clip x to be between min and max""" """clip x to be between min and max"""
# see decorator for function body
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either')) pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either')) pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
...@@ -3007,6 +3015,8 @@ class AdvancedSubtensor(Op): ...@@ -3007,6 +3015,8 @@ class AdvancedSubtensor(Op):
#TODO: see what's the best solution #TODO: see what's the best solution
self.args = args #? self.args = args #?
#FIXME: do not store variables in the class instance
#FIXME #FIXME
#if len(args) != 2: #if len(args) != 2:
# print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args) # print >>sys.stderr, 'WARNING: Advanced indexing with %i arguments not supported yet' % len(args)
...@@ -3018,6 +3028,11 @@ class AdvancedSubtensor(Op): ...@@ -3018,6 +3028,11 @@ class AdvancedSubtensor(Op):
if x.ndim == 2 and len(inputs) == 2: if x.ndim == 2 and len(inputs) == 2:
ind1 = as_tensor_variable(inputs[0]) ind1 = as_tensor_variable(inputs[0])
ind2 = as_tensor_variable(inputs[1]) ind2 = as_tensor_variable(inputs[1])
if not (ind1.type.dtype.startswith('int') or ind1.type.dtype.startswith('uint')):
raise TypeError()
if not (ind2.type.dtype.startswith('int') or ind2.type.dtype.startswith('uint')):
raise TypeError()
if ind1.ndim == 1 and ind2.ndim == 1: if ind1.ndim == 1 and ind2.ndim == 1:
return gof.Apply(self, return gof.Apply(self,
(x,) + inputs, (x,) + inputs,
...@@ -3029,7 +3044,11 @@ class AdvancedSubtensor(Op): ...@@ -3029,7 +3044,11 @@ class AdvancedSubtensor(Op):
% ','.join(str(input) for input in inputs)) % ','.join(str(input) for input in inputs))
def perform(self, node, inputs, (out,)): def perform(self, node, inputs, (out,)):
pass # TODO: in general, we need to re-pack the inputs into a valid index, just like
# subtensor
out[0] = inputs[0].__getitem__(inputs[1:])
#return
#raise NotImplementedError()
def grad(self, inputs, (gz,)): def grad(self, inputs, (gz,)):
x = inputs[0] x = inputs[0]
...@@ -3061,9 +3080,14 @@ class AdvancedIncSubtensor(Op): ...@@ -3061,9 +3080,14 @@ class AdvancedIncSubtensor(Op):
% ','.join(str(input) for input in inputs)) % ','.join(str(input) for input in inputs))
def perform(self, node, inputs, (out,)): def perform(self, node, inputs, (out,)):
pass # TODO: same thing as in AdvancedSubtensor's perform TODO
out[0] = inputs[0].copy()
out[0][inputs[2:]] += inputs[1]
#def grad? #def grad?
# grad on x is grad on output
# grad on y is grad_output[idx_list]
# grad on rest is None
......
差异被折叠。
...@@ -14,7 +14,8 @@ from elemwise import Elemwise, DimShuffle ...@@ -14,7 +14,8 @@ from elemwise import Elemwise, DimShuffle
from theano import scalar from theano import scalar
import basic as T import basic as T
import inplace as I import inplace as I
import numpy as N import numpy
import numpy as N #guys... please don't do this in the library :(
import operator import operator
import itertools import itertools
import sys, os import sys, os
...@@ -62,7 +63,6 @@ def get_constant_value(v): ...@@ -62,7 +63,6 @@ def get_constant_value(v):
return get_constant_value(v.owner.inputs[0]) return get_constant_value(v.owner.inputs[0])
raise TypeError(v) raise TypeError(v)
@gof.optimizer @gof.optimizer
def insert_inplace_optimizer(env): def insert_inplace_optimizer(env):
""" """
...@@ -108,10 +108,12 @@ compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run', ...@@ -108,10 +108,12 @@ compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run',
def register_canonicalize(lopt, *tags, **kwargs): def register_canonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags) compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_specialize(lopt, *tags, **kwargs): def register_specialize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags) compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags)
return lopt
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
...@@ -876,10 +878,39 @@ register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer') ...@@ -876,10 +878,39 @@ register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
def local_neg_to_mul(node): def local_neg_to_mul(node):
if node.op == T.neg: if node.op == T.neg:
return [T.mul(-1, node.inputs[0])] return [T.mul(-1, node.inputs[0])]
else:
return False
register_canonicalize(local_neg_to_mul) register_canonicalize(local_neg_to_mul)
@register_specialize
@gof.local_optimizer([])
def local_sum_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth)
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, T.Sum):
thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs
scalars = [t.dimshuffle() for t in terms if numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if scalars:
if len(scalars) > 1:
if len(non_scalars) > 1:
return [T.mul(T.mul(*scalars), node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(T.mul(*scalars), node.op(non_scalars[0]))]
else:
return [T.mul(*scalars)]
else:
if len(non_scalars) > 1:
return [T.mul(scalars[0], node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(scalars[0], node.op(non_scalars[0]))]
else:
return [scalars[0]]
if thing_summed.owner and thing_summed.owner.op == T.neg:
return [T.neg(node.op(thing_summed.owner.inputs[0]))]
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_to_neg(node): def local_mul_to_neg(node):
if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0): if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
...@@ -888,6 +919,16 @@ def local_mul_to_neg(node): ...@@ -888,6 +919,16 @@ def local_mul_to_neg(node):
return False return False
register_specialize(local_mul_to_neg) register_specialize(local_mul_to_neg)
@register_specialize
@gof.local_optimizer([T.neg])
def local_neg_neg(node):
# other specializations shouldn't put this in,
# but sometimes they do
if node.op == T.neg:
if node.inputs[0].owner and node.inputs[0].owner.op == T.neg:
return [node.inputs[0].owner.inputs[0]]
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_zero(node): def local_mul_zero(node):
"""As part of canonicalization, we replace multiplication by zero with zero. """As part of canonicalization, we replace multiplication by zero with zero.
......
...@@ -317,6 +317,44 @@ def test_asymptotic_32(): ...@@ -317,6 +317,44 @@ def test_asymptotic_32():
assert gxval[0,1] == 0.25 assert gxval[0,1] == 0.25
def test_get_rid_of_advanced_indexing_version_of_xent():
rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3,5)
y_val = numpy.asarray([2,4,1])
x = T.dmatrix('x')
y = T.lvector('y')
expressions_to_test = [
T.sum(-T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(x))[T.arange(y.shape[0]), y])]
def assert_optimizer_worked(expr):
f = theano.function([x,y], expr)
for i, node in enumerate(f.maker.env.toposort()):
print i, node
f(x_val, y_val)
assert len(f.maker.env.toposort()) == 4
for expr in expressions_to_test:
assert_optimizer_worked(expr)
## Gradient wrt x
for expr in expressions_to_test:
grad_x = T.grad(expr, x)
g = theano.function([x, y], grad_x)
for i, node in enumerate(g.maker.env.toposort()):
print i, node
g(x_val, y_val)
assert len(g.maker.env.toposort()) == 4
#TODO: Case with bias
# hint - call local_softmax_with_bias from within the other optimization
# hint - call the argmax push-down optimization first too
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论