提交 d77ccbce authored 作者: Olivier Breuleux's avatar Olivier Breuleux

Add and Mul now take an arbitrary number of inputs

上级 7a0b002f
...@@ -66,9 +66,9 @@ class _test_composite(unittest.TestCase): ...@@ -66,9 +66,9 @@ class _test_composite(unittest.TestCase):
assert c.outputs[0].data == 6.0 assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0 assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5 assert c.outputs[2].data == 0.5
g = env([x, y], c.outputs) g = env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -136,8 +136,11 @@ class Broadcast(Op, Destroyer): ...@@ -136,8 +136,11 @@ class Broadcast(Op, Destroyer):
assert len(set([len(input.broadcastable) for input in inputs])) == 1 assert len(set([len(input.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError): except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__) raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__)
self.nin = scalar_opclass.nin
self.nout = scalar_opclass.nout self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in inputs])
self.nin = self.shadow.nin
self.nout = self.shadow.nout
out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout
if inplace_pattern: if inplace_pattern:
...@@ -158,8 +161,7 @@ class Broadcast(Op, Destroyer): ...@@ -158,8 +161,7 @@ class Broadcast(Op, Destroyer):
self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)] self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
self.scalar_opclass = scalar_opclass self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in self.inputs]) self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def clone_with_new_inputs(self, *new_inputs): def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern) return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
...@@ -389,8 +391,10 @@ class CAReduce(Op): ...@@ -389,8 +391,10 @@ class CAReduce(Op):
def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None): def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None):
inputs = map(astensor, inputs) inputs = map(astensor, inputs)
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(len(inputs) + 1)])
if scalar_opclass.nin != 2 or scalar_opclass.nout != 1: if self.shadow.nin != 2 or self.shadow.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.") raise NotImplementedError("CAReduce only supports binary functions with a single output.")
if len(inputs) != 1: if len(inputs) != 1:
raise TypeError("Only one argument expected.") raise TypeError("Only one argument expected.")
...@@ -403,8 +407,7 @@ class CAReduce(Op): ...@@ -403,8 +407,7 @@ class CAReduce(Op):
self.dimensions_to_reduce = dimensions_to_reduce self.dimensions_to_reduce = dimensions_to_reduce
self.scalar_opclass = scalar_opclass self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)]) self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def desc(self): def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce)) return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce))
......
...@@ -36,6 +36,16 @@ def difference(seq1, seq2): ...@@ -36,6 +36,16 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo # -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2] return [x for x in seq1 if x not in seq2]
def partition(f, seq):
seqt = []
seqf = []
for elem in seq:
if f(elem):
seqt.append(elem)
else:
seqf.append(elem)
return seqt, seqf
def attr_checker(*attrs): def attr_checker(*attrs):
def f(candidate): def f(candidate):
for attr in attrs: for attr in attrs:
......
...@@ -186,28 +186,32 @@ class Scalar(Result): ...@@ -186,28 +186,32 @@ class Scalar(Result):
class ScalarMixedOp(GuardedOp): def upcast(dtype, *dtypes):
"""Olivier: document this stuff! -JB""" z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
class ScalarOp(GuardedOp):
nin = -1 nin = -1
nout = 1 nout = 1
def __init__(self, *inputs): def __init__(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \ raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self.__class__.__name__, len(inputs), self.nin)) % (self.__class__.__name__, len(inputs), self.nin))
else:
self.nin = len(inputs)
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = self.propagate_dtypes(*i_dtypes) o_dtypes = [upcast(*i_dtypes)] * self.nout
self.inputs = inputs self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes] self.outputs = [Scalar(dtype) for dtype in o_dtypes]
def propagate_dtypes(self, *inputs):
raise AbstractFunctionError()
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -215,43 +219,45 @@ class ScalarMixedOp(GuardedOp): ...@@ -215,43 +219,45 @@ class ScalarMixedOp(GuardedOp):
raise AbstractFunctionError() raise AbstractFunctionError()
def perform(self): def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs]) if self.nout == 1:
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
else:
def upcast(dtype, *dtypes): results = utils.from_return_values(self.impl(*[input.data for input in self.inputs]))
z = numpy.zeros((), dtype = dtype) for output, result in zip(self.outputs, results):
for dtype in dtypes: output.data = result
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
class PureScalarOp(ScalarMixedOp):
cast_method = lambda self, *args: upcast(*args)
def propagate_dtypes(self, *i_dtypes):
for dtype in i_dtypes:
if dtype is None:
raise TypeError("Expected a Scalar.")
return [self.cast_method(*i_dtypes)] * self.nout
class UnaryScalarOp(PureScalarOp): class UnaryScalarOp(ScalarOp):
nin = 1 nin = 1
class BinaryScalarOp(PureScalarOp): class BinaryScalarOp(ScalarOp):
nin = 2 nin = 2
class Add(ScalarOp):
class Add(BinaryScalarOp):
identity = 0 identity = 0
def impl(self, x, y): def impl(self, *inputs):
return x + y return sum(inputs)
def c_code(self, (x, y), (z, ), sub): def c_code(self, inputs, (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals() if not inputs:
def grad(self, (x, y), (gz, )): return z + " = 0;"
return gz, gz else:
return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return (gz, ) * len(inputs)
class Mul(ScalarOp):
identity = 1
def impl(self, *inputs):
return numpy.product(inputs)
def c_code(self, inputs, (z, ), sub):
if not inputs:
return z + " = 1;"
else:
return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return [mul(*([gz] + utils.difference(inputs, [input])))
for input in inputs]
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
...@@ -261,14 +267,6 @@ class Sub(BinaryScalarOp): ...@@ -261,14 +267,6 @@ class Sub(BinaryScalarOp):
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
return gz, -gz return gz, -gz
class Mul(BinaryScalarOp):
def impl(self, x, y):
return x * y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz * y, gz * x
class Div(BinaryScalarOp): class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
...@@ -302,6 +300,7 @@ class Second(BinaryScalarOp): ...@@ -302,6 +300,7 @@ class Second(BinaryScalarOp):
return None, gz return None, gz
class Identity(UnaryScalarOp): class Identity(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return x return x
...@@ -333,7 +332,8 @@ class Sgn(UnaryScalarOp): ...@@ -333,7 +332,8 @@ class Sgn(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
return None, return None,
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s/abs(%(x)s);" % locals() # TODO: C use copysign return "%(z)s = %(x)s/%(prefix)sabs(%(x)s);" \
% dict(locals(), prefix = 'float' in self.inputs[0].dtype and 'f' or '') # TODO: C use copysign
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
...@@ -405,7 +405,7 @@ def composite(inputs, outputs): ...@@ -405,7 +405,7 @@ def composite(inputs, outputs):
The operations between inputs and outputs (as given by The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of Env(inputs, outputs).ops()) must all be instances of
PureScalarOp. ScalarOp.
Examples: Examples:
x, y = Scalar(), Scalar() x, y = Scalar(), Scalar()
...@@ -420,8 +420,8 @@ def composite(inputs, outputs): ...@@ -420,8 +420,8 @@ def composite(inputs, outputs):
inputs, outputs = env.inputs, env.outputs inputs, outputs = env.inputs, env.outputs
for op in env.ops(): for op in env.ops():
if not isinstance(op, PureScalarOp): if not isinstance(op, ScalarOp):
raise ValueError("The input env to composite must be exclusively composed of PureScalarOp instances.") raise ValueError("The input env to composite must be exclusively composed of ScalarOp instances.")
subd = dict(zip(inputs, subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) + ["%%(i%i)s"%i for i in range(len(inputs))]) +
...@@ -460,7 +460,7 @@ def composite(inputs, outputs): ...@@ -460,7 +460,7 @@ def composite(inputs, outputs):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1) # this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice # it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably) # it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops # still correct since we only have scalar ops
if r in env.inputs: if r in env.inputs:
idx = env.inputs.index(r) idx = env.inputs.index(r)
return lambda inputs: inputs[idx] return lambda inputs: inputs[idx]
...@@ -472,7 +472,7 @@ def composite(inputs, outputs): ...@@ -472,7 +472,7 @@ def composite(inputs, outputs):
_impls = [compose_impl(r) for r in env.outputs] _impls = [compose_impl(r) for r in env.outputs]
class Composite(PureScalarOp): class Composite(ScalarOp):
nin = len(inputs) nin = len(inputs)
nout = len(outputs) nout = len(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论