提交 55c5d0b3 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

......@@ -211,12 +211,49 @@ for this somewhere in the future.
Local optimization
------------------
The local version of the above code would be the following:
.. code-block:: python
class LocalSimplify(gof.LocalOptimizer):
def transform(self, node):
if node.op == div:
x, y = node.inputs
if x.owner and x.owner.op == mul:
a, b = x.owner.inputs
if y == a:
return [b]
elif y == b:
return [a]
return False
local_simplify = LocalSimplify()
The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made,
False must be returned. Else, a list of what to replace the node's
outputs with must be returned.
In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. You can follow this :ref:`link <navigator>`
for further documentation, but basically a Navigator is a global
optimizer that loops through all nodes in the graph (or a well-defined
subset of them) and applies one or several local optimizers on them.
>>> x = double('x')
>>> y = double('y')
>>> z = double('z')
>>> a = add(z, mul(div(mul(y, x), y), div(z, x)))
>>> e = gof.Env([x, y, z], [a])
>>> e
[add(z, mul(div(mul(y, x), y), div(z, x)))]
>>> simplify = gof.TopoOptimizer([local_simplify])
>>> simplify.optimize(e)
>>> e
[add(z, mul(x, div(z, x)))]
TODO: test this.
The optimization database (optdb)
......
......@@ -381,19 +381,21 @@ class Canonizer(gof.LocalOptimizer):
Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* main: a suitable Op class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or
mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* calculate: function that takes a list of numpy.ndarray instances for
the numerator, another list for the denumerator, and calculates
inverse(main(*num), main(*denum)). It takes a keyword argument,
aslist. If True, the value should be returned as a list of one
element, unless the value is such that value = main(). In that
case, the return value should be an empty list.
e.g. sub or div
* reciprocal: a function such that main(x, reciprocal(y)) ==
inverse(x, y) e.g. neg or inv
* calculate: function that takes a list of numpy.ndarray instances
for the numerator, another list for the denumerator,
and calculates inverse(main(*num), main(*denum)). It
takes a keyword argument, aslist. If True, the value
should be returned as a list of one element, unless
the value is such that value = main(). In that case,
the return value should be an empty list.
The result is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order.
......@@ -422,38 +424,114 @@ class Canonizer(gof.LocalOptimizer):
self.use_reciprocal = use_reciprocal
def tracks(self):
#return [[None], [None, None], [None]*3, [None]*4, [None]*5]
return [[self.main, None], [self.inverse, None], [self.reciprocal, None]]
def get_num_denum(self, input):
"""
This extract two lists, num and denum, such that the input is:
self.inverse(self.main(*num), self.main(*denum)). It returns
the two lists in a (num, denum) pair.
For example, for main, inverse and reciprocal = *, / and inv(),
input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
"""
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
dsn = input.owner
dsop = dsn.op
dsi0 = dsn.inputs[0]
# If input is a DimShuffle of some input which does something like this:
# * change a vector of length N into a 1xN row matrix
# * change a scalar into a 1x1x1 tensor
# * in general, complete the shape of a tensor with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return the num/denum of its input
dsn = input.owner # dimshuffle node
dsop = dsn.op # dimshuffle op
dsi0 = dsn.inputs[0] # the first input of the dimshuffle i.e. the ndarray to redim
# The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
# That kind of DimShuffle only adds broadcastable
# dimensions on the left, without discarding any
# existing broadcastable dimension and is inserted
# automatically by Elemwise when the inputs have
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# later on).
compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
return self.get_num_denum(input.owner.inputs[0])
else:
# This is when the input isn't produced by main, inverse or reciprocal.
return [input], []
else:
return [input], []
num = []
denum = []
parent = input.owner
# We get the (num, denum) pairs for each input
pairs = [self.get_num_denum(input) for input in parent.inputs]
if parent.op == self.main:
# If we have main(x, y), numx, denumx, numy and denumy
# then num is concat(numx, numy) and denum is concat(denumx, denumy)
# note that main() can have any number of arguments >= 0
# concat is list concatenation
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
elif parent.op == self.inverse:
# If we have inverse(x, y), numx, denumx, numy and denumy
# then num is concat(numx, denumy) and denum is concat(denumx, numy)
# note that inverse() is binary
num = pairs[0][0] + pairs[1][1]
denum = pairs[0][1] + pairs[1][0]
elif parent.op == self.reciprocal:
# If we have reciprocal(x), numx, denumx
# then num is denumx and denum is numx
# note that reciprocal() is unary
num = pairs[0][1]
denum = pairs[0][0]
return num, denum
def merge_num_denum(self, num, denum):
"""
Utility function which takes two lists, num and denum, and
returns something which is equivalent to inverse(main(*num),
main(*denum)), but depends on the length of num and the length
of denum (in order to minimize the number of operations).
Let n = len(num) and d = len(denum):
n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))
Given the values of n and d to which they are associated, all
of the above are equivalent to:
inverse(main(*num), main(*denum))
"""
ln, ld = len(num), len(denum)
if not ln and not ld:
return T.as_tensor(self.calculate([], []))
......@@ -475,20 +553,52 @@ class Canonizer(gof.LocalOptimizer):
@classmethod
def get_constant(cls, v):
"""
Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Result, returns None.
"""
if isinstance(v, N.generic):
return v
return v # doesn't the not hasattr() condition below catch this?
if isinstance(v, gof.Constant):
return v.data
if not hasattr(v, 'owner'):
return v
if v.owner and isinstance(v.owner.op, DimShuffle):
return cls.get_constant(v.owner.inputs[0])
# NOTE: the following code was buggy, but while I was fixing
# it I realized it is probably made useless by constant
# folding, so screw that. Commented-out code is the half-fixed
# version.
# if v.owner and isinstance(v.owner.op, DimShuffle):
# # see the comments in get_num_denum
# # TODO: this should apply the
# dsn = v.owner
# dsop = dsn.op
# dsi0 = dsn.inputs[0]
# compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
# if dsop.new_order == compatible_order:
# return cls.get_constant(v.owner.inputs[0])
return None
def simplify(self, num, denum):
"""
Shorthand for: self.simplify_constants(*self.simplify_factors(num, denum))
"""
return self.simplify_constants(*self.simplify_factors(num, denum))
def simplify_factors(self, num, denum):
"""
For any Result r which is both in num and denum, removes it
from both lists. Modifies the lists inplace. Returns the
modified lists. For example:
[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]
"""
for v in list(num):
if v in denum:
num.remove(v)
......@@ -496,28 +606,64 @@ class Canonizer(gof.LocalOptimizer):
return num, denum
def simplify_constants(self, orig_num, orig_denum):
"""
Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single
constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is
removed from the numerator. Examples:
Let main be multiplication:
[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
"""
# Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum)
# Lists representing the *constant* elements of num and denum
numct, denumct = [], []
ncc, dcc = 0, 0
for v in orig_num:
ct = self.get_constant(v)
if ct is not None:
ncc += 1
# We found a constant in the numerator!
# We remove it from num
num.remove(v)
# We add it to numct
numct.append(ct)
for v in orig_denum:
ct = self.get_constant(v)
if ct is not None:
dcc += 1
denum.remove(v)
denumct.append(ct)
if self.use_reciprocal or num:
# This will calculate either:
# [inverse(main(*numct), main(*denumct))]
# [] - if inverse(main(*numct), main(*denumct)) is the neutral element
ct = self.calculate(numct, denumct, aslist = True)
else:
# This happens if we don't allow the reciprocal and the
# numerator is empty. That means we will need to represent
# reciprocal(x) like inverse(neutral_element, x) so
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False)]
# if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum
if orig_num and len(numct) == 1 and ct and N.all(ct == self.get_constant(orig_num[0])):
# TODO: why are we not wrapping ct in a gof.Constant right now?
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on the denominator
# * it's not the neutral element (ct is an empty list in that case)
# * the constant is the same as the first argument in the numerator
# Then we return very exactly the original num/denum
# If we don't do that the optimizer will just loop infinitely because
# it will not catch on that there are no changes to be made and everytime
# it will want to replace something by the same thing...
return orig_num, orig_denum
return ct + num, denum
......@@ -528,6 +674,11 @@ class Canonizer(gof.LocalOptimizer):
if op not in [self.main, self.inverse, self.reciprocal]:
return False
# I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have
# inverse operating on an inverse, we can make it so that only
# one inverse is used, so we'll reorganize that.
iops = set(input.owner.op for input in inputs if input.owner)
reorg = False
if op == self.main:
......@@ -537,8 +688,11 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
# just in case
assert len(node.outputs) == 1
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum)
......@@ -547,6 +701,8 @@ class Canonizer(gof.LocalOptimizer):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
if not reorg and same(orig_num, num) and same(orig_denum, denum):
# We return False if there are no changes
# TODO: what's the purpose of reorg? isn't same() sufficient?
return False
new = self.merge_num_denum(num, denum)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论