提交 f267e481 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

commit

...@@ -102,6 +102,19 @@ from theano.tensor import * ...@@ -102,6 +102,19 @@ from theano.tensor import *
from theano.sandbox import pprint from theano.sandbox import pprint
class _test_greedy_distribute(unittest.TestCase):
def test_main(self):
a, b, c, d, x, y, z = matrices('abcdxyz')
e = (a/z + b/x) * x * z
g = Env([a,b,c,d,x,y,z], [e])
print pprint.pp.process(g.outputs[0])
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_greedy_distributor), order = 'out_to_in').optimize(g)
print pprint.pp.process(g.outputs[0])
class _test_canonize(unittest.TestCase): class _test_canonize(unittest.TestCase):
def test_muldiv(self): def test_muldiv(self):
......
...@@ -615,6 +615,9 @@ class Log(UnaryScalarOp): ...@@ -615,6 +615,9 @@ class Log(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return gz / x, return gz / x,
def c_code(self, node, name, (x, ), (z, ), sub): def c_code(self, node, name, (x, ), (z, ), sub):
#todo: the version using log2 seems to be very slightly faster
# on some machines for some reason, check if it's worth switching
#return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals()
return "%(z)s = log(%(x)s);" % locals() return "%(z)s = log(%(x)s);" % locals()
log = Log(upgrade_to_float, name = 'log') log = Log(upgrade_to_float, name = 'log')
......
...@@ -6,6 +6,17 @@ import scalar ...@@ -6,6 +6,17 @@ import scalar
import tensor as T import tensor as T
import numpy as N import numpy as N
import operator import operator
import itertools
# Utilities
def out2in(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'out_to_in')
def in2out(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'in_to_out')
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
# Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0) # Transforms d -= a * dot(b, c) into gemm(d, -a, b, c, 1.0)
...@@ -30,7 +41,7 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'), ...@@ -30,7 +41,7 @@ dot_to_gemm = gof.PatternSub((T.dot, 'a', 'b'),
@gof.optimizer @gof.optimizer
def inplace_optimizer(self, env): def insert_inplace_optimizer(self, env):
""" """
Usage: inplace_optimizer.optimize(env) Usage: inplace_optimizer.optimize(env)
...@@ -64,6 +75,10 @@ def inplace_optimizer(self, env): ...@@ -64,6 +75,10 @@ def inplace_optimizer(self, env):
baseline = inplace_pattern baseline = inplace_pattern
break break
inplace_optimizer = gof.SeqOptimizer(out2in(gemm_pattern_1),
out2in(dot_to_gemm),
insert_inplace_optimizer)
###################### ######################
# DimShuffle lifters # # DimShuffle lifters #
...@@ -293,6 +308,43 @@ def local_fill_sink(node): ...@@ -293,6 +308,43 @@ def local_fill_sink(node):
################ ################
class Canonizer(gof.LocalOptimizer): class Canonizer(gof.LocalOptimizer):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, calculate)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* calculate: function that takes a list of numpy.ndarray instances for
the numerator, another list for the denumerator, and calculates
inverse(main(*num), main(*denum)). It takes a keyword argument,
aslist. If True, the value should be returned as a list of one
element, unless the value is such that value = main(). In that
case, the return value should be an empty list.
The result is a local_optimizer. It is best used with a TopoOptimizer in
in_to_out order.
Examples:
T = theano.tensor
add_canonizer = Canonizer(T.add, T.sub, T.neg, lambda n, d: sum(n) - sum(d))
mul_canonizer = Canonizer(T.mul, T.div, T.inv, lambda n, d: prod(n) / prod(d))
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def __init__(self, main, inverse, reciprocal, calculate): def __init__(self, main, inverse, reciprocal, calculate):
self.main = main self.main = main
...@@ -332,11 +384,14 @@ class Canonizer(gof.LocalOptimizer): ...@@ -332,11 +384,14 @@ class Canonizer(gof.LocalOptimizer):
return self.inverse(self.merge_num_denum(num, []), return self.inverse(self.merge_num_denum(num, []),
self.merge_num_denum(denum, [])) self.merge_num_denum(denum, []))
def get_constant(self, v): @classmethod
def get_constant(cls, v):
if isinstance(v, N.generic):
return v
if isinstance(v, gof.Constant): if isinstance(v, gof.Constant):
return v.data return v.data
if v.owner and isinstance(v.owner.op, DimShuffle): if v.owner and isinstance(v.owner.op, DimShuffle):
return self.get_constant(v.owner.inputs[0]) return cls.get_constant(v.owner.inputs[0])
return None return None
def simplify(self, num, denum): def simplify(self, num, denum):
...@@ -366,7 +421,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -366,7 +421,9 @@ class Canonizer(gof.LocalOptimizer):
denum.remove(v) denum.remove(v)
denumct.append(ct) denumct.append(ct)
ct = self.calculate(numct, denumct, aslist = True) ct = self.calculate(numct, denumct, aslist = True)
if len(ct) and ncc == 1 and dcc == 0: # if len(ct) and ncc == 1 and dcc == 0:
# return orig_num, orig_denum
if orig_num and ct == self.get_constant(orig_num[0]):
return orig_num, orig_denum return orig_num, orig_denum
return ct + num, denum return ct + num, denum
...@@ -398,6 +455,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -398,6 +455,7 @@ class Canonizer(gof.LocalOptimizer):
new = T.fill(out, new) new = T.fill(out, new)
return [new] return [new]
def mul_calculate(num, denum, aslist = False): def mul_calculate(num, denum, aslist = False):
v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0) v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0)
if aslist: if aslist:
...@@ -408,7 +466,26 @@ def mul_calculate(num, denum, aslist = False): ...@@ -408,7 +466,26 @@ def mul_calculate(num, denum, aslist = False):
return v return v
local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate) local_mul_canonizer = Canonizer(T.mul, T.div, T.inv, mul_calculate)
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_sink), order = 'in_to_out')
@gof.local_optimizer
def local_neg_to_mul(node):
if node.op == T.neg:
return [-1.0 * node.inputs[0]]
else:
return False
@gof.local_optimizer
def local_mul_to_neg(node):
if node.op == T.mul and local_mul_canonizer.get_constant(node.inputs[0]) == -1.0:
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])]
else:
return False
neg_to_mul = gof.TopoOptimizer(gof.LocalOptGroup(local_neg_to_mul), order = 'out_to_in')
mul_to_neg = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_to_neg), order = 'out_to_in')
mul_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink), order = 'in_to_out')
def add_calculate(num, denum, aslist = False): def add_calculate(num, denum, aslist = False):
v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0) v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
...@@ -420,7 +497,7 @@ def add_calculate(num, denum, aslist = False): ...@@ -420,7 +497,7 @@ def add_calculate(num, denum, aslist = False):
return v return v
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate) local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_sink), order = 'in_to_out') add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_fill_cut, local_fill_sink), order = 'in_to_out')
################## ##################
...@@ -429,488 +506,144 @@ add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_f ...@@ -429,488 +506,144 @@ add_canonizer = gof.TopoOptimizer(gof.LocalOptGroup(local_add_canonizer, local_f
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0): def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore = 0):
score = len(num) + len(denum) # score is number of operations saved, higher is better # each pair in pos_pairs and neg_pairs is a num/denum pair. this
new_pos_pairs = itertools.starmap(local_mul_canonizer.simplify, # function attempts to add num and denum to the corresponding parts
[(n+num, d+denum) for (n, d) in plus_pairs]) # of each pair, and counts how many multiplications/divisions can
new_neg_pairs = itertools.starmap(local_mul_canonizer.simplify, # be saved in that way.
[(n+num, d+denum) for (n, d) in plus_pairs])
# each division is counted like div_cost multiplications
# (typically, division costs more so we are willing to multiply more
# in order to divide less)
# 1.5 was obtained through an informal test and may very well be
# platform dependent
div_cost = 1.5
score = len(num) + div_cost * len(denum) # score is number of operations saved, higher is better
new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in pos_pairs]))
new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n+num, d+denum) for (n, d) in neg_pairs]))
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs): for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs):
# We calculate how many operations we are saving with the new num and denum # We calculate how many operations we are saving with the new num and denum
score += len(n) + len(d) - len(nn) - len(dd) score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd)
if score < minscore: if score <= minscore:
# the change is not applied because it adds too many operations
return False, pos_pairs, neg_pairs return False, pos_pairs, neg_pairs
return True, new_pos_pairs, new_neg_pairs return True, new_pos_pairs, new_neg_pairs
def attempt_distribution(factor, num, denum):
# we try to insert each num and each denum in the factor
# returns: changes?, new_factor, new_num, new_denum
# if there are changes, new_num and new_denum contain all the numerators
# and denumerators that could not be distributed in the factor
pos, neg = local_add_canonizer.get_num_denum(factor)
if len(pos) == 1 and not neg:
return False, factor, num, denum
pos_pairs = map(local_mul_canonizer.get_num_denum, pos)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg)
change = False
for n in list(num):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, neg_pairs, [n], [])
if success:
change = True
num.remove(n)
for d in list(denum):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, neg_pairs, [], [d])
if success:
change = True
denum.remove(d)
if not change:
return change, factor, num, denum
else:
return change, local_add_canonizer.merge_num_denum(
list(itertools.starmap(local_mul_canonizer.merge_num_denum, pos_pairs)),
list(itertools.starmap(local_mul_canonizer.merge_num_denum, neg_pairs))), num, denum
@gof.local_optimizer @gof.local_optimizer
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
This optimization tries to apply distributivity of multiplication
to addition in order to reduce the number of multiplications
and/or divisions that must be done. The algorithm weighs division
more than multiplication to account for the former's slightly
greater computational cost.
The following expressions are simplified: The following expressions are simplified:
((a/x + b/y) * x * y) --> a*y + b*x 1. ((a/x + b/y) * x * y) --> a*y + b*x
((a/x + b) * x) --> a + b*x 2. ((a/x + b) * x) --> a + b*x
The following expressions are not: The following expressions are not simplified:
((a + b) * x) -X-> a*x + b*x 3. ((a + b) * x) -/-> a*x + b*x
This optimization aims to reduce computational cost. It may also
increase numerical stability, e.g. when x and/or y tend to 0 in
example 1.
""" """
out = node.outputs[0] out = node.outputs[0]
num, denum = local_mul_canonizer.get_num_denum(out) num, denum = local_mul_canonizer.get_num_denum(out)
if len(num) == 1 and not denum: if len(num) == 1 and not denum:
return False return False
new_num = []
for entry in num:
pos, neg = local_add_canonizer.get_num_denum(entry)
if len(pos) == 1 and not neg:
new_num.append(entry)
continue
pos_pairs = map(local_mul_canonizer.get_num_denum, pos)
neg_pairs = map(local_mul_canonizer.get_num_denum, neg)
# class Canonizer(gof.LocalOptimizer):
# def __init__(self, main, inverse, reciprocal, simplify_constants, constant_op):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.simplify_constants = simplify_constants
# self.constant_op = constant_op
# def get_num_denum(self, input, depth):
# if depth == 0 or input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
# return [input], []
# num = []
# denum = []
# parent = input.owner
# pairs = [self.get_num_denum(input, depth - 1) for input in parent.inputs]
# if parent.op == self.main:
# num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
# denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
# elif parent.op == self.inverse:
# num = pairs[0][0] + pairs[1][1]
# denum = pairs[0][1] + pairs[1][0]
# elif parent.op == self.reciprocal:
# num = pairs[0][1]
# denum = pairs[0][0]
# return num, denum
# def deep_num_denum(self, node):
# op = node.op
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# num, denum = [node.outputs[0]], []
# return num, denum
# def get_num_denum(self, inputs):
# num = []
# denum = []
# for input in inputs:
# if input.owner is None:
# num.append(input)
# continue
# parent = input.owner
# if parent.op == self.main:
# num += parent.inputs
# elif parent.op == self.inverse:
# num += parent.inputs[:1]
# denum += parent.inputs[1:]
# elif parent.op == self.reciprocal:
# denum += parent.inputs
# else:
# num.append(input)
# return num, denum
# def merge_num_denum(self, num, denum, outtype):
# ln, ld = len(num), len(denum)
# if not ln and not ld:
# return outtype.filter(self.simplify_constants([], []))
# if not ln:
# return self.reciprocal(self.merge_num_denum(denum, [], outtype))
# if not ld:
# if ln == 1:
# return num[0]
# else:
# return self.main(*num)
# return self.inverse(self.merge_num_denum(num, [], outtype),
# self.merge_num_denum(denum, [], outtype))
# def get_constant(self, v):
# if isinstance(v, gof.Constant):
# return v.data
# if v.owner and isinstance(v.owner.op, DimShuffle):
# return self.get_constant(v.owner.inputs[0])
# return None
# def simplify(self, num, denum):
# numct, denumct = [], []
# ncc, dcc = 0, 0
# for v in list(num):
# if v in denum:
# num.remove(v)
# denum.remove(v)
# continue
# ct = self.get_constant(v)
# if ct is not None:
# ncc += 1
# num.remove(v)
# numct.append(ct)
# for v in list(denum):
# ct = self.get_constant(v)
# if ct is not None:
# dcc += 1
# denum.remove(v)
# denumct.append(ct)
# ct = self.simplify_constants(numct, denumct)
# if ct is None:
# return ncc+dcc>0, None, num, denum
# ctop = self.constant_op.get(ct)
# if ctop is not None:
# return True, ctop, num, denum
# return not (ncc==1 and dcc==0), None, [ct]+num, denum
# def transform(self, node):
# op = node.op
# inputs = node.inputs
# if op == self.main:
# num, denum = self.get_num_denum(inputs)
# elif op == self.inverse:
# assert len(inputs) == 2
# n1, d1 = self.get_num_denum(inputs[:1])
# n2, d2 = self.get_num_denum(inputs[1:])
# num, denum = n1+d2, d1+n2
# elif op == self.reciprocal:
# denum, num = self.get_num_denum(inputs)
# else:
# return False
# change, ctop, num2, denum2 = self.simplify(num, denum)
# if change:
# num, denum = num2, denum2
# # print node, ct, num, denum
# # ctop = ct != [] and self.constant_op.get(ct[0], None)
# # if not ctop:
# # num = ct + num
# new = self.merge_num_denum(num, denum, node.outputs[0].type)
# if ctop:
# new = ctop(new)
# print new.owner.op, op, new.owner.inputs, inputs
# if new.owner and new.owner.op == op and all((new_input.owner new.owner.inputs == inputs:
# return False
# return [new]
# @gof.local_optimizer
# def local_cut_middlemen(node):
# op = node.op
# if isinstance(op, Elemwise):
# aaaaaaa
# # @gof.local_optimizer
# # def local_merge_mul(node):
# # op = node.op
# # if op != mul:
# # return False
# # num, denum = _get_num_denum(node.inputs)
# # if num == node.inputs and denum == []:
# # return False
# # return _
new_num, new_denum = [], []
change = False
for candidate in list(num):
if candidate not in num:
continue
num.remove(candidate)
_change, candidate, num, denum = attempt_distribution(candidate, num, denum)
change |= _change
if change:
new_num.append(candidate)
for candidate in list(denum):
if candidate not in denum:
continue
denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate, denum, num)
change |= _change
if change:
new_denum.append(candidate)
if not change:
return False
new_num += num
new_denum += denum
return [local_mul_canonizer.merge_num_denum(new_num, new_denum)]
def _math_optimizer():
pass_1 = in2out(local_fill_sink)
pass_2 = out2in(local_dimshuffle_lift, local_shape_lift, local_fill_lift)#, local_fill_cut)
pass_3 = out2in(local_subtensor_make_vector, local_fill_cut)
# class Lift(gof.LocalOptimizer):
# def __init__(self, op, lifters, chooser):
# self.op = op
# self.lifters = lifters
# self.chooser = chooser
# def op_key(self):
# return self.op
# def transform(self, node):
# if not node.op == self.op:
# return False
# candidates = [node.inputs[0]]
# seen = set(candidates)
# while True:
# candidate = candidates.pop()
# for lifter in self.lifters:
# new_candidates = lifter(candidate)
# if not new_candidates:
# break
# else:
# candidates.append(candidate)
# new_op = self.op(self.chooser(candidates))
# return new_op
# class Canonizer(gof.Optimizer):
# """
# Simplification tool.
# Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
# * main: a suitable Op class that is commutative, associative and takes
# one to an arbitrary number of inputs, e.g. Add or Mul
# * inverse: an Op class such that inverse(main(x, y), y) == x
# e.g. Sub or Div
# * reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
# e.g. Neg or Inv
# * mainfn, invfn, recfn: functions that behave just like the previous three
# Ops, but on true scalars (e.g. their impl)
# * transform: a function that maps (numerator, denominatur) where numerator
# and denominator are lists of Result instances, to new lists
# where further simplifications may have been applied.
# Examples:
# add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
# mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
# Examples of optimizations mul_canonizer can perform: canonizer = in2out(local_add_canonizer,
# x / x -> 1 local_mul_canonizer,
# (x * y) / x -> y local_fill_sink)
# x / y / x -> 1 / y
# x / y / z -> x / (y * z)
# x / (y / z) -> (x * z) / y
# (a / b) * (b / c) * (c / d) -> a / d
# (2.0 * x) / (4.0 * y) -> (0.5 * x) / y
# 2 * x / 2 -> x
# """
# def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
# self.main = main
# self.inverse = inverse
# self.reciprocal = reciprocal
# self.mainfn = mainfn
# self.invfn = invfn
# self.recfn = recfn
# self.neutral = mainfn()
# self.transform = transform
# def apply(self, env):
# def edge(r):
# return r.owner is None
# def follow(r):
# return None if r.owner is None else r.owner.inputs
# def canonize(r):
# next = follow(r)
# if next is None:
# return
# def flatten(r, nclients_check = True):
# # Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# # into a list of numerators and a list of denominators
# # e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
# if edge(r):
# return [r], []
# node = r.owner
# op = node.op
# results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
# if op == self.main and (not nclients_check or env.nclients(r) == 1):
# nums = [x[0] for x in results]
# denums = [x[1] for x in results]
# elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the second argument are added to the denum, num respectively
# nums = [results[0][0], results[1][1]]
# denums = [results[0][1], results[1][0]]
# elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# # num, denum of the sole argument are added to the denum, num respectively
# nums = [results[0][1]]
# denums = [results[0][0]]
# else:
# return [r], []
# return reduce(list.__add__, nums), reduce(list.__add__, denums)
# num, denum = flatten(r, False)
# if (num, denum) == ([r], []):
# for input in (follow(r) or []):
# canonize(input)
# return
# # Terms that are both in the num and denum lists cancel each other
# for d in list(denum):
# if d in list(num):
# # list.remove only removes the element once
# num.remove(d)
# denum.remove(d)
# # We identify the constants in num and denum
# numct, num = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, num)
# denumct, denum = gof.utils.partition(lambda factor: isinstance(factor, gof.Constant) and factor.data is not None, denum)
# #print numct, num
# #print denumct, denum
# print num, denum
# # All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
# v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
# if v != self.neutral:
# num.insert(0, C(v))
# # We optimize the num and denum lists further if requested
# if self.transform is not None:
# num, denum = self.transform(env, num, denum)
# def make(factors):
# # Combines the factors using self.main (aka Mul) depending
# # on the number of elements.
# n = len(factors)
# if n == 0:
# return None
# elif n == 1:
# return factors[0]
# else:
# return self.main(*factors)
# numr, denumr = make(num), make(denum)
# if numr is None:
# if denumr is None:
# # Everything cancelled each other so we're left with
# # the neutral element.
# new_r = gof.Constant(r.type, self.neutral)
# else:
# # There's no numerator so we use reciprocal
# new_r = self.reciprocal(denumr)
# else:
# if denumr is None:
# new_r = numr
# else:
# new_r = self.inverse(numr, denumr)
# # Hopefully this won't complain!
# env.replace(r, new_r)
# for factor in num + denum:
# canonize(factor)
# for output in env.outputs:
# canonize(output)
# _mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# _divfn = lambda x, y: x / y
# _invfn = lambda x: 1 / x
# mul_canonizer = Canonizer(T.mul, T.div, T.inv, _mulfn, _divfn, _invfn)
pass_4 = out2in(local_greedy_distributor)
return gof.SeqOptimizer(pass_1,
pass_2,
pass_3,
neg_to_mul,
canonizer,
pass_4,
mul_to_neg)
math_optimizer = _math_optimizer()
# @gof.local_optimizer
# def local_clique_fusion(node):
# aaaaaaaaaaaaaaaaaaaaaaa
...@@ -920,63 +653,6 @@ def local_greedy_distributor(node): ...@@ -920,63 +653,6 @@ def local_greedy_distributor(node):
# class DimShuffleLifter(opt.Optimizer):
# """
# Usage: lift_dimshuffle.optimize(env)
# "Lifts" DimShuffle through Broadcast operations and merges
# consecutive DimShuffles. Basically, applies the following
# transformations on the whole graph:
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
# After this transform, clusters of Broadcast operations are
# void of DimShuffle operations.
# """
# def apply(self, env):
# seen = set()
# def lift(r):
# if r in seen:
# return
# seen.add(r)
# if env.edge(r):
# return
# op = r.owner
# if isinstance(op, DimShuffle):
# in_op = op.inputs[0].owner
# if isinstance(in_op, DimShuffle):
# # DimShuffle(DimShuffle(x)) => DimShuffle(x)
# new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
# if new_order == range(len(new_order)):
# repl = in_op.inputs[0]
# else:
# repl = DimShuffle(in_op.inputs[0], new_order).out
# env.replace(r, repl)
# lift(repl)
# return
# elif isinstance(in_op, Broadcast):
# # DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
# repl = Broadcast(in_op.scalar_opclass,
# [DimShuffle(input, op.new_order).out for input in in_op.inputs],
# in_op.inplace_pattern).out
# env.replace(r, repl)
# r = repl
# op = r.owner
# for next_r in op.inputs:
# lift(next_r)
# for output in env.outputs:
# lift(output)
# lift_dimshuffle = DimShuffleLifter()
# def find_cliques(env, through_broadcast = False): # def find_cliques(env, through_broadcast = False):
# """ # """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论