提交 bf3df169 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixes here and there, doc

上级 7158e6d8
...@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase): ...@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
# x, y, z = inputs() # x, y, z = inputs()
# a, b, c, d = more_inputs() # a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y) # # e = (2.0 * x) / (2.0 * y)
# # e = (2.0 * x) / (4.0 * y) # e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z) # # e = x / (y / z)
# # e = (x * y) / x # # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x) # # e = (x / y) * (y / z) * (z / x)
...@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase): ...@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
# # e = (a - b) + (b - c) + (c - d) # # e = (a - b) + (b - c) + (c - d)
# # e = x + -y # # e = x + -y
# # e = a - b - b + a + b + c + b - c # # e = a - b - b + a + b + c + b - c
# e = x + log(y) - x + y # # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e]) # g = Env([x, y, z, a, b, c, d], [e])
# print g # print g
# gof.ConstantFinder().optimize(g) # gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs) # addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y # subfn = lambda x, y: x - y
# negfn = lambda x: -x # negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g) # Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
......
...@@ -58,7 +58,7 @@ class BaseTensor(Result): ...@@ -58,7 +58,7 @@ class BaseTensor(Result):
# filter # filter
# #
def filter(self, arr): def filter(self, arr):
"""cast to an L{numpy.ndarray} and ensure arr has correct rank, shape""" """Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape."""
if not (isinstance(arr, numpy.ndarray) \ if not (isinstance(arr, numpy.ndarray) \
and arr.dtype==self.dtype): and arr.dtype==self.dtype):
arr = numpy.asarray(arr, dtype = self.dtype) arr = numpy.asarray(arr, dtype = self.dtype)
...@@ -102,6 +102,9 @@ class BaseTensor(Result): ...@@ -102,6 +102,9 @@ class BaseTensor(Result):
# Description for constant folding # Description for constant folding
# #
def desc(self): def desc(self):
"""
Returns a hashable description of this BaseTensor.
"""
if self.data is not None: if self.data is not None:
return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:]) return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:])
else: else:
...@@ -210,6 +213,7 @@ class BaseTensor(Result): ...@@ -210,6 +213,7 @@ class BaseTensor(Result):
}; };
""" """
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64) return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
# todo: use C templating
############################ ############################
......
差异被折叠。
...@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')), ...@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
class Canonizer(gof.Optimizer): class Canonizer(gof.Optimizer):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None): def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
self.main = main self.main = main
...@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer): ...@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
def apply(self, env): def apply(self, env):
def canonize(r): def canonize(r):
if r in env.inputs or r in env.orphans(): if r in env.inputs or r in env.orphans():
return return
def flatten(r, nclients_check = True): def flatten(r, nclients_check = True):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op = r.owner op = r.owner
if op is None or r in env.inputs or r in env.orphans(): if op is None or r in env.inputs or r in env.orphans():
return [r], [] return [r], []
...@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer): ...@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
nums = [x[0] for x in results] nums = [x[0] for x in results]
denums = [x[1] for x in results] denums = [x[1] for x in results]
elif isinstance(op, self.inverse) and (not nclients_check or env.nclients(r) == 1): elif isinstance(op, self.inverse) and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]] nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]] denums = [results[0][1], results[1][0]]
elif isinstance(op, self.reciprocal) and (not nclients_check or env.nclients(r) == 1): elif isinstance(op, self.reciprocal) and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]] nums = [results[0][1]]
denums = [results[0][0]] denums = [results[0][0]]
else: else:
...@@ -70,22 +110,29 @@ class Canonizer(gof.Optimizer): ...@@ -70,22 +110,29 @@ class Canonizer(gof.Optimizer):
canonize(input) canonize(input)
return return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum): for d in list(denum):
if d in list(num): if d in list(num):
# list.remove only removes the element once
num.remove(d) num.remove(d)
denum.remove(d) denum.remove(d)
# We identify the constants in num and denum
numct, num = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, num) numct, num = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, num)
denumct, denum = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, denum) denumct, denum = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, denum)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct])) v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral: if v != self.neutral:
num.insert(0, C(v)) num.insert(0, C(v))
# We optimize the num and denum lists further if requested
if self.transform is not None: if self.transform is not None:
num, denum = self.transform(env, num, denum) num, denum = self.transform(env, num, denum)
def make(factors): def make(factors):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n = len(factors) n = len(factors)
if n == 0: if n == 0:
return None return None
...@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer): ...@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
if numr is None: if numr is None:
if denumr is None: if denumr is None:
# Everything cancelled each other so we're left with
# the neutral element.
new_r = Scalar(dtype = r.dtype) new_r = Scalar(dtype = r.dtype)
new_r.constant = True new_r.constant = True
new_r.data = self.neutral new_r.data = self.neutral
else: else:
# There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr).out new_r = self.reciprocal(denumr).out
else: else:
if denumr is None: if denumr is None:
...@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer): ...@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
else: else:
new_r = self.inverse(numr, denumr).out new_r = self.inverse(numr, denumr).out
# Hopefully this won't complain!
env.replace(r, new_r) env.replace(r, new_r)
for factor in num + denum: for factor in num + denum:
...@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer): ...@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
def group_powers(env, num, denum): def group_powers(env, num, denum):
"""
Plugin for Canonizer: use as Canonizer(..., transform = group_powers)
Takes num, denum such that mul(*num) / mul(*denum) is in env
and searches for instances of exp(x) or x**y in order to group
together powers of the same variable. Returns num2, denum2 in
which the grouping has been done.
Note: this function does not modify env.
Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
"""
# maps a base to the list of powers it is raised to in the
# numerator/denominator lists.
num_powers = {} num_powers = {}
denum_powers = {} denum_powers = {}
def populate(d, seq): def populate(d, seq):
# For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power).
for factor in list(seq): for factor in list(seq):
op = factor.owner op = factor.owner
if op is None or factor in env.inputs or factor in env.orphans(): if op is None or factor in env.inputs or factor in env.orphans():
...@@ -139,6 +207,8 @@ def group_powers(env, num, denum): ...@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
populate(denum_powers, denum) populate(denum_powers, denum)
for x in set(num_powers.keys() + denum_powers.keys()): for x in set(num_powers.keys() + denum_powers.keys()):
# we append base ** (num_powers[base] - denum_powers[base])
# to the num list
try: num_ys = num_powers.pop(x) try: num_ys = num_powers.pop(x)
except KeyError: num_ys = [] except KeyError: num_ys = []
...@@ -148,6 +218,7 @@ def group_powers(env, num, denum): ...@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
num_r = num_ys and add(*num_ys) or C(0) num_r = num_ys and add(*num_ys) or C(0)
denum_r = denum_ys and add(*denum_ys) or C(0) denum_r = denum_ys and add(*denum_ys) or C(0)
if x == 'e': if x == 'e':
num.append(exp(num_r - denum_r)) num.append(exp(num_r - denum_r))
else: else:
......
...@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None): ...@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
if isinstance(data, BaseTensor): if isinstance(data, BaseTensor):
if broadcastable is not None and list(data.broadcastable) != list(broadcastable): if broadcastable is not None and list(data.broadcastable) != list(broadcastable):
raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable)) raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable))
if isinstance(data, Tensor) and (name is None or name == data.name): if name is not None and name != data.name:
raise ValueError("Cannot rename an existing Tensor.")
return data return data
else:
t = Tensor(data.dtype, data.broadcastable, name = name)
t.data = data
return t
elif isinstance(data, Result): elif isinstance(data, Result):
data = data.data raise TypeError("Cannot make a Tensor out of a non-Tensor result.")
if data is None and broadcastable is None: if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None or a Result with no data.") raise TypeError("Cannot make a Tensor out of None.")
data = numpy.asarray(data) data = numpy.asarray(data)
if broadcastable is None: if broadcastable is None:
...@@ -107,38 +104,6 @@ s2t.astensor = astensor ...@@ -107,38 +104,6 @@ s2t.astensor = astensor
# Supporting Ops # Supporting Ops
############################ ############################
def _scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
"""a decorator for operators before broadcasting works properly"""
def f(x, y):
def as_tensor(obj):
if isinstance(obj, Tensor):
return obj
else:
return astensor(obj)
x, y = as_tensor(x), as_tensor(y)
if 0 not in y.broadcastable:
return scalar_f(x, y)
if 0 not in x.broadcastable:
if scalar_f_reverse:
return scalar_f_reverse(y, x)
else:
raise TypeError("You cannot do this operation on a scalar.")
return normal_f(x, y)
return f
def _assert_same_shapes(x, *rest):
"""Ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)"""
shape = x.shape
for other in rest:
if other.shape != shape:
raise ValueError(_assert_same_shapes.E_shape, shape, other.shape)
_assert_same_shapes.E_shape = "The dimensions of the inputs do not match."
def _assert_tensor_scalar(x, a):
"""ensure that the second input is a scalar"""
if numpy.product(a.shape) != 1:
raise ValueError("The second argument must be a scalar.")
# this has a different name, because _as_tensor is the function which ops use # this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor. # to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor = astensor _as_tensor = astensor
...@@ -450,8 +415,6 @@ class Gemm(_Op): ...@@ -450,8 +415,6 @@ class Gemm(_Op):
return ['<iostream>'] return ['<iostream>']
def c_libraries(self): def c_libraries(self):
return blas.ldflags() return blas.ldflags()
#def c_var_names(self):
# return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self, *args): def c_validate_update(self, *args):
return "" return ""
def c_validate_update_cleanup(self, *args): def c_validate_update_cleanup(self, *args):
...@@ -612,125 +575,3 @@ class Gemm(_Op): ...@@ -612,125 +575,3 @@ class Gemm(_Op):
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
gemm = gof.op.constructor(Gemm) gemm = gof.op.constructor(Gemm)
if 0:
##########################
# Comparisons
##########################
# Less-than
class lt_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class lt_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
# Less-than or equal
class le_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class le_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
# Greater-than or equal
class gt_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class gt_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
# Greater-than or equal
class ge_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class ge_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
if 0:
def _broadcastable_pattern(pattern):
def factory(data = None, name = None, dtype=None):
if data:
assert len(data.shape) == len(pattern)
if dtype is not None:
assert dtype is data.dtype
dtype = data.dtype
rval = Tensor(dtype, pattern, name)
rval.data = data
else:
rval = Tensor(dtype, pattern, name)
return rval
return factory
row = _broadcastable_pattern([1, 0])
col = _broadcastable_pattern([0, 1])
matrix = _broadcastable_pattern([0, 0])
if 0: #old __init__ code
"""Create a Tensor
If data is given:
- constant defaults to True
- if dtype is given, it must match data.dtype
- otherwise: default is data.dtype
- if broadcastable is given, len(broadcastable) must match len(data.shape)
- otherwise: if it is constant, it defaults to 1 where shape[i]==1
- if it is not constant, it defaults to 0s
If data is not given:
- constant defaults to False
"""
if dtype is None or broadcastable is None:
if data is None:
raise TypeError("Provide non-None data to complete the dtype and broadcastable flags.")
data = numpy.asarray(data)
if constant is None:
constant = True
dtype = data.dtype
if constant:
broadcastable = [1*(x == 1) for x in data.shape]
else:
broadcastable = [0] * len(data.shape)
if 0:
def tensor__new__(cls, *args, **kwargs):
"""__new__ is overloaded to handle the special form Tensor(x) when x is
a Tensor or an Op whose default output is a Tensor. In these cases, the
argument x is returned, and a new Tensor is not created.
"""
if len(args) == 1:
a = args[0]
t = super(Tensor, cls).__new__(cls, *args, **kwargs)
t.__init__(*args, **kwargs)
return t
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论