提交 a3c57764 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

made submodules for compile, scalar and sparse

上级 444ea1e4
## PENDING REWRITE OF scalar_opt.py
# unittest
# from gof import Result, Op, Env, modes
# import gof
# from scalar import *
# from scalar_opt import *
# def inputs():
# return floats('xyz')
# def more_inputs():
# return floats('abcd')
# class _test_opts(unittest.TestCase):
# def test_pow_to_sqr(self):
# x, y, z = floats('xyz')
# e = x ** 2.0
# g = Env([x], [e])
# assert str(g) == "[pow(x, 2.0)]"
# pow2sqr_float.optimize(g)
# assert str(g) == "[sqr(x)]"
# class _test_canonize(unittest.TestCase):
# # def test_muldiv(self):
# # x, y, z = inputs()
# # a, b, c, d = more_inputs()
# # # e = (2.0 * x) / (2.0 * y)
# # # e = (2.0 * x) / (4.0 * y)
# # # e = x / (y / z)
# # # e = (x * y) / x
# # # e = (x / y) * (y / z) * (z / x)
# # # e = (a / b) * (b / c) * (c / d)
# # # e = (a * b) / (b * c) / (c * d)
# # # e = 2 * x / 2
# # e = x / y / x
# # g = Env([x, y, z, a, b, c, d], [e])
# # print g
# # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# # divfn = lambda x, y: x / y
# # invfn = lambda x: 1 / x
# # Canonizer(mul, div, inv, mulfn, divfn, invfn).optimize(g)
# # print g
# # def test_plusmin(self):
# # x, y, z = inputs()
# # a, b, c, d = more_inputs()
# # # e = x - x
# # # e = (2.0 + x) - (2.0 + y)
# # # e = (2.0 + x) - (4.0 + y)
# # # e = x - (y - z)
# # # e = (x + y) - x
# # # e = (x - y) + (y - z) + (z - x)
# # # e = (a - b) + (b - c) + (c - d)
# # # e = x + -y
# # # e = a - b - b + a + b + c + b - c
# # # e = x + log(y) - x + y
# # e = 2.0 + x + 4.0
# # g = Env([x, y, z, a, b, c, d], [e])
# # print g
# # gof.ConstantFinder().optimize(g)
# # addfn = lambda *inputs: sum(inputs)
# # subfn = lambda x, y: x - y
# # negfn = lambda x: -x
# # Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# # print g
# # def test_both(self):
# # x, y, z = inputs()
# # a, b, c, d = more_inputs()
# # e0 = (x * y / x)
# # e = e0 + e0 - e0
# # g = Env([x, y, z, a, b, c, d], [e])
# # print g
# # gof.ConstantFinder().optimize(g)
# # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# # divfn = lambda x, y: x / y
# # invfn = lambda x: 1 / x
# # Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# # addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# # subfn = lambda x, y: x - y
# # negfn = lambda x: -x
# # Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# # print g
# # def test_group_powers(self):
# # x, y, z, a, b, c, d = floats('xyzabcd')
# ###################
# # c1, c2 = constant(1.), constant(2.)
# # #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# # e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# # g = Env([x, y, z, a, b, c, d], [e])
# # print g
# # print g.inputs, g.outputs, g.orphans
# # f = sub(add(2.0, y), add(7.0))
# # g.replace(e, pow(x, f))
# # print g
# # print g.inputs, g.outputs, g.orphans
# # g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# # print g
# # print g.inputs, g.outputs, g.orphans
# ###################
# # # e = x * exp(y) * exp(z)
# # # e = x * pow(x, y) * pow(x, z)
# # # e = pow(x, y) / pow(x, z)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # # e = pow(x - x, y)
# # # e = pow(x, 2.0 + y - 7.0)
# # # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # # e = pow(x, 2.0 + y - 7.0 - z)
# # # e = x ** y / x ** y
# # # e = x ** y / x ** (y - 1.0)
# # # e = exp(x) * a * exp(y) / exp(z)
# # g = Env([x, y, z, a, b, c, d], [e])
# # g.extend(gof.PrintListener(g))
# # print g, g.orphans
# # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# # divfn = lambda x, y: x / y
# # invfn = lambda x: 1 / x
# # Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# # print g, g.orphans
# # addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# # subfn = lambda x, y: x - y
# # negfn = lambda x: -x
# # Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# # print g, g.orphans
# # pow2one_float.optimize(g)
# # pow2x_float.optimize(g)
# # print g, g.orphans
# if __name__ == '__main__':
# unittest.main()
#import function
from function import *
...@@ -5,9 +5,8 @@ import cPickle ...@@ -5,9 +5,8 @@ import cPickle
from functools import partial from functools import partial
import numpy import numpy
import gof from .. import gof
import sys import sys
from copy import copy from copy import copy
......
from basic import *
from basic import _abs
import opt
import unittest import unittest
from gof import Result, Op, Env from ..gof import Result, Op, Env
import gof from .. import gof
from scalar import * from basic import *
def inputs(): def inputs():
......
## PENDING REWRITE OF opt.py
...@@ -4,9 +4,9 @@ from copy import copy ...@@ -4,9 +4,9 @@ from copy import copy
import numpy import numpy
import gof from .. import gof
from gof import Op, utils, Result, Constant, Type, Apply, Env from ..gof import Op, utils, Result, Constant, Type, Apply, Env
from gof.python25 import partial from ..gof.python25 import partial
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
......
# TODO: everything?
from scalar import *
from gof import PatternOptimizer as Pattern
from gof import utils
C = constant
# x**2 -> x*x
pow2sqr_float = Pattern((pow, 'x', C(2.0)), (sqr, 'x'))
pow2sqr_int = Pattern((pow, 'x', C(2)), (sqr, 'x'))
# x**0 -> 1
pow2one_float = Pattern((pow, 'x', C(0.0)), C(1.0))
pow2one_int = Pattern((pow, 'x', C(0)), C(1))
# x**1 -> x
pow2x_float = Pattern((pow, 'x', C(1.0)), 'x')
pow2x_int = Pattern((pow, 'x', C(1)), 'x')
# log(x**y) -> y*log(x)
logpow = Pattern((log, (pow, 'x', 'y')),
(mul, 'y', (log, 'x')))
class Canonizer(gof.Optimizer):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
self.main = main
self.inverse = inverse
self.reciprocal = reciprocal
self.mainfn = mainfn
self.invfn = invfn
self.recfn = recfn
self.neutral = mainfn()
self.transform = transform
def apply(self, env):
def canonize(r):
next = env.follow(r)
if next is None:
return
def flatten(r, nclients_check = True):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
if env.edge(r):
return [r], []
print "a", r, r.owner, env, env.orphans
node = r.owner
op = node.op
print "b"
results = [r2.type == r.type and flatten(r2) or ([r2], []) for r2 in node.inputs]
if op == self.main and (not nclients_check or env.nclients(r) == 1):
nums = [x[0] for x in results]
denums = [x[1] for x in results]
elif op == self.inverse and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]]
elif op == self.reciprocal and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]]
denums = [results[0][0]]
else:
return [r], []
return reduce(list.__add__, nums), reduce(list.__add__, denums)
num, denum = flatten(r, False)
if (num, denum) == ([r], []):
for input in (env.follow(r) or []):
canonize(input)
return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum):
if d in list(num):
# list.remove only removes the element once
num.remove(d)
denum.remove(d)
# We identify the constants in num and denum
numct, num = utils.partition(lambda factor: isinstance(factor, Constant) and factor.data is not None, num)
denumct, denum = utils.partition(lambda factor: isinstance(factor, Constant) and factor.data is not None, denum)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral:
num.insert(0, C(v))
# We optimize the num and denum lists further if requested
if self.transform is not None:
num, denum = self.transform(env, num, denum)
def make(factors):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n = len(factors)
if n == 0:
return None
elif n == 1:
return factors[0]
else:
return self.main(*factors)
numr, denumr = make(num), make(denum)
if numr is None:
if denumr is None:
# Everything cancelled each other so we're left with
# the neutral element.
new_r = Constant(r.type, self.neutral)
else:
# There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr)
else:
if denumr is None:
new_r = numr
else:
new_r = self.inverse(numr, denumr)
# Hopefully this won't complain!
env.replace(r, new_r)
for factor in num + denum:
canonize(factor)
for output in env.outputs:
canonize(output)
def group_powers(env, num, denum):
"""
Plugin for Canonizer: use as Canonizer(..., transform = group_powers)
Takes num, denum such that mul(*num) / mul(*denum) is in env
and searches for instances of exp(x) or x**y in order to group
together powers of the same variable. Returns num2, denum2 in
which the grouping has been done.
Note: this function does not modify env.
Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
"""
# maps a base to the list of powers it is raised to in the
# numerator/denominator lists.
num_powers = {}
denum_powers = {}
def populate(d, seq):
# For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power).
for factor in list(seq):
if env.edge(factor):
continue
node = factor.owner
op = node.op
if op == exp:
d.setdefault('e', []).append(node.inputs[0])
seq.remove(factor)
elif op == pow:
d.setdefault(node.inputs[0], []).append(node.inputs[1])
seq.remove(factor)
populate(num_powers, num)
populate(denum_powers, denum)
for x in set(num_powers.keys() + denum_powers.keys()):
# we append base ** (num_powers[base] - denum_powers[base])
# to the num list
try: num_ys = num_powers.pop(x)
except KeyError: num_ys = []
try: denum_ys = denum_powers.pop(x)
except KeyError: denum_ys = []
num_r = num_ys and add(*num_ys) or C(0)
denum_r = denum_ys and add(*denum_ys) or C(0)
if x == 'e':
num.append(exp(num_r - denum_r))
else:
num.append(pow(x, num_r - denum_r))
return num, denum
def simple_factorize(env, num, denum):
# a*b + a*c -> a*(b+c)
# a*b + a*c + b*c -> a*(b+c) + b*c
# -> a*b + (a+b)*c
# => a: {b, c}, b: {a, c}, c: {a, b}
# a*c + a*d + b*c + b*d
# => a: {c, d}, b: {c, d}, c: {a, b}, d: {a, b}
# (a+b*x)*(c+d) --> a*c + a*d + b*x*c + b*x*d
# => a: {c, d}, b: {xc, xd}, c: {a, bx}, d: {a, bx}, x: {bc, bd}
pass
from basic import *
from sparse import * from sparse import *
import unittest import unittest
import compile from .. import compile
import gradient from .. import gradient
from sparse import _is_dense, _is_sparse, _is_dense_result, _is_sparse_result from basic import _is_dense, _is_sparse, _is_dense_result, _is_sparse_result
from sparse import _mtypes, _mtype_to_str from basic import _mtypes, _mtype_to_str
import random import random
import gof from .. import gof
def eval_outputs(outputs): def eval_outputs(outputs):
return compile.function([], outputs)()[0] return compile.function([], outputs)()[0]
......
...@@ -9,9 +9,8 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/ ...@@ -9,9 +9,8 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
import numpy import numpy
from scipy import sparse from scipy import sparse
import gof from .. import gof
import gof.op from .. import tensor
import tensor
""" Types of sparse matrices to use for testing """ """ Types of sparse matrices to use for testing """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论