提交 8bdc44e9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

code documentation for Canonizer

上级 2838e342
......@@ -381,19 +381,21 @@ class Canonizer(gof.LocalOptimizer):
Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* main: a suitable Op class that is commutative, associative and
takes one to an arbitrary number of inputs, e.g. add or
mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* calculate: function that takes a list of numpy.ndarray instances for
the numerator, another list for the denumerator, and calculates
inverse(main(*num), main(*denum)). It takes a keyword argument,
aslist. If True, the value should be returned as a list of one
element, unless the value is such that value = main(). In that
case, the return value should be an empty list.
e.g. sub or div
* reciprocal: a function such that main(x, reciprocal(y)) ==
inverse(x, y) e.g. neg or inv
* calculate: function that takes a list of numpy.ndarray instances
for the numerator, another list for the denumerator,
and calculates inverse(main(*num), main(*denum)). It
takes a keyword argument, aslist. If True, the value
should be returned as a list of one element, unless
the value is such that value = main(). In that case,
the return value should be an empty list.
The result is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order.
......@@ -422,38 +424,114 @@ class Canonizer(gof.LocalOptimizer):
self.use_reciprocal = use_reciprocal
def tracks(self):
#return [[None], [None, None], [None]*3, [None]*4, [None]*5]
return [[self.main, None], [self.inverse, None], [self.reciprocal, None]]
def get_num_denum(self, input):
"""
This extract two lists, num and denum, such that the input is:
self.inverse(self.main(*num), self.main(*denum)). It returns
the two lists in a (num, denum) pair.
For example, for main, inverse and reciprocal = *, / and inv(),
input -> returned value (num, denum)
x*y -> ([x, y], [])
inv(x) -> ([], [x])
inv(x) * inv(y) -> ([], [x, y])
x*y/z -> ([x, y], [z])
log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y])
(((a / b) * c) / d) -> ([a, c], [b, d])
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
"""
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
dsn = input.owner
dsop = dsn.op
dsi0 = dsn.inputs[0]
# If input is a DimShuffle of some input which does something like this:
# * change a vector of length N into a 1xN row matrix
# * change a scalar into a 1x1x1 tensor
# * in general, complete the shape of a tensor with broadcastable 1s to the *left*
# Then we will simply discard the DimShuffle and return the num/denum of its input
dsn = input.owner # dimshuffle node
dsop = dsn.op # dimshuffle op
dsi0 = dsn.inputs[0] # the first input of the dimshuffle i.e. the ndarray to redim
# The compatible order is a DimShuffle "new_order" of the form:
# ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim)
# That kind of DimShuffle only adds broadcastable
# dimensions on the left, without discarding any
# existing broadcastable dimension and is inserted
# automatically by Elemwise when the inputs have
# different numbers of dimensions (hence why we can
# discard its information - we know we can retrieve it
# later on).
compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
if dsop.new_order == compatible_order:
# If the "new_order" is the one we recognize,
# we return the num_denum of the dimshuffled input.
return self.get_num_denum(input.owner.inputs[0])
else:
# This is when the input isn't produced by main, inverse or reciprocal.
return [input], []
else:
return [input], []
num = []
denum = []
parent = input.owner
# We get the (num, denum) pairs for each input
pairs = [self.get_num_denum(input) for input in parent.inputs]
if parent.op == self.main:
# If we have main(x, y), numx, denumx, numy and denumy
# then num is concat(numx, numy) and denum is concat(denumx, denumy)
# note that main() can have any number of arguments >= 0
# concat is list concatenation
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
elif parent.op == self.inverse:
# If we have inverse(x, y), numx, denumx, numy and denumy
# then num is concat(numx, denumy) and denum is concat(denumx, numy)
# note that inverse() is binary
num = pairs[0][0] + pairs[1][1]
denum = pairs[0][1] + pairs[1][0]
elif parent.op == self.reciprocal:
# If we have reciprocal(x), numx, denumx
# then num is denumx and denum is numx
# note that reciprocal() is unary
num = pairs[0][1]
denum = pairs[0][0]
return num, denum
def merge_num_denum(self, num, denum):
"""
Utility function which takes two lists, num and denum, and
returns something which is equivalent to inverse(main(*num),
main(*denum)), but depends on the length of num and the length
of denum (in order to minimize the number of operations).
Let n = len(num) and d = len(denum):
n=0, d=0: neutral element (given by self.calculate([], []))
(for example, this would be 0 if main is addition
and 1 if main is multiplication)
n=1, d=0: num[0]
n=0, d=1: reciprocal(denum[0])
n=1, d=1: inverse(num[0], denum[0])
n=0, d>1: reciprocal(main(*denum))
n>1, d=0: main(*num)
n=1, d>1: inverse(num[0], main(*denum))
n>1, d=1: inverse(main(*num), denum[0])
n>1, d>1: inverse(main(*num), main(*denum))
Given the values of n and d to which they are associated, all
of the above are equivalent to:
inverse(main(*num), main(*denum))
"""
ln, ld = len(num), len(denum)
if not ln and not ld:
return T.as_tensor(self.calculate([], []))
......@@ -475,20 +553,52 @@ class Canonizer(gof.LocalOptimizer):
@classmethod
def get_constant(cls, v):
"""
Returns a numeric constant if v is a gof.Constant or, well, a
numeric constant. If v is a plain Result, returns None.
"""
if isinstance(v, N.generic):
return v
return v # doesn't the not hasattr() condition below catch this?
if isinstance(v, gof.Constant):
return v.data
if not hasattr(v, 'owner'):
return v
if v.owner and isinstance(v.owner.op, DimShuffle):
return cls.get_constant(v.owner.inputs[0])
# NOTE: the following code was buggy, but while I was fixing
# it I realized it is probably made useless by constant
# folding, so screw that. Commented-out code is the half-fixed
# version.
# if v.owner and isinstance(v.owner.op, DimShuffle):
# # see the comments in get_num_denum
# # TODO: this should apply the
# dsn = v.owner
# dsop = dsn.op
# dsi0 = dsn.inputs[0]
# compatible_order = ('x',) * (input.type.ndim - dsi0.type.ndim) + tuple(range(dsi0.type.ndim))
# if dsop.new_order == compatible_order:
# return cls.get_constant(v.owner.inputs[0])
return None
def simplify(self, num, denum):
"""
Shorthand for: self.simplify_constants(*self.simplify_factors(num, denum))
"""
return self.simplify_constants(*self.simplify_factors(num, denum))
def simplify_factors(self, num, denum):
"""
For any Result r which is both in num and denum, removes it
from both lists. Modifies the lists inplace. Returns the
modified lists. For example:
[x], [x] -> [], []
[x, y], [x] -> [y], []
[a, b], [c, d] -> [a, b], [c, d]
"""
for v in list(num):
if v in denum:
num.remove(v)
......@@ -496,28 +606,64 @@ class Canonizer(gof.LocalOptimizer):
return num, denum
def simplify_constants(self, orig_num, orig_denum):
"""
Finds all constants in orig_num and orig_denum (using
get_constant) and puts them together into a single
constant. The constant is inserted as the first element of the
numerator. If the constant is the neutral element, it is
removed from the numerator. Examples:
Let main be multiplication:
[2, 3, x], [] -> [6, x], []
[x, y, 2], [4, z] -> [0.5, x, y], [z]
[x, 2, y], [z, 2] -> [x, y], [z]
"""
# Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum)
# Lists representing the *constant* elements of num and denum
numct, denumct = [], []
ncc, dcc = 0, 0
for v in orig_num:
ct = self.get_constant(v)
if ct is not None:
ncc += 1
# We found a constant in the numerator!
# We remove it from num
num.remove(v)
# We add it to numct
numct.append(ct)
for v in orig_denum:
ct = self.get_constant(v)
if ct is not None:
dcc += 1
denum.remove(v)
denumct.append(ct)
if self.use_reciprocal or num:
# This will calculate either:
# [inverse(main(*numct), main(*denumct))]
# [] - if inverse(main(*numct), main(*denumct)) is the neutral element
ct = self.calculate(numct, denumct, aslist = True)
else:
# This happens if we don't allow the reciprocal and the
# numerator is empty. That means we will need to represent
# reciprocal(x) like inverse(neutral_element, x) so
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False)]
# if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum
if orig_num and len(numct) == 1 and ct and N.all(ct == self.get_constant(orig_num[0])):
# TODO: why are we not wrapping ct in a gof.Constant right now?
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on the denominator
# * it's not the neutral element (ct is an empty list in that case)
# * the constant is the same as the first argument in the numerator
# Then we return very exactly the original num/denum
# If we don't do that the optimizer will just loop infinitely because
# it will not catch on that there are no changes to be made and everytime
# it will want to replace something by the same thing...
return orig_num, orig_denum
return ct + num, denum
......@@ -528,6 +674,11 @@ class Canonizer(gof.LocalOptimizer):
if op not in [self.main, self.inverse, self.reciprocal]:
return False
# I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have
# inverse operating on an inverse, we can make it so that only
# one inverse is used, so we'll reorganize that.
iops = set(input.owner.op for input in inputs if input.owner)
reorg = False
if op == self.main:
......@@ -537,8 +688,11 @@ class Canonizer(gof.LocalOptimizer):
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
# just in case
assert len(node.outputs) == 1
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = list(orig_num), list(orig_denum)
num, denum = self.simplify(num, denum)
......@@ -547,6 +701,8 @@ class Canonizer(gof.LocalOptimizer):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
if not reorg and same(orig_num, num) and same(orig_denum, denum):
# We return False if there are no changes
# TODO: what's the purpose of reorg? isn't same() sufficient?
return False
new = self.merge_num_denum(num, denum)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论