提交 bf3df169 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixes here and there, doc

上级 7158e6d8
...@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase): ...@@ -42,7 +42,7 @@ class _test_opts(unittest.TestCase):
# x, y, z = inputs() # x, y, z = inputs()
# a, b, c, d = more_inputs() # a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y) # # e = (2.0 * x) / (2.0 * y)
# # e = (2.0 * x) / (4.0 * y) # e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z) # # e = x / (y / z)
# # e = (x * y) / x # # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x) # # e = (x / y) * (y / z) * (z / x)
...@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase): ...@@ -71,11 +71,12 @@ class _test_opts(unittest.TestCase):
# # e = (a - b) + (b - c) + (c - d) # # e = (a - b) + (b - c) + (c - d)
# # e = x + -y # # e = x + -y
# # e = a - b - b + a + b + c + b - c # # e = a - b - b + a + b + c + b - c
# e = x + log(y) - x + y # # e = x + log(y) - x + y
# e = 2.0 + x + 4.0
# g = Env([x, y, z, a, b, c, d], [e]) # g = Env([x, y, z, a, b, c, d], [e])
# print g # print g
# gof.ConstantFinder().optimize(g) # gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs) # addfn = lambda *inputs: sum(inputs)
# subfn = lambda x, y: x - y # subfn = lambda x, y: x - y
# negfn = lambda x: -x # negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g) # Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
......
...@@ -58,7 +58,7 @@ class BaseTensor(Result): ...@@ -58,7 +58,7 @@ class BaseTensor(Result):
# filter # filter
# #
def filter(self, arr): def filter(self, arr):
"""cast to an L{numpy.ndarray} and ensure arr has correct rank, shape""" """Cast to an L{numpy.ndarray} and ensure arr has correct rank and shape."""
if not (isinstance(arr, numpy.ndarray) \ if not (isinstance(arr, numpy.ndarray) \
and arr.dtype==self.dtype): and arr.dtype==self.dtype):
arr = numpy.asarray(arr, dtype = self.dtype) arr = numpy.asarray(arr, dtype = self.dtype)
...@@ -102,6 +102,9 @@ class BaseTensor(Result): ...@@ -102,6 +102,9 @@ class BaseTensor(Result):
# Description for constant folding # Description for constant folding
# #
def desc(self): def desc(self):
"""
Returns a hashable description of this BaseTensor.
"""
if self.data is not None: if self.data is not None:
return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:]) return (BaseTensor, self.dtype, self.broadcastable, self.data.data[:])
else: else:
...@@ -210,6 +213,7 @@ class BaseTensor(Result): ...@@ -210,6 +213,7 @@ class BaseTensor(Result):
}; };
""" """
return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64) return template % dict(nbits = 64, half_nbits = 32) + template % dict(nbits = 128, half_nbits = 64)
# todo: use C templating
############################ ############################
......
...@@ -7,6 +7,19 @@ import scalar ...@@ -7,6 +7,19 @@ import scalar
class InplaceOptimizer(opt.OpSpecificOptimizer): class InplaceOptimizer(opt.OpSpecificOptimizer):
"""
Usage: inplace_optimizer.optimize(env)
Attempts to replace all Broadcast ops by versions of them
that operate inplace. It operates greedily: for each Broadcast
Op that is encountered, for each output, tries each input to
see if it can operate inplace on that input. If so, makes the
change and go to the next output or Broadcast Op.
Examples:
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
opclass = Broadcast opclass = Broadcast
...@@ -24,6 +37,7 @@ class InplaceOptimizer(opt.OpSpecificOptimizer): ...@@ -24,6 +37,7 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
continue continue
candidate_inputs.remove(candidate_input) candidate_inputs.remove(candidate_input)
op = new_op op = new_op
baseline = inplace_pattern
break break
inplace_optimizer = InplaceOptimizer() inplace_optimizer = InplaceOptimizer()
...@@ -32,6 +46,8 @@ inplace_optimizer = InplaceOptimizer() ...@@ -32,6 +46,8 @@ inplace_optimizer = InplaceOptimizer()
class DimShuffleLifter(opt.Optimizer): class DimShuffleLifter(opt.Optimizer):
""" """
Usage: lift_dimshuffle.optimize(env)
"Lifts" DimShuffle through Broadcast operations and merges "Lifts" DimShuffle through Broadcast operations and merges
consecutive DimShuffles. Basically, applies the following consecutive DimShuffles. Basically, applies the following
transformations on the whole graph: transformations on the whole graph:
...@@ -46,9 +62,6 @@ class DimShuffleLifter(opt.Optimizer): ...@@ -46,9 +62,6 @@ class DimShuffleLifter(opt.Optimizer):
def apply(self, env): def apply(self, env):
seen = set() seen = set()
def merge(ord1, ord2):
return [x == 'x' and 'x' or ord1[x] for x in ord2]
def lift(r): def lift(r):
if r in seen: if r in seen:
...@@ -62,6 +75,7 @@ class DimShuffleLifter(opt.Optimizer): ...@@ -62,6 +75,7 @@ class DimShuffleLifter(opt.Optimizer):
if isinstance(op, DimShuffle): if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle): if isinstance(in_op, DimShuffle):
# DimShuffle(DimShuffle(x)) => DimShuffle(x)
new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order] new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
if new_order == range(len(new_order)): if new_order == range(len(new_order)):
repl = in_op.inputs[0] repl = in_op.inputs[0]
...@@ -71,6 +85,7 @@ class DimShuffleLifter(opt.Optimizer): ...@@ -71,6 +85,7 @@ class DimShuffleLifter(opt.Optimizer):
lift(repl) lift(repl)
return return
elif isinstance(in_op, Broadcast): elif isinstance(in_op, Broadcast):
# DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
repl = Broadcast(in_op.scalar_opclass, repl = Broadcast(in_op.scalar_opclass,
[DimShuffle(input, op.new_order).out for input in in_op.inputs], [DimShuffle(input, op.new_order).out for input in in_op.inputs],
in_op.inplace_pattern).out in_op.inplace_pattern).out
...@@ -87,9 +102,24 @@ lift_dimshuffle = DimShuffleLifter() ...@@ -87,9 +102,24 @@ lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False): def find_cliques(env, through_broadcast = False):
"""
Usage: find_cliques(env, through_broadcast = False)
Returns a list of pairs where each pair contains a list
of inputs and a list of outputs such that Env(inputs, outputs)
contains nothing but Broadcast Ops.
If through_broadcast is False, the cliques will only be
allowed to broadcast over the inputs, which means, for
example, that vector operations will not be mixed with
matrix operations.
"""
def seek_from(r): def seek_from(r):
# walks through the graph until it encounters a
# non-Broadcast operation or (if through_broadcast
# is False) a Result which needs to be broadcasted.
op = r.owner op = r.owner
if r in env.inputs \ if r in env.inputs \
or r in env.orphans() \ or r in env.orphans() \
...@@ -103,6 +133,10 @@ def find_cliques(env, through_broadcast = False): ...@@ -103,6 +133,10 @@ def find_cliques(env, through_broadcast = False):
ret = set() ret = set()
if not through_broadcast: if not through_broadcast:
# check each dimension over all the inputs - if the broadcastable
# fields are not all 0 or all 1 for a particular dimension, then
# broadcasting will be performed along it on the inputs where the
# value is 1 and we will stop.
if any(any(bc) and not all(bc) if any(any(bc) and not all(bc)
for bc in zip(*[input.broadcastable for input in op.inputs])): for bc in zip(*[input.broadcastable for input in op.inputs])):
ret.update(op.inputs) ret.update(op.inputs)
...@@ -111,6 +145,7 @@ def find_cliques(env, through_broadcast = False): ...@@ -111,6 +145,7 @@ def find_cliques(env, through_broadcast = False):
for input in op.inputs: for input in op.inputs:
res = seek_from(input) res = seek_from(input)
if res is None: if res is None:
# input is a leaf of our search
ret.add(input) ret.add(input)
else: else:
ret.update(res) ret.update(res)
...@@ -124,11 +159,14 @@ def find_cliques(env, through_broadcast = False): ...@@ -124,11 +159,14 @@ def find_cliques(env, through_broadcast = False):
return return
clique_inputs = seek_from(r) clique_inputs = seek_from(r)
if clique_inputs is None: if clique_inputs is None:
# Not in a clique, keep going
op = r.owner op = r.owner
if op is not None: if op is not None:
for input in op.inputs: for input in op.inputs:
find_cliques_helper(input) find_cliques_helper(input)
else: else:
# We found a clique, add it to the list and
# jump to the leaves.
cliques.append((clique_inputs, [r])) cliques.append((clique_inputs, [r]))
for input in clique_inputs: for input in clique_inputs:
find_cliques_helper(input) find_cliques_helper(input)
...@@ -142,6 +180,24 @@ def find_cliques(env, through_broadcast = False): ...@@ -142,6 +180,24 @@ def find_cliques(env, through_broadcast = False):
class CliqueOptimizer(opt.Optimizer): class CliqueOptimizer(opt.Optimizer):
"""
Usage: CliqueOptimizer(through_broadcast = False,
scalar_optimizer = None,
make_composite = False).optimize(env)
Finds cliques of Broadcast operations in the env and does either
or both of two things:
* Apply scalar_optimizer on the clique as if the clique was a
group of scalar operations. scalar_optimizer can be any optimization
which applies on scalars. If it is None, no optimization is done.
* Replace the clique with a single Op, optimized to perform the
computations properly. If make_composite is False, no such replacement
is done.
Note: it is recommended to run the lift_dimshuffle optimization before
this one.
"""
def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False): def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
self.through_broadcast = through_broadcast self.through_broadcast = through_broadcast
...@@ -152,20 +208,25 @@ class CliqueOptimizer(opt.Optimizer): ...@@ -152,20 +208,25 @@ class CliqueOptimizer(opt.Optimizer):
if self.scalar_optimizer is None and not self.make_composite: if self.scalar_optimizer is None and not self.make_composite:
# there's nothing to do with the cliques... # there's nothing to do with the cliques...
return return
cliques = find_cliques(env, self.through_broadcast) cliques = find_cliques(env, self.through_broadcast)
opt = self.scalar_optimizer opt = self.scalar_optimizer
def build_scalar_clique(r, env, equiv): def build_scalar_clique(r, env, equiv):
# Maps a clique of Broadcast Ops to a clique of Scalar Ops with the same
# structure and equivalent operations. equiv contains the mapping.
if r in equiv: if r in equiv:
return equiv[r] return equiv[r]
op = r.owner op = r.owner
if r in env.inputs or r in env.orphans(): if r in env.inputs or r in env.orphans():
# For each leave we make a Scalar of the corresponding dtype
s = scalar.Scalar(dtype = r.dtype) s = scalar.Scalar(dtype = r.dtype)
_r = r _r = r
if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order): if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
_r = r.owner.inputs[0] _r = r.owner.inputs[0]
if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \ if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
and _r.broadcastable == (): and _r.broadcastable == ():
# If we have a constant tensor we map it to a constant scalar.
s.data = _r.data s.data = _r.data
s.constant = True s.constant = True
equiv[r] = s equiv[r] = s
...@@ -184,15 +245,18 @@ class CliqueOptimizer(opt.Optimizer): ...@@ -184,15 +245,18 @@ class CliqueOptimizer(opt.Optimizer):
s_g = Env([equiv[r] for r in g.inputs], s_g = Env([equiv[r] for r in g.inputs],
[equiv[r] for r in g.outputs]) [equiv[r] for r in g.outputs])
if opt is not None: if opt is not None:
equiv2 = dict() equiv2 = dict() # reverse mapping, from Scalar Op to Tensor Op
for k, v in equiv.items(): for k, v in equiv.items():
equiv2[v] = k equiv2[v] = k
def transform(op, equiv): def transform(op, equiv):
# We get a scalar op and we return an equivalent op on tensors.
return Broadcast(op.__class__, [equiv[input] for input in op.inputs]) return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
s_g.add_feature(sync_to(env, equiv2, transform)) s_g.add_feature(sync_to(env, equiv2, transform)) # Any change to s_g will now be transferred to g
opt.optimize(s_g) opt.optimize(s_g)
if self.make_composite: if self.make_composite:
def follow_inplace(r): def follow_inplace(r):
# Tries to find the earliest r2 in g such that r destroys r2
# If no such r2 is found, returns None
op = r.owner op = r.owner
if op is None or r in g.inputs or r in g.orphans(): if op is None or r in g.inputs or r in g.orphans():
return None return None
...@@ -211,6 +275,8 @@ class CliqueOptimizer(opt.Optimizer): ...@@ -211,6 +275,8 @@ class CliqueOptimizer(opt.Optimizer):
for i, output in enumerate(g.outputs): for i, output in enumerate(g.outputs):
destroyed = follow_inplace(output) destroyed = follow_inplace(output)
if destroyed is not None and destroyed in g.inputs: if destroyed is not None and destroyed in g.inputs:
# we transfer the inplace operation only if it is
# an input that is destroyed
inplace_pattern[i] = g.inputs.index(destroyed) inplace_pattern[i] = g.inputs.index(destroyed)
C = scalar.composite(s_g.inputs, s_g.outputs) C = scalar.composite(s_g.inputs, s_g.outputs)
ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern) ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
...@@ -218,6 +284,17 @@ class CliqueOptimizer(opt.Optimizer): ...@@ -218,6 +284,17 @@ class CliqueOptimizer(opt.Optimizer):
def sync_to(target, equiv, transform): def sync_to(target, equiv, transform):
"""
Usage: sync_to(target, equiv, transform)
* target: an Env
* equiv: a dictionary that maps results and ops to results and ops
in target
* transform: a function that takes (op, equiv) as inputs and
returns a new op.
Returns a Feature that can be added to an Env and mirrors all
modifications to that env with modifications to the target env.
"""
class Synchronize(gof.Listener, gof.Constraint): class Synchronize(gof.Listener, gof.Constraint):
...@@ -259,44 +336,3 @@ def sync_to(target, equiv, transform): ...@@ -259,44 +336,3 @@ def sync_to(target, equiv, transform):
return Synchronize return Synchronize
"""
This variable is used in compile.prog as the optimizer for all programs built
using either compile.single, compile.to_func, and compile.prog.
Old code::
if 0:
def optimizer(lst):
begin = gof.SeqOptimizer([])
end = gof.SeqOptimizer([gof.DummyRemover])
seq_opt = gof.SeqOptimizer(begin + lst + end)
return gof.PythonOpt(gof.MergeOptMerge(seq_opt))
if 0:
optimizer_begin = gof.SeqOptimizer([opt for name, opt in [
['double_transpose_eliminator', pattern_opt((transpose, (transpose, 'x')), 'x')],
['addxx_to_twice', pattern_opt((add_elemwise, 'x', 'x'), (twice, 'x'))],
['twice_to_itwice', op_sub(twice, itwice)],
['mulxx_to_sqr', pattern_opt((mul_elemwise, 'x', 'x'), (sqr, 'x'))],
['sqr_to_isqr', op_sub(sqr, isqr)],
['add_to_iadd', op_sub(add_elemwise, iadd_elemwise)],
['add_to_iadd_reverse', pattern_opt((add_elemwise, 'x', 'y'),
(iadd_elemwise, 'y', 'x'))]]])
# ['remove_copies', gof.OpRemover(array_copy)],
# [None, gof.DummyRemover] # has to be at the end
"""
...@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')), ...@@ -23,6 +23,39 @@ logpow = Pattern((Log, (Pow, 'x', 'y')),
class Canonizer(gof.Optimizer): class Canonizer(gof.Optimizer):
"""
Simplification tool.
Usage: Canonizer(main, inverse, reciprocal, mainfn, invfn, recfn, transform)
* main: a suitable Op class that is commutative, associative and takes
one to an arbitrary number of inputs, e.g. Add or Mul
* inverse: an Op class such that inverse(main(x, y), y) == x
e.g. Sub or Div
* reciprocal: a function such that main(x, reciprocal(y)) == inverse(x, y)
e.g. Neg or Inv
* mainfn, invfn, recfn: functions that behave just like the previous three
Ops, but on true scalars (e.g. their impl)
* transform: a function that maps (numerator, denominatur) where numerator
and denominator are lists of Result instances, to new lists
where further simplifications may have been applied.
Examples:
add_canonizer = Canonizer(Add, Sub, Neg, lambda *inputs: sum(inputs), ...)
mul_canonizer = Canonizer(Mul, Div, Inv, lambda *inputs: product(inputs), ...)
Examples of optimizations mul_canonizer can perform:
x / x -> 1
(x * y) / x -> y
x / y / x -> 1 / y
x / y / z -> x / (y * z)
x / (y / z) -> (x * z) / y
(a / b) * (b / c) * (c / d) -> a / d
(2.0 * x) / (4.0 * y) -> (0.5 * x) / y
2 * x / 2 -> x
"""
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None): def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
self.main = main self.main = main
...@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer): ...@@ -37,10 +70,15 @@ class Canonizer(gof.Optimizer):
def apply(self, env): def apply(self, env):
def canonize(r): def canonize(r):
if r in env.inputs or r in env.orphans(): if r in env.inputs or r in env.orphans():
return return
def flatten(r, nclients_check = True): def flatten(r, nclients_check = True):
# Collapses a tree of main/inverse/reciprocal Ops (aka Mul/Div/Inv or Add/Sub/Neg)
# into a list of numerators and a list of denominators
# e.g. (x*(1/y))*(x/(z/a)) aka Mul(Mul(x, (Inv, y)), Div(x, Div(z, a))) -> [x, x, a], [z, y]
op = r.owner op = r.owner
if op is None or r in env.inputs or r in env.orphans(): if op is None or r in env.inputs or r in env.orphans():
return [r], [] return [r], []
...@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer): ...@@ -50,9 +88,11 @@ class Canonizer(gof.Optimizer):
nums = [x[0] for x in results] nums = [x[0] for x in results]
denums = [x[1] for x in results] denums = [x[1] for x in results]
elif isinstance(op, self.inverse) and (not nclients_check or env.nclients(r) == 1): elif isinstance(op, self.inverse) and (not nclients_check or env.nclients(r) == 1):
# num, denum of the second argument are added to the denum, num respectively
nums = [results[0][0], results[1][1]] nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]] denums = [results[0][1], results[1][0]]
elif isinstance(op, self.reciprocal) and (not nclients_check or env.nclients(r) == 1): elif isinstance(op, self.reciprocal) and (not nclients_check or env.nclients(r) == 1):
# num, denum of the sole argument are added to the denum, num respectively
nums = [results[0][1]] nums = [results[0][1]]
denums = [results[0][0]] denums = [results[0][0]]
else: else:
...@@ -69,23 +109,30 @@ class Canonizer(gof.Optimizer): ...@@ -69,23 +109,30 @@ class Canonizer(gof.Optimizer):
for input in r.owner.inputs: for input in r.owner.inputs:
canonize(input) canonize(input)
return return
# Terms that are both in the num and denum lists cancel each other
for d in list(denum): for d in list(denum):
if d in list(num): if d in list(num):
# list.remove only removes the element once
num.remove(d) num.remove(d)
denum.remove(d) denum.remove(d)
# We identify the constants in num and denum
numct, num = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, num) numct, num = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, num)
denumct, denum = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, denum) denumct, denum = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, denum)
# All constants in num and denum are combined into a single constant which we add to num (unless it's a neutral constant)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct])) v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral: if v != self.neutral:
num.insert(0, C(v)) num.insert(0, C(v))
# We optimize the num and denum lists further if requested
if self.transform is not None: if self.transform is not None:
num, denum = self.transform(env, num, denum) num, denum = self.transform(env, num, denum)
def make(factors): def make(factors):
# Combines the factors using self.main (aka Mul) depending
# on the number of elements.
n = len(factors) n = len(factors)
if n == 0: if n == 0:
return None return None
...@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer): ...@@ -98,10 +145,13 @@ class Canonizer(gof.Optimizer):
if numr is None: if numr is None:
if denumr is None: if denumr is None:
# Everything cancelled each other so we're left with
# the neutral element.
new_r = Scalar(dtype = r.dtype) new_r = Scalar(dtype = r.dtype)
new_r.constant = True new_r.constant = True
new_r.data = self.neutral new_r.data = self.neutral
else: else:
# There's no numerator so we use reciprocal
new_r = self.reciprocal(denumr).out new_r = self.reciprocal(denumr).out
else: else:
if denumr is None: if denumr is None:
...@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer): ...@@ -109,6 +159,7 @@ class Canonizer(gof.Optimizer):
else: else:
new_r = self.inverse(numr, denumr).out new_r = self.inverse(numr, denumr).out
# Hopefully this won't complain!
env.replace(r, new_r) env.replace(r, new_r)
for factor in num + denum: for factor in num + denum:
...@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer): ...@@ -119,11 +170,28 @@ class Canonizer(gof.Optimizer):
def group_powers(env, num, denum): def group_powers(env, num, denum):
"""
Plugin for Canonizer: use as Canonizer(..., transform = group_powers)
Takes num, denum such that mul(*num) / mul(*denum) is in env
and searches for instances of exp(x) or x**y in order to group
together powers of the same variable. Returns num2, denum2 in
which the grouping has been done.
Note: this function does not modify env.
Examples:
group_powers([x, exp(x), exp(y)], [exp(z)]) -> [x, exp(x+y-z)], []
"""
# maps a base to the list of powers it is raised to in the
# numerator/denominator lists.
num_powers = {} num_powers = {}
denum_powers = {} denum_powers = {}
def populate(d, seq): def populate(d, seq):
# For each instance of exp or pow in seq, removes it from seq
# and does d[base].append(power).
for factor in list(seq): for factor in list(seq):
op = factor.owner op = factor.owner
if op is None or factor in env.inputs or factor in env.orphans(): if op is None or factor in env.inputs or factor in env.orphans():
...@@ -139,6 +207,8 @@ def group_powers(env, num, denum): ...@@ -139,6 +207,8 @@ def group_powers(env, num, denum):
populate(denum_powers, denum) populate(denum_powers, denum)
for x in set(num_powers.keys() + denum_powers.keys()): for x in set(num_powers.keys() + denum_powers.keys()):
# we append base ** (num_powers[base] - denum_powers[base])
# to the num list
try: num_ys = num_powers.pop(x) try: num_ys = num_powers.pop(x)
except KeyError: num_ys = [] except KeyError: num_ys = []
...@@ -148,6 +218,7 @@ def group_powers(env, num, denum): ...@@ -148,6 +218,7 @@ def group_powers(env, num, denum):
num_r = num_ys and add(*num_ys) or C(0) num_r = num_ys and add(*num_ys) or C(0)
denum_r = denum_ys and add(*denum_ys) or C(0) denum_r = denum_ys and add(*denum_ys) or C(0)
if x == 'e': if x == 'e':
num.append(exp(num_r - denum_r)) num.append(exp(num_r - denum_r))
else: else:
......
...@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None): ...@@ -80,17 +80,14 @@ def astensor(data, broadcastable=None, name=None):
if isinstance(data, BaseTensor): if isinstance(data, BaseTensor):
if broadcastable is not None and list(data.broadcastable) != list(broadcastable): if broadcastable is not None and list(data.broadcastable) != list(broadcastable):
raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable)) raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable))
if isinstance(data, Tensor) and (name is None or name == data.name): if name is not None and name != data.name:
return data raise ValueError("Cannot rename an existing Tensor.")
else: return data
t = Tensor(data.dtype, data.broadcastable, name = name)
t.data = data
return t
elif isinstance(data, Result): elif isinstance(data, Result):
data = data.data raise TypeError("Cannot make a Tensor out of a non-Tensor result.")
if data is None and broadcastable is None: if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None or a Result with no data.") raise TypeError("Cannot make a Tensor out of None.")
data = numpy.asarray(data) data = numpy.asarray(data)
if broadcastable is None: if broadcastable is None:
...@@ -107,38 +104,6 @@ s2t.astensor = astensor ...@@ -107,38 +104,6 @@ s2t.astensor = astensor
# Supporting Ops # Supporting Ops
############################ ############################
def _scalar_switch(normal_f, scalar_f, scalar_f_reverse = None):
"""a decorator for operators before broadcasting works properly"""
def f(x, y):
def as_tensor(obj):
if isinstance(obj, Tensor):
return obj
else:
return astensor(obj)
x, y = as_tensor(x), as_tensor(y)
if 0 not in y.broadcastable:
return scalar_f(x, y)
if 0 not in x.broadcastable:
if scalar_f_reverse:
return scalar_f_reverse(y, x)
else:
raise TypeError("You cannot do this operation on a scalar.")
return normal_f(x, y)
return f
def _assert_same_shapes(x, *rest):
"""Ensure that all inputs to the function impl have the same size (foils numpy's broadcasting)"""
shape = x.shape
for other in rest:
if other.shape != shape:
raise ValueError(_assert_same_shapes.E_shape, shape, other.shape)
_assert_same_shapes.E_shape = "The dimensions of the inputs do not match."
def _assert_tensor_scalar(x, a):
"""ensure that the second input is a scalar"""
if numpy.product(a.shape) != 1:
raise ValueError("The second argument must be a scalar.")
# this has a different name, because _as_tensor is the function which ops use # this has a different name, because _as_tensor is the function which ops use
# to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor. # to upcast their arguments... this internal-use function is a good place to put debugging stuff, better than the global astensor.
_as_tensor = astensor _as_tensor = astensor
...@@ -450,8 +415,6 @@ class Gemm(_Op): ...@@ -450,8 +415,6 @@ class Gemm(_Op):
return ['<iostream>'] return ['<iostream>']
def c_libraries(self): def c_libraries(self):
return blas.ldflags() return blas.ldflags()
#def c_var_names(self):
# return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self, *args): def c_validate_update(self, *args):
return "" return ""
def c_validate_update_cleanup(self, *args): def c_validate_update_cleanup(self, *args):
...@@ -612,125 +575,3 @@ class Gemm(_Op): ...@@ -612,125 +575,3 @@ class Gemm(_Op):
""" % dict(locals(), **sub) """ % dict(locals(), **sub)
gemm = gof.op.constructor(Gemm) gemm = gof.op.constructor(Gemm)
if 0:
##########################
# Comparisons
##########################
# Less-than
class lt_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class lt_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
# Less-than or equal
class le_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class le_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
# Greater-than or equal
class gt_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class gt_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
# Greater-than or equal
class ge_elemwise(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
class ge_scalar_r(_Elemwise):
def __init__(self, *args):
raise NotImplementedError()
if 0:
def _broadcastable_pattern(pattern):
def factory(data = None, name = None, dtype=None):
if data:
assert len(data.shape) == len(pattern)
if dtype is not None:
assert dtype is data.dtype
dtype = data.dtype
rval = Tensor(dtype, pattern, name)
rval.data = data
else:
rval = Tensor(dtype, pattern, name)
return rval
return factory
row = _broadcastable_pattern([1, 0])
col = _broadcastable_pattern([0, 1])
matrix = _broadcastable_pattern([0, 0])
if 0: #old __init__ code
"""Create a Tensor
If data is given:
- constant defaults to True
- if dtype is given, it must match data.dtype
- otherwise: default is data.dtype
- if broadcastable is given, len(broadcastable) must match len(data.shape)
- otherwise: if it is constant, it defaults to 1 where shape[i]==1
- if it is not constant, it defaults to 0s
If data is not given:
- constant defaults to False
"""
if dtype is None or broadcastable is None:
if data is None:
raise TypeError("Provide non-None data to complete the dtype and broadcastable flags.")
data = numpy.asarray(data)
if constant is None:
constant = True
dtype = data.dtype
if constant:
broadcastable = [1*(x == 1) for x in data.shape]
else:
broadcastable = [0] * len(data.shape)
if 0:
def tensor__new__(cls, *args, **kwargs):
"""__new__ is overloaded to handle the special form Tensor(x) when x is
a Tensor or an Op whose default output is a Tensor. In these cases, the
argument x is returned, and a new Tensor is not created.
"""
if len(args) == 1:
a = args[0]
t = super(Tensor, cls).__new__(cls, *args, **kwargs)
t.__init__(*args, **kwargs)
return t
# def upcast(dtype, *dtypes):
# z = numpy.zeros((), dtype = dtype)
# for dtype in dtypes:
# z = z + numpy.zeros((), dtype = dtype)
# return str(z.dtype)
# for dtype in i_dtypes:
# if dtype is None:
# raise TypeError("Expected a Tensor.")
# upcasted = upcast(*i_dtypes)
# return [upcasted] * self.nout
# # try:
# # dmap = self.destroy_map()
# # except AttributeError:
# # dmap = {}
# # rval = []
# # for i in xrange(self.nout):
# # if i in dmap:
# # destroyed = dmap[output]
# # if len(destroyed) != 1:
# # raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# # rval.append(destroyed[0])
# # else:
# # rval.append(upcasted)
# # return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论