提交 62c90bd2 authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

merged

...@@ -164,27 +164,34 @@ class _test_CAReduce(unittest.TestCase): ...@@ -164,27 +164,34 @@ class _test_CAReduce(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# x = modes.build(Tensor('float64', [0, 0], name = 'x')) # x = modes.build(Tensor('int32', [0, 0], name = 'x'))
# y = modes.build(Tensor('float64', [0, 0], name = 'y')) # y = modes.build(Tensor('int32', [0, 0], name = 'y'))
# e = Broadcast(SquareDiff, (x, y), {0:0}).out # # x = modes.build(Tensor('float64', [0, 0], name = 'x'))
# # y = modes.build(Tensor('float64', [0, 0], name = 'y'))
# e = Broadcast(Pow, (x, y)).out
# f = gof.CLinker(env([x, y], [e])).make_function(inplace = False) # f = gof.CLinker(env([x, y], [e])).make_function(inplace = False)
# xv = numpy.random.rand(1000, 1000) # # xv = numpy.random.rand(1000, 1000)
# yv = numpy.random.rand(1000, 1000) # # yv = numpy.random.rand(1000, 1000)
# zv = numpy.random.rand(1000, 1000) # # zv = numpy.random.rand(1000, 1000)
# xv = numpy.random.randint(1, 5, (1000, 1000))
# yv = numpy.random.randint(1, 5, (1000, 1000))
# add = numpy.frompyfunc(lambda x, y: x + y, 2, 1) # add = numpy.frompyfunc(lambda x, y: x + y, 2, 1)
# t0 = time.time() # # t0 = time.time()
# for i in xrange(100): # # for i in xrange(100):
# xv -= yv # # xv / yv
# xv *= xv # # print time.time() - t0
# # xv += yv
# print time.time() - t0
# t0 = time.time() # t0 = time.time()
# for i in xrange(100): # for i in xrange(100):
# f(xv, yv) # f(xv, yv)
# print time.time() - t0 # print time.time() - t0
# speed ratios:
# add : 1
# mul : 1
# div : 2
# pow : 20
......
...@@ -66,9 +66,9 @@ class _test_composite(unittest.TestCase): ...@@ -66,9 +66,9 @@ class _test_composite(unittest.TestCase):
assert c.outputs[0].data == 6.0 assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0 assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5 assert c.outputs[2].data == 0.5
g = env([x, y], c.outputs) g = env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -13,8 +13,16 @@ def inputs(): ...@@ -13,8 +13,16 @@ def inputs():
x = Scalar('float64', name = 'x') x = Scalar('float64', name = 'x')
y = Scalar('float64', name = 'y') y = Scalar('float64', name = 'y')
z = Scalar('float64', name = 'z') z = Scalar('float64', name = 'z')
a = Scalar('float64', name = 'a')
return x, y, z return x, y, z
def more_inputs():
a = Scalar('float64', name = 'a')
b = Scalar('float64', name = 'b')
c = Scalar('float64', name = 'c')
d = Scalar('float64', name = 'd')
return a, b, c, d
class _test_opts(unittest.TestCase): class _test_opts(unittest.TestCase):
...@@ -24,9 +32,106 @@ class _test_opts(unittest.TestCase): ...@@ -24,9 +32,106 @@ class _test_opts(unittest.TestCase):
g = Env([x], [e]) g = Env([x], [e])
assert str(g) == "[Pow(x, 2.0)]" assert str(g) == "[Pow(x, 2.0)]"
gof.ConstantFinder().optimize(g) gof.ConstantFinder().optimize(g)
opt2.optimize(g) pow2sqr_float.optimize(g)
assert str(g) == "[Sqr(x)]" assert str(g) == "[Sqr(x)]"
# class _test_canonize(unittest.TestCase):
# def test_muldiv(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y)
# # e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z)
# # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x)
# # e = (a / b) * (b / c) * (c / d)
# # e = (a * b) / (b * c) / (c * d)
# # e = 2 * x / 2
# # e = x / y / x
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# print g
# def test_plusmin(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x - x
# # e = (2.0 + x) - (2.0 + y)
# # e = (2.0 + x) - (4.0 + y)
# # e = x - (y - z)
# # e = (x + y) - x
# # e = (x - y) + (y - z) + (z - x)
# # e = (a - b) + (b - c) + (c - d)
# # e = x + -y
# # e = a - b - b + a + b + c + b - c
# e = x + log(y) - x + y
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_both(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# e0 = (x * y / x)
# e = e0 + e0 - e0
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g)
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# def test_group_powers(self):
# x, y, z = inputs()
# a, b, c, d = more_inputs()
# # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0)
# # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0)
# e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y
# negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g)
# print g
# pow2one_float.optimize(g)
# pow2x_float.optimize(g)
# print g
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -219,6 +219,47 @@ def make_broadcast_tester(op_class, expected, checks = {}, **kwargs): ...@@ -219,6 +219,47 @@ def make_broadcast_tester(op_class, expected, checks = {}, **kwargs):
return make_tester(name, op_class, expected, checks, **kwargs) return make_tester(name, op_class, expected, checks, **kwargs)
def make_broadcast_tester_unary(op_class, expected, checks = {}, **kwargs):
_randint = randint
_rand = rand
if kwargs.has_key('nonzero'):
if kwargs['nonzero']:
_randint = banzero(_randint)
_rand = banzero(_rand)
del kwargs['nonzero']
if kwargs.has_key('positive'):
if kwargs['positive']:
_randint = banneg(_randint)
_rand = banneg(_rand)
del kwargs['positive']
_good_broadcast = dict(normal = (_rand(2, 3), ),
int = (_rand(2, 3), ))
_bad_build_broadcast = dict()
_bad_runtime_broadcast = dict()
_grad_broadcast = dict(normal = (_rand(2, 3), ),
int = (_rand(2, 3), ))
kwargs.setdefault('good', _good_broadcast)
kwargs.setdefault('bad_build', _bad_build_broadcast)
kwargs.setdefault('bad_runtime', _bad_runtime_broadcast)
kwargs.setdefault('grad', _grad_broadcast)
name = op_class.__name__ + "Tester"
if kwargs.has_key('inplace'):
if kwargs['inplace']:
_expected = expected
expected = lambda *inputs: numpy.array(_expected(*inputs), dtype = inputs[0].dtype)
checks = dict(checks,
inplace_check = lambda inputs, outputs: inputs[0] is outputs[0])
del kwargs['inplace']
return make_tester(name, op_class, expected, checks, **kwargs)
...@@ -264,11 +305,11 @@ def make_broadcast_tester(op_class, expected, checks = {}, **kwargs): ...@@ -264,11 +305,11 @@ def make_broadcast_tester(op_class, expected, checks = {}, **kwargs):
# good = _pow_good) # good = _pow_good)
# AbsTester = make_broadcast_tester(op_class = Abs, AbsTester = make_broadcast_tester_unary(op_class = Abs,
# expected = lambda x: abs(x)) expected = lambda x: abs(x))
# AbsInplaceTester = make_broadcast_tester(op_class = AbsInplace, AbsInplaceTester = make_broadcast_tester_unary(op_class = AbsInplace,
# expected = lambda x: abs(x), expected = lambda x: abs(x),
# inplace = True) inplace = True)
# ExpTester = make_broadcast_tester(op_class = Exp, # ExpTester = make_broadcast_tester(op_class = Exp,
# expected = lambda x: numpy.exp(x)) # expected = lambda x: numpy.exp(x))
......
...@@ -136,8 +136,11 @@ class Broadcast(Op, Destroyer): ...@@ -136,8 +136,11 @@ class Broadcast(Op, Destroyer):
assert len(set([len(input.broadcastable) for input in inputs])) == 1 assert len(set([len(input.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError): except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__) raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__)
self.nin = scalar_opclass.nin
self.nout = scalar_opclass.nout self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in inputs])
self.nin = self.shadow.nin
self.nout = self.shadow.nout
out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout
if inplace_pattern: if inplace_pattern:
...@@ -158,8 +161,7 @@ class Broadcast(Op, Destroyer): ...@@ -158,8 +161,7 @@ class Broadcast(Op, Destroyer):
self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)] self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
self.scalar_opclass = scalar_opclass self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in self.inputs]) self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def clone_with_new_inputs(self, *new_inputs): def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern) return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
...@@ -390,7 +392,9 @@ class CAReduce(Op): ...@@ -390,7 +392,9 @@ class CAReduce(Op):
def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None): def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None):
inputs = map(astensor, inputs) inputs = map(astensor, inputs)
if scalar_opclass.nin != 2 or scalar_opclass.nout != 1: self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(len(inputs) + 1)])
if self.shadow.nin != 2 or self.shadow.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.") raise NotImplementedError("CAReduce only supports binary functions with a single output.")
if len(inputs) != 1: if len(inputs) != 1:
raise TypeError("Only one argument expected.") raise TypeError("Only one argument expected.")
...@@ -403,8 +407,7 @@ class CAReduce(Op): ...@@ -403,8 +407,7 @@ class CAReduce(Op):
self.dimensions_to_reduce = dimensions_to_reduce self.dimensions_to_reduce = dimensions_to_reduce
self.scalar_opclass = scalar_opclass self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)]) self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def desc(self): def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce)) return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce))
......
...@@ -262,6 +262,8 @@ def as_string(i, o, ...@@ -262,6 +262,8 @@ def as_string(i, o,
exist for viewing convenience). exist for viewing convenience).
""" """
orph = orphans(i, o)
multi = set() multi = set()
seen = set() seen = set()
for output in o: for output in o:
...@@ -273,7 +275,7 @@ def as_string(i, o, ...@@ -273,7 +275,7 @@ def as_string(i, o,
for op in ops(i, o): for op in ops(i, o):
for input in op.inputs: for input in op.inputs:
op2 = input.owner op2 = input.owner
if input in i or op2 is None: if input in i or input in orph or op2 is None:
continue continue
if op2 in seen: if op2 in seen:
multi.add(op2) multi.add(op2)
...@@ -286,7 +288,7 @@ def as_string(i, o, ...@@ -286,7 +288,7 @@ def as_string(i, o,
return multi.index(x) + 1 return multi.index(x) + 1
def describe(r): def describe(r):
if r.owner is not None and r not in i: if r.owner is not None and r not in i and r not in orph:
op = r.owner op = r.owner
idx = op.outputs.index(r) idx = op.outputs.index(r)
if idx == op._default_output_idx: if idx == op._default_output_idx:
......
...@@ -276,7 +276,7 @@ class GuardedOp(Op): ...@@ -276,7 +276,7 @@ class GuardedOp(Op):
try: try:
if not old.same_properties(new): if not old.same_properties(new):
raise TypeError("The new input must have the same properties as the previous one.") raise TypeError("The new input must have the same properties as the previous one.")
except AbstractFunction: except AbstractFunctionError:
pass pass
Op.set_input(self, i, new) Op.set_input(self, i, new)
......
...@@ -36,6 +36,16 @@ def difference(seq1, seq2): ...@@ -36,6 +36,16 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2] return [x for x in seq1 if x not in seq2]
def partition(f, seq):
seqt = []
seqf = []
for elem in seq:
if f(elem):
seqt.append(elem)
else:
seqf.append(elem)
return seqt, seqf
def attr_checker(*attrs): def attr_checker(*attrs):
def f(candidate): def f(candidate):
for attr in attrs: for attr in attrs:
......
...@@ -186,8 +186,13 @@ class Scalar(Result): ...@@ -186,8 +186,13 @@ class Scalar(Result):
class ScalarMixedOp(GuardedOp): def upcast(dtype, *dtypes):
"""Olivier: document this stuff! -JB""" z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
class ScalarOp(GuardedOp):
nin = -1 nin = -1
nout = 1 nout = 1
...@@ -197,17 +202,16 @@ class ScalarMixedOp(GuardedOp): ...@@ -197,17 +202,16 @@ class ScalarMixedOp(GuardedOp):
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \ raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self.__class__.__name__, len(inputs), self.nin)) % (self.__class__.__name__, len(inputs), self.nin))
else:
self.nin = len(inputs)
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = self.propagate_dtypes(*i_dtypes) o_dtypes = [upcast(*i_dtypes)] * self.nout
self.inputs = inputs self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes] self.outputs = [Scalar(dtype) for dtype in o_dtypes]
def propagate_dtypes(self, *inputs):
raise AbstractFunctionError()
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -215,43 +219,45 @@ class ScalarMixedOp(GuardedOp): ...@@ -215,43 +219,45 @@ class ScalarMixedOp(GuardedOp):
raise AbstractFunctionError() raise AbstractFunctionError()
def perform(self): def perform(self):
if self.nout == 1:
self.outputs[0].data = self.impl(*[input.data for input in self.inputs]) self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
else:
results = utils.from_return_values(self.impl(*[input.data for input in self.inputs]))
for output, result in zip(self.outputs, results):
output.data = result
class UnaryScalarOp(ScalarOp):
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
class PureScalarOp(ScalarMixedOp):
cast_method = lambda self, *args: upcast(*args)
def propagate_dtypes(self, *i_dtypes):
for dtype in i_dtypes:
if dtype is None:
raise TypeError("Expected a Scalar.")
return [self.cast_method(*i_dtypes)] * self.nout
class UnaryScalarOp(PureScalarOp):
nin = 1 nin = 1
class BinaryScalarOp(PureScalarOp): class BinaryScalarOp(ScalarOp):
nin = 2 nin = 2
class Add(ScalarOp):
class Add(BinaryScalarOp):
identity = 0 identity = 0
def impl(self, x, y): def impl(self, *inputs):
return x + y return sum(inputs)
def c_code(self, (x, y), (z, ), sub): def c_code(self, inputs, (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals() if not inputs:
def grad(self, (x, y), (gz, )): return z + " = 0;"
return gz, gz else:
return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return (gz, ) * len(inputs)
class Mul(ScalarOp):
identity = 1
def impl(self, *inputs):
return numpy.product(inputs)
def c_code(self, inputs, (z, ), sub):
if not inputs:
return z + " = 1;"
else:
return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return [mul(*([gz] + utils.difference(inputs, [input])))
for input in inputs]
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
...@@ -261,14 +267,6 @@ class Sub(BinaryScalarOp): ...@@ -261,14 +267,6 @@ class Sub(BinaryScalarOp):
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, -gz return gz, -gz
class Mul(BinaryScalarOp):
def impl(self, x, y):
return x * y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz * y, gz * x
class Div(BinaryScalarOp): class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
...@@ -302,6 +300,7 @@ class Second(BinaryScalarOp): ...@@ -302,6 +300,7 @@ class Second(BinaryScalarOp):
return None, gz return None, gz
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x return x
...@@ -338,7 +337,8 @@ class Sgn(UnaryScalarOp): ...@@ -338,7 +337,8 @@ class Sgn(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return None, return None,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s/abs(%(x)s);" % locals() # TODO: C use copysign return "%(z)s = %(x)s/%(prefix)sabs(%(x)s);" \
% dict(locals(), prefix = 'float' in self.inputs[0].dtype and 'f' or '') # TODO: C use copysign
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -457,7 +457,7 @@ def composite(inputs, outputs): ...@@ -457,7 +457,7 @@ def composite(inputs, outputs):
The operations between inputs and outputs (as given by The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of Env(inputs, outputs).ops()) must all be instances of
PureScalarOp. ScalarOp.
Examples: Examples:
x, y = Scalar(), Scalar() x, y = Scalar(), Scalar()
...@@ -472,8 +472,8 @@ def composite(inputs, outputs): ...@@ -472,8 +472,8 @@ def composite(inputs, outputs):
inputs, outputs = env.inputs, env.outputs inputs, outputs = env.inputs, env.outputs
for op in env.ops(): for op in env.ops():
if not isinstance(op, PureScalarOp): if not isinstance(op, ScalarOp):
raise ValueError("The input env to composite must be exclusively composed of PureScalarOp instances.") raise ValueError("The input env to composite must be exclusively composed of ScalarOp instances.")
subd = dict(zip(inputs, subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) + ["%%(i%i)s"%i for i in range(len(inputs))]) +
...@@ -512,7 +512,7 @@ def composite(inputs, outputs): ...@@ -512,7 +512,7 @@ def composite(inputs, outputs):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1) # this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice # it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably) # it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops # still correct since we only have scalar ops
if r in env.inputs: if r in env.inputs:
idx = env.inputs.index(r) idx = env.inputs.index(r)
return lambda inputs: inputs[idx] return lambda inputs: inputs[idx]
...@@ -524,7 +524,7 @@ def composite(inputs, outputs): ...@@ -524,7 +524,7 @@ def composite(inputs, outputs):
_impls = [compose_impl(r) for r in env.outputs] _impls = [compose_impl(r) for r in env.outputs]
class Composite(PureScalarOp): class Composite(ScalarOp):
nin = len(inputs) nin = len(inputs)
nout = len(outputs) nout = len(outputs)
......
from scalar import * from scalar import *
from gof import PatternOptimizer from gof import PatternOptimizer as Pattern
from gof import utils
c2 = constant(2.0) C = constant
opt1 = PatternOptimizer((Mul, 'x', 'x'), (Sqr, 'x')) # x**2 -> x*x
opt2 = PatternOptimizer((Pow, 'x', c2), (Sqr, 'x')) pow2sqr_float = Pattern((Pow, 'x', C(2.0)), (Sqr, 'x'))
pow2sqr_int = Pattern((Pow, 'x', C(2)), (Sqr, 'x'))
# x**0 -> 1
pow2one_float = Pattern((Pow, 'x', C(0.0)), C(1.0))
pow2one_int = Pattern((Pow, 'x', C(0)), C(1))
# x**1 -> x
pow2x_float = Pattern((Pow, 'x', C(1.0)), 'x')
pow2x_int = Pattern((Pow, 'x', C(1)), 'x')
# log(x**y) -> y*log(x)
logpow = Pattern((Log, (Pow, 'x', 'y')),
(Mul, 'y', (Log, 'x')))
class Canonizer(gof.Optimizer):
def __init__(self, main, inverse, reciprocal, mainfn, invfn, recfn, transform = None):
self.main = main
self.inverse = inverse
self.reciprocal = reciprocal
self.mainfn = mainfn
self.invfn = invfn
self.recfn = recfn
self.neutral = mainfn()
self.transform = transform
def apply(self, env):
def canonize(r):
if r in env.inputs or r in env.orphans():
return
def flatten(r, nclients_check = True):
op = r.owner
if op is None or r in env.inputs or r in env.orphans():
return [r], []
results = [r2.dtype == r.dtype and flatten(r2) or ([r2], []) for r2 in op.inputs]
if isinstance(op, self.main) and (not nclients_check or env.nclients(r) == 1):
nums = [x[0] for x in results]
denums = [x[1] for x in results]
elif isinstance(op, self.inverse) and (not nclients_check or env.nclients(r) == 1):
nums = [results[0][0], results[1][1]]
denums = [results[0][1], results[1][0]]
elif isinstance(op, self.reciprocal) and (not nclients_check or env.nclients(r) == 1):
nums = [results[0][1]]
denums = [results[0][0]]
else:
return [r], []
return reduce(list.__add__, nums), reduce(list.__add__, denums)
num, denum = flatten(r, False)
if (num, denum) == ([r], []):
if r.owner is None:
return
else:
for input in r.owner.inputs:
canonize(input)
return
for d in list(denum):
if d in list(num):
num.remove(d)
denum.remove(d)
numct, num = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, num)
denumct, denum = utils.partition(lambda factor: getattr(factor, 'constant', False) and factor.data is not None, denum)
v = self.invfn(self.mainfn(*[x.data for x in numct]), self.mainfn(*[x.data for x in denumct]))
if v != self.neutral:
num.insert(0, C(v))
if self.transform is not None:
num, denum = self.transform(env, num, denum)
def make(factors):
n = len(factors)
if n == 0:
return None
elif n == 1:
return factors[0]
else:
return self.main(*factors).out
numr, denumr = make(num), make(denum)
if numr is None:
if denumr is None:
new_r = Scalar(dtype = r.dtype)
new_r.constant = True
new_r.data = self.neutral
else:
new_r = self.reciprocal(denumr).out
else:
if denumr is None:
new_r = numr
else:
new_r = self.inverse(numr, denumr).out
env.replace(r, new_r)
for factor in num + denum:
canonize(factor)
for output in env.outputs:
canonize(output)
def group_powers(env, num, denum):
num_powers = {}
denum_powers = {}
def populate(d, seq):
for factor in list(seq):
op = factor.owner
if op is None or factor in env.inputs or factor in env.orphans():
continue
if isinstance(op, Exp):
d.setdefault('e', []).append(op.inputs[0])
seq.remove(factor)
elif isinstance(op, Pow):
d.setdefault(op.inputs[0], []).append(op.inputs[1])
seq.remove(factor)
populate(num_powers, num)
populate(denum_powers, denum)
for x in set(num_powers.keys() + denum_powers.keys()):
try: num_ys = num_powers.pop(x)
except KeyError: num_ys = []
try: denum_ys = denum_powers.pop(x)
except KeyError: denum_ys = []
num_r = num_ys and add(*num_ys) or C(0)
denum_r = denum_ys and add(*denum_ys) or C(0)
if x == 'e':
num.append(exp(num_r - denum_r))
else:
num.append(pow(x, num_r - denum_r))
return num, denum
def simple_factorize(env, num, denum):
# a*b + a*c -> a*(b+c)
# a*b + a*c + b*c -> a*(b+c) + b*c
# -> a*b + (a+b)*c
# => a: {b, c}, b: {a, c}, c: {a, b}
# a*c + a*d + b*c + b*d
# => a: {c, d}, b: {c, d}, c: {a, b}, d: {a, b}
# (a+b*x)*(c+d) --> a*c + a*d + b*x*c + b*x*d
# => a: {c, d}, b: {xc, xd}, c: {a, bx}, d: {a, bx}, x: {bc, bd}
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论