提交 4856a655 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

dot2gemm + Canonizer

上级 a901473e
...@@ -5,6 +5,7 @@ import scalar ...@@ -5,6 +5,7 @@ import scalar
import tensor as T import tensor as T
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub_inplace, gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
'd', 'd',
(T.mul, (T.mul,
...@@ -14,6 +15,16 @@ gemm_pattern_1 = gof.PatternSub((T.sub_inplace, ...@@ -14,6 +15,16 @@ gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
(T.gemm, 'd', (T.neg, 'a'), 'b', 'c', T.constant(1.0)), (T.gemm, 'd', (T.neg, 'a'), 'b', 'c', T.constant(1.0)),
allow_multiple_clients = False) allow_multiple_clients = False)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms dot(a, b) into gemm(zeros(2)(hstack(shape(a)[:1], shape(b)[1:])), 1.0, a, b, 1.0)
dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
(T.gemm, (T.Zeros(2),
(T.vertical_stack,
(T.Subtensor([slice(0, 1)]), (T.shape, 'a')),
(T.Subtensor([slice(1, 2)]), (T.shape, 'b')))),
T.constant(1.0), 'a', 'b', T.constant(1.0)),
allow_multiple_clients = False)
class InplaceOptimizer(gof.Optimizer): class InplaceOptimizer(gof.Optimizer):
""" """
...@@ -97,6 +108,182 @@ lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in') ...@@ -97,6 +108,182 @@ lift_dimshuffle = gof.TopoOptimizer(DimShuffleLifter(), order = 'out_to_in')
class Canonizer(gof.Optimizer):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
self.main = main
self.inverse = inverse
self.reciprocal = reciprocal
self.mainfn = mainfn
self.invfn = invfn
self.recfn = recfn
self.neutral = mainfn()
self.transform = transform
def apply(self, env):
def edge(r):
return r.owner is None
def follow(r):
return None if r.owner is None else r.owner.inputs
def canonize(r):
next = follow(r)
if next is None:
return
def flatten(r, nclients_check = True):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
if edge(r):
return [r], []
node = r.owner
op = node.op
results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
if op == self.main and (not nclients_check or env.nclients(r) == 1):
nums = [x[0] for x in results]
denums = [x[1] for x in results]
elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]]
elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]]
denums = [results[0][0]]
else:
return [r], []
return reduce(list.__add__, nums), reduce(list.__add__, denums)
num, denum = flatten(r, False)
if (num, denum) == ([r], []):
for input in (follow(r) or []):
canonize(input)
return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum):
if d in list(num):
# list.remove only removes the element once
num.remove(d)
denum.remove(d)
# We identify the constants in num and denum
numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
#print numct, num
#print denumct, denum
print num, denum
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral:
num.insert(0, C(v))
# We optimize the num and denum lists further if requested
if self.transform is not None:
num, denum = self.transform(env, num, denum)
def make(factors):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n = len(factors)
if n == 0:
return None
elif n == 1:
return factors[0]
else:
return self.main(*factors)
numr, denumr = make(num), make(denum)
if numr is None:
if denumr is None:
# Everything cancelled each other so we're left with
# the neutral element.
new_r = gof.Constant(r.type, self.neutral)
else:
# There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr)
else:
if denumr is None:
new_r = numr
else:
new_r = self.inverse(numr, denumr)
# Hopefully this won't complain!
env.replace(r, new_r)
for factor in num + denum:
canonize(factor)
for output in env.outputs:
canonize(output)
_mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
_divfn = lambda x, y: x / y
_invfn = lambda x: 1 / x
mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
# class DimShuffleLifter(opt.Optimizer): # class DimShuffleLifter(opt.Optimizer):
# """ # """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论