提交 d77ccbce authored 作者: Olivier Breuleux's avatar Olivier Breuleux

Add and Mul now take an arbitrary number of inputs

上级 7a0b002f
......@@ -66,9 +66,9 @@ class _test_composite(unittest.TestCase):
assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5
g = env([x, y], c.outputs)
g = env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == [6.0, 7.0, 0.5]
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
if __name__ == '__main__':
......
......@@ -136,8 +136,11 @@ class Broadcast(Op, Destroyer):
assert len(set([len(input.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__)
self.nin = scalar_opclass.nin
self.nout = scalar_opclass.nout
self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in inputs])
self.nin = self.shadow.nin
self.nout = self.shadow.nout
out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout
if inplace_pattern:
......@@ -158,8 +161,7 @@ class Broadcast(Op, Destroyer):
self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
self.inplace_pattern = inplace_pattern
self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in self.inputs])
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
......@@ -389,8 +391,10 @@ class CAReduce(Op):
def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None):
inputs = map(astensor, inputs)
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(len(inputs) + 1)])
if scalar_opclass.nin != 2 or scalar_opclass.nout != 1:
if self.shadow.nin != 2 or self.shadow.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.")
if len(inputs) != 1:
raise TypeError("Only one argument expected.")
......@@ -403,8 +407,7 @@ class CAReduce(Op):
self.dimensions_to_reduce = dimensions_to_reduce
self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)])
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
self.ufunc = numpy.frompyfunc(self.shadow.impl, self.shadow.nin, self.shadow.nout)
def desc(self):
return (self.__class__, self.scalar_opclass, tuple(self.dimensions_to_reduce))
......
......@@ -36,6 +36,16 @@ def difference(seq1, seq2):
# -> use O(len(seq1) * len(seq2)) algo
return [x for x in seq1 if x not in seq2]
def partition(f, seq):
seqt = []
seqf = []
for elem in seq:
if f(elem):
seqt.append(elem)
else:
seqf.append(elem)
return seqt, seqf
def attr_checker(*attrs):
def f(candidate):
for attr in attrs:
......
......@@ -186,28 +186,32 @@ class Scalar(Result):
class ScalarMixedOp(GuardedOp):
"""Olivier: document this stuff! -JB"""
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
class ScalarOp(GuardedOp):
nin = -1
nout = 1
def __init__(self, *inputs):
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self.__class__.__name__, len(inputs), self.nin))
else:
self.nin = len(inputs)
inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = self.propagate_dtypes(*i_dtypes)
o_dtypes = [upcast(*i_dtypes)] * self.nout
self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes]
def propagate_dtypes(self, *inputs):
raise AbstractFunctionError()
def impl(self, *inputs):
raise AbstractFunctionError()
......@@ -215,43 +219,45 @@ class ScalarMixedOp(GuardedOp):
raise AbstractFunctionError()
def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
return str(z.dtype)
class PureScalarOp(ScalarMixedOp):
cast_method = lambda self, *args: upcast(*args)
def propagate_dtypes(self, *i_dtypes):
for dtype in i_dtypes:
if dtype is None:
raise TypeError("Expected a Scalar.")
return [self.cast_method(*i_dtypes)] * self.nout
if self.nout == 1:
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
else:
results = utils.from_return_values(self.impl(*[input.data for input in self.inputs]))
for output, result in zip(self.outputs, results):
output.data = result
class UnaryScalarOp(PureScalarOp):
class UnaryScalarOp(ScalarOp):
nin = 1
class BinaryScalarOp(PureScalarOp):
class BinaryScalarOp(ScalarOp):
nin = 2
class Add(BinaryScalarOp):
class Add(ScalarOp):
identity = 0
def impl(self, x, y):
return x + y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz, gz
def impl(self, *inputs):
return sum(inputs)
def c_code(self, inputs, (z, ), sub):
if not inputs:
return z + " = 0;"
else:
return z + " = " + " + ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return (gz, ) * len(inputs)
class Mul(ScalarOp):
identity = 1
def impl(self, *inputs):
return numpy.product(inputs)
def c_code(self, inputs, (z, ), sub):
if not inputs:
return z + " = 1;"
else:
return z + " = " + " * ".join(inputs) + ";"
def grad(self, inputs, (gz, )):
return [mul(*([gz] + utils.difference(inputs, [input])))
for input in inputs]
class Sub(BinaryScalarOp):
def impl(self, x, y):
......@@ -261,14 +267,6 @@ class Sub(BinaryScalarOp):
def grad(self, (x, y), (gz, )):
return gz, -gz
class Mul(BinaryScalarOp):
def impl(self, x, y):
return x * y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz * y, gz * x
class Div(BinaryScalarOp):
def impl(self, x, y):
return x / y
......@@ -302,6 +300,7 @@ class Second(BinaryScalarOp):
return None, gz
class Identity(UnaryScalarOp):
def impl(self, x):
return x
......@@ -333,7 +332,8 @@ class Sgn(UnaryScalarOp):
def grad(self, (x, ), (gz, )):
return None,
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s/abs(%(x)s);" % locals() # TODO: C use copysign
return "%(z)s = %(x)s/%(prefix)sabs(%(x)s);" \
% dict(locals(), prefix = 'float' in self.inputs[0].dtype and 'f' or '') # TODO: C use copysign
class Inv(UnaryScalarOp):
def impl(self, x):
......@@ -405,7 +405,7 @@ def composite(inputs, outputs):
The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of
PureScalarOp.
ScalarOp.
Examples:
x, y = Scalar(), Scalar()
......@@ -420,8 +420,8 @@ def composite(inputs, outputs):
inputs, outputs = env.inputs, env.outputs
for op in env.ops():
if not isinstance(op, PureScalarOp):
raise ValueError("The input env to composite must be exclusively composed of PureScalarOp instances.")
if not isinstance(op, ScalarOp):
raise ValueError("The input env to composite must be exclusively composed of ScalarOp instances.")
subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) +
......@@ -460,7 +460,7 @@ def composite(inputs, outputs):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops
# still correct since we only have scalar ops
if r in env.inputs:
idx = env.inputs.index(r)
return lambda inputs: inputs[idx]
......@@ -472,7 +472,7 @@ def composite(inputs, outputs):
_impls = [compose_impl(r) for r in env.outputs]
class Composite(PureScalarOp):
class Composite(ScalarOp):
nin = len(inputs)
nout = len(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论