提交 55c5d0b3 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -211,12 +211,49 @@ for this somewhere in the future. ...@@ -211,12 +211,49 @@ for this somewhere in the future.
Local optimization Local optimization
------------------ ------------------
The local version of the above code would be the following:
.. code-block:: python
class LocalSimplify(gof.LocalOptimizer):
def transform(self, node):
if node.op == div:
x, y = node.inputs
if x.owner and x.owner.op == mul:
a, b = x.owner.inputs
if y == a:
return [b]
elif y == b:
return [a]
return False
local_simplify = LocalSimplify()
The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made,
False must be returned. Else, a list of what to replace the node's
outputs with must be returned.
In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. You can follow this :ref:`link <navigator>`
for further documentation, but basically a Navigator is a global
optimizer that loops through all nodes in the graph (or a well-defined
subset of them) and applies one or several local optimizers on them.
>>> x = double('x')
>>> y = double('y')
>>> z = double('z')
>>> a = add(z, mul(div(mul(y, x), y), div(z, x)))
>>> e = gof.Env([x, y, z], [a])
>>> e
[add(z, mul(div(mul(y, x), y), div(z, x)))]
>>> simplify = gof.TopoOptimizer([local_simplify])
>>> simplify.optimize(e)
>>> e
[add(z, mul(x, div(z, x)))]
TODO: test this.
The optimization database (optdb) The optimization database (optdb)
......
...@@ -381,19 +381,21 @@ class Canonizer(gof.LocalOptimizer): ...@@ -381,19 +381,21 @@ class Canonizer(gof.LocalOptimizer):
Usage: Canonizer(main, inverse, reciprocal, calculate) Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and takes * main: a suitable Op class that is commutative, associative and
one to an arbitrary number of inputs, e.g. Add or Mul takes one to an arbitrary number of inputs, e.g. add or
mul
* inverse: an Op class such that inverse(main(x, y), y) == x * inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div e.g. sub or div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y) * reciprocal: a function such that main(x, reciprocal(y)) ==
e.g. Neg or Inv inverse(x, y) e.g. neg or inv
* calculate: function that takes a list of numpy.ndarray instances for * calculate: function that takes a list of numpy.ndarray instances
the numerator, another list for the denumerator, and calculates for the numerator, another list for the denumerator,
inverse(main(*num), main(*denum)). It takes a keyword argument, and calculates inverse(main(*num), main(*denum)). It
aslist. If True, the value should be returned as a list of one takes a keyword argument, aslist. If True, the value
element, unless the value is such that value = main(). In that should be returned as a list of one element, unless
case, the return value should be an empty list. the value is such that value = main(). In that case,
the return value should be an empty list.
The result is a local_optimizer. It is best used with a TopoOptimizer in The result is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order. in_to_out order.
...@@ -422,38 +424,114 @@ class Canonizer(gof.LocalOptimizer): ...@@ -422,38 +424,114 @@ class Canonizer(gof.LocalOptimizer):
self.use_reciprocal = use_reciprocal self.use_reciprocal = use_reciprocal
def tracks(self): def tracks(self):
#return [[None], [None, None], [None]*3, [None]*4, [None]*5]
return [[self.main, None], [self.inverse, None], [self.reciprocal, None]] return [[self.main, None], [self.inverse, None], [self.reciprocal, None]]
def get_num_denum(self, input): def get_num_denum(self, input):
"""
This extract two lists, num and denum, such that the input is:
self.inverse(self.main(*num), self.main(*denum)). It returns
the two lists in a (num, denum) pair.
For example, for main, inverse and reciprocal = *, / and inv(),
input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
"""
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]: if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle): if input.owner and isinstance(input.owner.op, T.DimShuffle):
dsn = input.owner # If input is a DimShuffle of some input which does something like this:
dsop = dsn.op # * change a vector of length N into a 1xN row matrix
dsi0 = dsn.inputs[0] # * change a scalar into a 1x1x1 tensor
# * in general, complete the shape of a tensor with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return the num/denum of its input
dsn = input.owner # dimshuffle node
dsop = dsn.op # dimshuffle op
dsi0 = dsn.inputs[0] # the first input of the dimshuffle i.e. the ndarray to redim
# The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
# That kind of DimShuffle only adds broadcastable
# dimensions on the left, without discarding any
# existing broadcastable dimension and is inserted
# automatically by Elemwise when the inputs have
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# later on).
compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim)) compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
if dsop.new_order == compatible_order: if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
return self.get_num_denum(input.owner.inputs[0]) return self.get_num_denum(input.owner.inputs[0])
else: else:
# This is when the input isn't produced by main, inverse or reciprocal.
return [input], [] return [input], []
else: else:
return [input], [] return [input], []
num = [] num = []
denum = [] denum = []
parent = input.owner parent = input.owner
# We get the (num, denum) pairs for each input
pairs = [self.get_num_denum(input) for input in parent.inputs] pairs = [self.get_num_denum(input) for input in parent.inputs]
if parent.op == self.main: if parent.op == self.main:
# If we have main(x, y), numx, denumx, numy and denumy
# then num is concat(numx, numy) and denum is concat(denumx, denumy)
# note that main() can have any number of arguments >= 0
# concat is list concatenation
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs)) num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs)) denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
elif parent.op == self.inverse: elif parent.op == self.inverse:
# If we have inverse(x, y), numx, denumx, numy and denumy
# then num is concat(numx, denumy) and denum is concat(denumx, numy)
# note that inverse() is binary
num = pairs[0][0] + pairs[1][1] num = pairs[0][0] + pairs[1][1]
denum = pairs[0][1] + pairs[1][0] denum = pairs[0][1] + pairs[1][0]
elif parent.op == self.reciprocal: elif parent.op == self.reciprocal:
# If we have reciprocal(x), numx, denumx
# then num is denumx and denum is numx
# note that reciprocal() is unary
num = pairs[0][1] num = pairs[0][1]
denum = pairs[0][0] denum = pairs[0][0]
return num, denum return num, denum
def merge_num_denum(self, num, denum): def merge_num_denum(self, num, denum):
"""
Utility function which takes two lists, num and denum, and
returns something which is equivalent to inverse(main(*num),
main(*denum)), but depends on the length of num and the length
of denum (in order to minimize the number of operations).
Let n = len(num) and d = len(denum):
n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))
Given the values of n and d to which they are associated, all
of the above are equivalent to:
inverse(main(*num), main(*denum))
"""
ln, ld = len(num), len(denum) ln, ld = len(num), len(denum)
if not ln and not ld: if not ln and not ld:
return T.as_tensor(self.calculate([], [])) return T.as_tensor(self.calculate([], []))
...@@ -475,20 +553,52 @@ class Canonizer(gof.LocalOptimizer): ...@@ -475,20 +553,52 @@ class Canonizer(gof.LocalOptimizer):
@classmethod @classmethod
def get_constant(cls, v): def get_constant(cls, v):
"""
Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Result, returns None.
"""
if isinstance(v, N.generic): if isinstance(v, N.generic):
return v return v # doesn't the not hasattr() condition below catch this?
if isinstance(v, gof.Constant): if isinstance(v, gof.Constant):
return v.data return v.data
if not hasattr(v, 'owner'): if not hasattr(v, 'owner'):
return v return v
if v.owner and isinstance(v.owner.op, DimShuffle):
return cls.get_constant(v.owner.inputs[0]) # NOTE: the following code was buggy, but while I was fixing
# it I realized it is probably made useless by constant
# folding, so screw that. Commented-out code is the half-fixed
# version.
# if v.owner and isinstance(v.owner.op, DimShuffle):
# # see the comments in get_num_denum
# # TODO: this should apply the
# dsn = v.owner
# dsop = dsn.op
# dsi0 = dsn.inputs[0]
# compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
# if dsop.new_order == compatible_order:
# return cls.get_constant(v.owner.inputs[0])
return None return None
def simplify(self, num, denum): def simplify(self, num, denum):
"""
Shorthand for: self.simplify_constants(*self.simplify_factors(num, denum))
"""
return self.simplify_constants(*self.simplify_factors(num, denum)) return self.simplify_constants(*self.simplify_factors(num, denum))
def simplify_factors(self, num, denum): def simplify_factors(self, num, denum):
"""
For any Result r which is both in num and denum, removes it
from both lists. Modifies the lists inplace. Returns the
modified lists. For example:
[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]
"""
for v in list(num): for v in list(num):
if v in denum: if v in denum:
num.remove(v) num.remove(v)
...@@ -496,28 +606,64 @@ class Canonizer(gof.LocalOptimizer): ...@@ -496,28 +606,64 @@ class Canonizer(gof.LocalOptimizer):
return num, denum return num, denum
def simplify_constants(self, orig_num, orig_denum): def simplify_constants(self, orig_num, orig_denum):
"""
Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single
constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is
removed from the numerator. Examples:
Let main be multiplication:
[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
"""
# Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum) num, denum = list(orig_num), list(orig_denum)
# Lists representing the *constant* elements of num and denum
numct, denumct = [], [] numct, denumct = [], []
ncc, dcc = 0, 0
for v in orig_num: for v in orig_num:
ct = self.get_constant(v) ct = self.get_constant(v)
if ct is not None: if ct is not None:
ncc += 1 # We found a constant in the numerator!
# We remove it from num
num.remove(v) num.remove(v)
# We add it to numct
numct.append(ct) numct.append(ct)
for v in orig_denum: for v in orig_denum:
ct = self.get_constant(v) ct = self.get_constant(v)
if ct is not None: if ct is not None:
dcc += 1
denum.remove(v) denum.remove(v)
denumct.append(ct) denumct.append(ct)
if self.use_reciprocal or num: if self.use_reciprocal or num:
# This will calculate either:
# [inverse(main(*numct), main(*denumct))]
# [] - if inverse(main(*numct), main(*denumct)) is the neutral element
ct = self.calculate(numct, denumct, aslist = True) ct = self.calculate(numct, denumct, aslist = True)
else: else:
# This happens if we don't allow the reciprocal and the
# numerator is empty. That means we will need to represent
# reciprocal(x) like inverse(neutral_element, x) so
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False)] ct = [self.calculate(numct, denumct, aslist = False)]
# if len(ct) and ncc == 1 and dcc == 0: # TODO: why are we not wrapping ct in a gof.Constant right now?
# return orig_num, orig_denum
if orig_num and len(numct) == 1 and ct and N.all(ct == self.get_constant(orig_num[0])): if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on the denominator
# * it's not the neutral element (ct is an empty list in that case)
# * the constant is the same as the first argument in the numerator
# Then we return very exactly the original num/denum
# If we don't do that the optimizer will just loop infinitely because
# it will not catch on that there are no changes to be made and everytime
# it will want to replace something by the same thing...
return orig_num, orig_denum return orig_num, orig_denum
return ct + num, denum return ct + num, denum
...@@ -528,6 +674,11 @@ class Canonizer(gof.LocalOptimizer): ...@@ -528,6 +674,11 @@ class Canonizer(gof.LocalOptimizer):
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
# I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have
# inverse operating on an inverse, we can make it so that only
# one inverse is used, so we'll reorganize that.
iops = set(input.owner.op for input in inputs if input.owner) iops = set(input.owner.op for input in inputs if input.owner)
reorg = False reorg = False
if op == self.main: if op == self.main:
...@@ -537,8 +688,11 @@ class Canonizer(gof.LocalOptimizer): ...@@ -537,8 +688,11 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal: elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0 reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
# just in case
assert len(node.outputs) == 1 assert len(node.outputs) == 1
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0]) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum) num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum) num, denum = self.simplify(num, denum)
...@@ -547,6 +701,8 @@ class Canonizer(gof.LocalOptimizer): ...@@ -547,6 +701,8 @@ class Canonizer(gof.LocalOptimizer):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y)) return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
if not reorg and same(orig_num, num) and same(orig_denum, denum): if not reorg and same(orig_num, num) and same(orig_denum, denum):
# We return False if there are no changes
# TODO: what's the purpose of reorg? isn't same() sufficient?
return False return False
new = self.merge_num_denum(num, denum) new = self.merge_num_denum(num, denum)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论