提交 a250ccc8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

cleanup

上级 58e51d18
...@@ -8,6 +8,16 @@ import numpy as N ...@@ -8,6 +8,16 @@ import numpy as N
import operator import operator
import itertools import itertools
# Utilities
def out2in(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'out_to_in')
def in2out(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'in_to_out')
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
gemm_pattern_1 = gof.PatternSub((T.sub_inplace, gemm_pattern_1 = gof.PatternSub((T.sub_inplace,
...@@ -31,7 +41,7 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'), ...@@ -31,7 +41,7 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
@gof.optimizer @gof.optimizer
def inplace_optimizer(self, env): def insert_inplace_optimizer(self, env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -65,6 +75,10 @@ def inplace_optimizer(self, env): ...@@ -65,6 +75,10 @@ def inplace_optimizer(self, env):
baseline = inplace_pattern baseline = inplace_pattern
break break
inplace_optimizer = gof.SeqOptimizer(out2in(gemm_pattern_1),
out2in(dot_to_gemm),
insert_inplace_optimizer)
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
...@@ -294,6 +308,43 @@ def local_fill_sink(node): ...@@ -294,6 +308,43 @@ def local_fill_sink(node):
################ ################
class Canonizer(gof.LocalOptimizer): class Canonizer(gof.LocalOptimizer):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* calculate: function that takes a list of numpy.ndarray instances for
the numerator, another list for the denumerator, and calculates
inverse(main(*num), main(*denum)). It takes a keyword argument,
aslist. If True, the value should be returned as a list of one
element, unless the value is such that value = main(). In that
case, the return value should be an empty list.
The result is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order.
Examples:
T = theano.tensor
add_canonizer = Canonizer(T.add, T.sub, T.neg, lambda n, d: sum(n) - sum(d))
mul_canonizer = Canonizer(T.mul, T.div, T.inv, lambda n, d: prod(n) / prod(d))
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def __init__(self, main, inverse, reciprocal, calculate): def __init__(self, main, inverse, reciprocal, calculate):
self.main = main self.main = main
...@@ -567,12 +618,6 @@ def local_greedy_distributor(node): ...@@ -567,12 +618,6 @@ def local_greedy_distributor(node):
def out2in(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'out_to_in')
def in2out(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'in_to_out')
def _math_optimizer(): def _math_optimizer():
pass_1 = in2out(local_fill_sink) pass_1 = in2out(local_fill_sink)
pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut) pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
...@@ -593,219 +638,12 @@ def _math_optimizer(): ...@@ -593,219 +638,12 @@ def _math_optimizer():
mul_to_neg) mul_to_neg)
math_optimizer = _math_optimizer() math_optimizer = _math_optimizer()
# class Canonizer(gof.LocalOptimizer):
# def __init__(self, main, inverse, reciprocal, simplify_constants, constant_op):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.simplify_constants = simplify_constants
# self.constant_op = constant_op
# def get_num_denum(self, input, depth):
# if depth == 0 or input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
# return [input], []
# num = []
# denum = []
# parent = input.owner
# pairs = [self.get_num_denum(input, depth - 1) for input in parent.inputs]
# if parent.op == self.main:
# num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
# denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
# elif parent.op == self.inverse:
# num = pairs[0][0] + pairs[1][1]
# denum = pairs[0][1] + pairs[1][0]
# elif parent.op == self.reciprocal:
# num = pairs[0][1]
# denum = pairs[0][0]
# return num, denum
# def deep_num_denum(self, node):
# op = node.op
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# num, denum = [node.outputs[0]], []
# return num, denum
# def get_num_denum(self, inputs):
# num = []
# denum = []
# for input in inputs:
# if input.owner is None:
# num.append(input)
# continue
# parent = input.owner
# if parent.op == self.main:
# num += parent.inputs
# elif parent.op == self.inverse:
# num += parent.inputs[:1]
# denum += parent.inputs[1:]
# elif parent.op == self.reciprocal:
# denum += parent.inputs
# else:
# num.append(input)
# return num, denum
# def merge_num_denum(self, num, denum, outtype):
# ln, ld = len(num), len(denum)
# if not ln and not ld:
# return outtype.filter(self.simplify_constants([], []))
# if not ln:
# return self.reciprocal(self.merge_num_denum(denum, [], outtype))
# if not ld:
# if ln == 1:
# return num[0]
# else:
# return self.main(*num)
# return self.inverse(self.merge_num_denum(num, [], outtype),
# self.merge_num_denum(denum, [], outtype))
# def get_constant(self, v):
# if isinstance(v, gof.Constant):
# return v.data
# if v.owner and isinstance(v.owner.op, DimShuffle):
# return self.get_constant(v.owner.inputs[0])
# return None
# def simplify(self, num, denum):
# numct, denumct = [], []
# ncc, dcc = 0, 0
# for v in list(num):
# if v in denum:
# num.remove(v)
# denum.remove(v)
# continue
# ct = self.get_constant(v)
# if ct is not None:
# ncc += 1
# num.remove(v)
# numct.append(ct)
# for v in list(denum):
# ct = self.get_constant(v)
# if ct is not None:
# dcc += 1
# denum.remove(v)
# denumct.append(ct)
# ct = self.simplify_constants(numct, denumct)
# if ct is None:
# return ncc+dcc>0, None, num, denum
# ctop = self.constant_op.get(ct)
# if ctop is not None:
# return True, ctop, num, denum
# return not (ncc==1 and dcc==0), None, [ct]+num, denum
# def transform(self, node):
# op = node.op
# inputs = node.inputs
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# return False
# change, ctop, num2, denum2 = self.simplify(num, denum)
# if change:
# num, denum = num2, denum2
# # print node, ct, num, denum
# # ctop = ct != [] and self.constant_op.get(ct[0], None)
# # if not ctop:
# # num = ct + num
# new = self.merge_num_denum(num, denum, node.outputs[0].type)
# if ctop:
# new = ctop(new)
# print new.owner.op, op, new.owner.inputs, inputs
# if new.owner and new.owner.op == op and all((new_input.owner new.owner.inputs == inputs:
# return False
# return [new]
# @gof.local_optimizer # @gof.local_optimizer
# def local_cut_middlemen(node): # def local_clique_fusion(node):
# op = node.op # aaaaaaaaaaaaaaaaaaaaaaa
# if isinstance(op, Elemwise):
# aaaaaaa
# # @gof.local_optimizer
# # def local_merge_mul(node):
# # op = node.op
# # if op != mul:
# # return False
# # num, denum = _get_num_denum(node.inputs)
# # if num == node.inputs and denum == []:
# # return False
# # return _
...@@ -816,290 +654,6 @@ math_optimizer = _math_optimizer() ...@@ -816,290 +654,6 @@ math_optimizer = _math_optimizer()
# class Lift(gof.LocalOptimizer):
# def __init__(self, op, lifters, chooser):
# self.op = op
# self.lifters = lifters
# self.chooser = chooser
# def op_key(self):
# return self.op
# def transform(self, node):
# if not node.op == self.op:
# return False
# candidates = [node.inputs[0]]
# seen = set(candidates)
# while True:
# candidate = candidates.pop()
# for lifter in self.lifters:
# new_candidates = lifter(candidate)
# if not new_candidates:
# break
# else:
# candidates.append(candidate)
# new_op = self.op(self.chooser(candidates))
# return new_op
# class Canonizer(gof.Optimizer):
# """
# Simplification tool.
# Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
# * main: a suitable Op class that is commutative, associative and takes
# one to an arbitrary number of inputs, e.g. Add or Mul
# * inverse: an Op class such that inverse(main(x, y), y) == x
# e.g. Sub or Div
# * reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
# e.g. Neg or Inv
# * mainfn, invfn, recfn: functions that behave just like the previous three
# Ops, but on true scalars (e.g. their impl)
# * transform: a function that maps (numerator, denominatur) where numerator
# and denominator are lists of Result instances, to new lists
# where further simplifications may have been applied.
# Examples:
# add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
# mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
# Examples of optimizations mul_canonizer can perform:
# x / x -> 1
# (x * y) / x -> y
# x / y / x -> 1 / y
# x / y / z -> x / (y * z)
# x / (y / z) -> (x * z) / y
# (a / b) * (b / c) * (c / d) -> a / d
# (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
# 2 * x / 2 -> x
# """
# def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.mainfn = mainfn
# self.invfn = invfn
# self.recfn = recfn
# self.neutral = mainfn()
# self.transform = transform
# def apply(self, env):
# def edge(r):
# return r.owner is None
# def follow(r):
# return None if r.owner is None else r.owner.inputs
# def canonize(r):
# next = follow(r)
# if next is None:
# return
# def flatten(r, nclients_check = True):
# # Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# # into a list of numerators and a list of denominators
# # e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
# if edge(r):
# return [r], []
# node = r.owner
# op = node.op
# results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
# if op == self.main and (not nclients_check or env.nclients(r) == 1):
# nums = [x[0] for x in results]
# denums = [x[1] for x in results]
# elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the second argument are added to the denum, num respectively
# nums = [results[0][0], results[1][1]]
# denums = [results[0][1], results[1][0]]
# elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the sole argument are added to the denum, num respectively
# nums = [results[0][1]]
# denums = [results[0][0]]
# else:
# return [r], []
# return reduce(list.__add__, nums), reduce(list.__add__, denums)
# num, denum = flatten(r, False)
# if (num, denum) == ([r], []):
# for input in (follow(r) or []):
# canonize(input)
# return
# # Terms that are both in the num and denum lists cancel each other
# for d in list(denum):
# if d in list(num):
# # list.remove only removes the element once
# num.remove(d)
# denum.remove(d)
# # We identify the constants in num and denum
# numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
# denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
# #print numct, num
# #print denumct, denum
# print num, denum
# # All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
# v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
# if v != self.neutral:
# num.insert(0, C(v))
# # We optimize the num and denum lists further if requested
# if self.transform is not None:
# num, denum = self.transform(env, num, denum)
# def make(factors):
# # Combines the factors using self.main (aka Mul) depending
# # on the number of elements.
# n = len(factors)
# if n == 0:
# return None
# elif n == 1:
# return factors[0]
# else:
# return self.main(*factors)
# numr, denumr = make(num), make(denum)
# if numr is None:
# if denumr is None:
# # Everything cancelled each other so we're left with
# # the neutral element.
# new_r = gof.Constant(r.type, self.neutral)
# else:
# # There's no numerator so we use reciprocal
# new_r = self.reciprocal(denumr)
# else:
# if denumr is None:
# new_r = numr
# else:
# new_r = self.inverse(numr, denumr)
# # Hopefully this won't complain!
# env.replace(r, new_r)
# for factor in num + denum:
# canonize(factor)
# for output in env.outputs:
# canonize(output)
# _mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# _divfn = lambda x, y: x / y
# _invfn = lambda x: 1 / x
# mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
# class DimShuffleLifter(opt.Optimizer):
# """
# Usage: lift_dimshuffle.optimize(env)
# "Lifts" DimShuffle through Broadcast operations and merges
# consecutive DimShuffles. Basically, applies the following
# transformations on the whole graph:
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
# After this transform, clusters of Broadcast operations are
# void of DimShuffle operations.
# """
# def apply(self, env):
# seen = set()
# def lift(r):
# if r in seen:
# return
# seen.add(r)
# if env.edge(r):
# return
# op = r.owner
# if isinstance(op, DimShuffle):
# in_op = op.inputs[0].owner
# if isinstance(in_op, DimShuffle):
# # DimShuffle(DimShuffle(x)) => DimShuffle(x)
# new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
# if new_order == range(len(new_order)):
# repl = in_op.inputs[0]
# else:
# repl = DimShuffle(in_op.inputs[0], new_order).out
# env.replace(r, repl)
# lift(repl)
# return
# elif isinstance(in_op, Broadcast):
# # DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# repl = Broadcast(in_op.scalar_opclass,
# [DimShuffle(input, op.new_order).out for input in in_op.inputs],
# in_op.inplace_pattern).out
# env.replace(r, repl)
# r = repl
# op = r.owner
# for next_r in op.inputs:
# lift(next_r)
# for output in env.outputs:
# lift(output)
# lift_dimshuffle = DimShuffleLifter()
# def find_cliques(env, through_broadcast = False): # def find_cliques(env, through_broadcast = False):
# """ # """
# Usage: find_cliques(env, through_broadcast = False) # Usage: find_cliques(env, through_broadcast = False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论