broadcast, reduce are working in python and c

上级 88f7858f
...@@ -8,31 +8,23 @@ from scalar_ops import * ...@@ -8,31 +8,23 @@ from scalar_ops import *
def inputs(): def inputs():
x = modes.build_eval(as_scalar(1.0, 'x')) x = modes.build(as_scalar(1.0, 'x'))
y = modes.build_eval(as_scalar(2.0, 'y')) y = modes.build(as_scalar(2.0, 'y'))
z = modes.build_eval(as_scalar(3.0, 'z')) z = modes.build(as_scalar(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_ScalarOps(unittest.TestCase): class _test_ScalarOps(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
assert e.data == 1.5
def test_1(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
g = env([x, y], [e]) g = env([x, y], [e])
fn = gof.cc.CLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
assert e.data == 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -691,7 +691,7 @@ class t_gemm(unittest.TestCase): ...@@ -691,7 +691,7 @@ class t_gemm(unittest.TestCase):
self.rand(3,5), self.rand(5,4), 1.0) self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0, def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0) self.rand(3,5), self.rand(5,4), -1.0)
t_gemm = None
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -108,53 +108,53 @@ class BaseTensor(ResultBase): ...@@ -108,53 +108,53 @@ class BaseTensor(ResultBase):
# #
# C codegen stubs # C codegen stubs
# #
def c_declare(self): def c_declare(self, name, sub):
return """ return """
PyArrayObject* %%(name)s; PyArrayObject* %(name)s;
int type_num_%%(name)s; int type_num_%(name)s;
typedef %(dtype)s dtype_%%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(dtype = self.dtype_specs()[1]) """ % dict(sub, name = name, dtype = self.dtype_specs()[1])
def c_init(self): def c_init(self, name, sub):
return """ return """
%%(name)s = NULL; %(name)s = NULL;
type_num_%%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
""" % dict(type_num = self.dtype_specs()[2]) """ % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_extract(self): def c_extract(self, name, sub):
return """ return """
%%(name)s = NULL; %(name)s = NULL;
type_num_%%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
if (py_%%(name)s == Py_None) { if (py_%(name)s == Py_None) {
// We can either fail here or set %%(name)s to NULL and rely on Ops using // We can either fail here or set %(name)s to NULL and rely on Ops using
// tensors to handle the NULL case, but if they fail to do so they'll end up // tensors to handle the NULL case, but if they fail to do so they'll end up
// with nasty segfaults, so this is public service. // with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None"); PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%%(fail)s %(fail)s
//%%(name)s = NULL; //%(name)s = NULL;
} }
else if (!PyArray_Check(py_%%(name)s)) { else if (!PyArray_Check(py_%(name)s)) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray"); PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%%(fail)s %(fail)s
} }
else if (((PyArrayObject*)py_%%(name)s)->descr->type_num != %(type_num)s) { else if (((PyArrayObject*)py_%(name)s)->descr->type_num != %(type_num)s) {
PyErr_SetString(PyExc_ValueError, "expected %(type_num)s"); PyErr_SetString(PyExc_ValueError, "expected %(type_num)s");
%%(fail)s %(fail)s
} }
else { else {
%%(name)s = (PyArrayObject*)(py_%%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%%(name)s); Py_XINCREF(%(name)s);
} }
""" % dict(type_num = self.dtype_specs()[2]) """ % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_cleanup(self): def c_cleanup(self, name, sub):
return """ return """
if (%(name)s) { if (%(name)s) {
Py_XDECREF(%(name)s); Py_XDECREF(%(name)s);
} }
""" """ % locals()
def c_sync(self): def c_sync(self, name, sub):
return """ return """
if (!%(name)s) { if (!%(name)s) {
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
...@@ -165,7 +165,7 @@ class BaseTensor(ResultBase): ...@@ -165,7 +165,7 @@ class BaseTensor(ResultBase):
py_%(name)s = (PyObject*)%(name)s; py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
} }
""" """ % locals()
def c_headers(self): def c_headers(self):
return [] return []
......
...@@ -16,8 +16,8 @@ exec_opt.optimizer = None ...@@ -16,8 +16,8 @@ exec_opt.optimizer = None
def default_optimizer(env): def default_optimizer(env):
#TODO: pass tests with these un-commented #TODO: pass tests with these un-commented
default_optimizer.const(env) # default_optimizer.const(env)
default_optimizer.merge(env) # default_optimizer.merge(env)
pass pass
default_optimizer.merge = gof.opt.MergeOptimizer() default_optimizer.merge = gof.opt.MergeOptimizer()
default_optimizer.const = gof.opt.ConstantFinder() default_optimizer.const = gof.opt.ConstantFinder()
......
...@@ -61,21 +61,41 @@ class Elemwise(Op): ...@@ -61,21 +61,41 @@ class Elemwise(Op):
return ret return ret
def c_validate_update(self): def c_validate_update(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd return valupd % sub
def c_validate_update_cleanup(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_validate_update_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd_cleanup return valupd_cleanup % sub
def c_code(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_code(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code return code % sub
def c_code_cleanup(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_code_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code_cleanup return code_cleanup % sub
@classmethod @classmethod
def inplace_version(cls, dmap = {0:0}): def inplace_version(cls, dmap = {0:0}):
......
import elemwise_cgen as cgen
# foldl(f, fold_inputs, init) => import numpy
# fold_inputs = init; from gof import Op, Viewer, Destroyer
# for loop_inputs in c_order(difference(inputs, fold_inputs)): from tensor import Tensor
# fold_inputs = f(fold_inputs, loop_inputs) from scalar import upcast, Scalar
# a+b+c+d => ((a+b)+c)+d import scalar_ops
import gof
# foldr(f, fold_inputs, init) =>
# fold_inputs = init;
# for loop_inputs in reversed_c_order(difference(inputs, fold_inputs)):
# fold_inputs = f(fold_inputs, loop_inputs)
# a**b**c**d => a**(b**(c**d))
# foldx(f, fold_inputs, init) =>
# fold_inputs = init;
# for loop_inputs in any_order(difference(inputs, fold_inputs)):
# fold_inputs = f(fold_inputs, loop_inputs)
# a+b+c+d => ((a+b)+c)+d
# a+b+c+d => a+(b+(c+d))
# a+b+c+d => (a+b)+(c+d)
# foldx <=> f.associative
# f.associative => (foldl => foldx) and (foldr => foldx)
# z = a*b + b*c + c*d + d*e
# z: (0, 0, 0, 0)
# a: (0, 0, 0, 0) => (0, 0, 1, 0, 0, 1) => loop order: 1, 2, 3, 4, x, x
# b: (0, 0) => (1, 1, 1, 0, 0, 1) => loop order: x, x, x, 1, 2, x
# c: (0, 0, S, 0, 0, S) => (0, 0, S, 0, 0, S) => loop order: 1, 2, 4, 5, 3, 6
# d: (1, 0, 1) => (1, 1, 1, 0, 1, 1) => loop order: x, x, x, 2, x, x
# e: (S, 0, 0, S, 0, 0) => => loop order: 2, 3, 5, 6, 1, 4
# strategy: (broadcasted, folded, fold_method)
# (2, 1, 1, 3, 1), (1, 7, 1, 1, 4)
# (2, 7, 1, 3, 4), (1, 1, 8, 1, 4)
# (2, 7, 8, 3, 4)
# (2, 3, 4, 5), (7, 3, 4, 8)
# (2, 3, 4), (3, 4)
# (2, 3, 4)
class ElemwiseGroup:
def __init__(self):
self.
def compile_env(env):
mappings = {}
order = env.io_toposort()
for op in reversed(order):
if not isinstance(op, Elemwise):
raise TypeError("Unsupported op type for the Elemwise compiler.", op)
for input in op.input_policy:
strategies.setdefault()
def elemwise_op_gen(op, modalities):
"""
* op: z = x + y
modalities: {z: foldx(0, x, y)}
result: Z = sum(Y)
* op: z = x + y
"""
def broadcasting_cgen(op):
template = op.c_foreach()
##################
### DimShuffle ###
##################
class DimShuffle(Op, Viewer): class DimShuffle(Op, Viewer):
...@@ -108,11 +37,14 @@ class DimShuffle(Op, Viewer): ...@@ -108,11 +37,14 @@ class DimShuffle(Op, Viewer):
self.inplace = inplace self.inplace = inplace
self.numorder = [x for x in new_order if type(x) == int] self.numorder = [x for x in new_order if type(x) == int]
self.is_transposition = sorted(new_order) == range(length(ib)) self.is_transposition = sorted(new_order) == range(len(ib))
self.dup_dims = len(set(self.numorder)) != len(self.numorder) self.dup_dims = len(set(self.numorder)) != len(self.numorder)
self.all_dims = len(set(self.numorder)) == len(ib) self.all_dims = len(set(self.numorder)) == len(ib)
if self.dup_dims or not self.all_dims: if self.dup_dims or not self.all_dims:
raise NotImplementedError("You must provide a permutation of *all* the input dimensions with *no duplicates*.") raise NotImplementedError("You must provide a permutation of *all* the input dimensions with *no duplicates*.")
def clone_with_new_inputs(self, *new_inputs):
return DimShuffle(new_inputs[0], self.new_order, self.inplace)
def view_map(self): def view_map(self):
if self.inplace: if self.inplace:
...@@ -124,13 +56,13 @@ class DimShuffle(Op, Viewer): ...@@ -124,13 +56,13 @@ class DimShuffle(Op, Viewer):
res = self.inputs[0].data.transpose(self.numorder) res = self.inputs[0].data.transpose(self.numorder)
shape = list(res.shape) shape = list(res.shape)
new_shape = [] new_shape = []
for entry in new_order: for entry in self.new_order:
if entry == 'x': if entry == 'x':
new_shape.append(1) new_shape.append(1)
else: else:
new_shape.append(shape.pop()) new_shape.append(shape.pop(0))
res = res.reshape(new_shape) res = res.reshape(new_shape)
if not inplace: if not self.inplace:
res = numpy.copy(res) res = numpy.copy(res)
self.outputs[0].data = res self.outputs[0].data = res
...@@ -141,20 +73,40 @@ class DimShuffle(Op, Viewer): ...@@ -141,20 +73,40 @@ class DimShuffle(Op, Viewer):
class Transpose(DimShuffle): class Transpose(DimShuffle):
def __init__(self, input): def __init__(self, input):
DimShuffle.__init__(self, input, range(len(input.broadcastable)-1, -1, -1)) DimShuffle.__init__(self, input, range(len(input.broadcastable)-1, -1, -1), False)
def clone_with_new_inputs(self, *new_inputs):
return Transpose(new_inputs[0])
def __str__(self):
return "%s(%s)" % (self.__class__.__name__, str(self.inputs[0]))
#################
### Broadcast ###
#################
class Broadcast(Op, Destroyer): class Broadcast(Op, Destroyer):
def __init__(self, scalar_opclass, inputs, inplace_pattern): def __init__(self, scalar_opclass, inputs, inplace_pattern = {}):
try: try:
assert len(set([len(input.broadcastable) for input in inputs]) == 1) assert len(set([len(input.broadcastable) for input in inputs])) == 1
except (AssertionError, AttributeError): except (AssertionError, AttributeError):
raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__) raise TypeError("All inputs to a Broadcast subclass must be Tensor instances and their broadcastable fields must all have the same length.", self.__class__)
self.nin = scalar_opclass.nin
self.nout = scalar_opclass.nout
out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout out_broadcastables = [[1*all(bcast) for bcast in zip(*[input.broadcastable for input in inputs])]] * self.nout
if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip(out_broadcastables[overwriter], inputs[overwritten].broadcastable):
if ib and not ob:
raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.")
upcasted = upcast(*[input.dtype for input in inputs]) upcasted = upcast(*[input.dtype for input in inputs])
def get_dtype(i): def get_dtype(i):
input_idx = inplace_pattern.get(i, [None]) input_idx = inplace_pattern.get(i, None)
if input_idx is not None: if input_idx is not None:
return inputs[input_idx].dtype return inputs[input_idx].dtype
else: else:
...@@ -164,8 +116,11 @@ class Broadcast(Op, Destroyer): ...@@ -164,8 +116,11 @@ class Broadcast(Op, Destroyer):
self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)] self.outputs = [Tensor(dtype = dtype, broadcastable = broadcastable) for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
self.scalar_opclass = scalar_opclass self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass([Scalar(dtype = t.dtype) for t in self.inputs]) self.shadow = scalar_opclass(*[Scalar(dtype = t.dtype) for t in self.inputs])
self.ufunc = numpy.frompyfunc(scalar_opclass.impl, scalar_opclass.nin, scalar_opclass.nout) self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def clone_with_new_inputs(self, *new_inputs):
return Broadcast(self.scalar_opclass, new_inputs, self.inplace_pattern)
def id(self): def id(self):
return (self.__class__, self.scalar_opclass, self.inplace_pattern) return (self.__class__, self.scalar_opclass, self.inplace_pattern)
...@@ -194,7 +149,8 @@ class Broadcast(Op, Destroyer): ...@@ -194,7 +149,8 @@ class Broadcast(Op, Destroyer):
r = transform(scalar_igrad) r = transform(scalar_igrad)
to_sum = [i for i, bcast in enumerate(input.broadcastable) if bcast] to_sum = [i for i, bcast in enumerate(input.broadcastable) if bcast]
if to_sum: if to_sum:
ret.append(Sum(r, to_sum)) sr = Sum(r, axis = to_sum).out
ret.append(sr)
else: else:
ret.append(r) ret.append(r)
return ret return ret
...@@ -204,24 +160,114 @@ class Broadcast(Op, Destroyer): ...@@ -204,24 +160,114 @@ class Broadcast(Op, Destroyer):
if not self.inplace_pattern: if not self.inplace_pattern:
for output in self.outputs: for output in self.outputs:
odat = output.data odat = output.data
shape = [max(values) for values in zip(*[input.data.shape for input in self.inputs])]
if odat is not None: if odat is not None:
odat.resize(self.inputs[0].data.shape) odat.resize(shape)
else: else:
odat = numpy.ndarray(self.inputs[0].data.shape, dtype = output.dtype) odat = numpy.ndarray(shape, dtype = output.dtype)
output_storage.append(odat) output_storage.append(odat)
output.data = odat
else: else:
for i, output in enumerate(self.outputs): for i, output in enumerate(self.outputs):
if i in self.inplace_pattern: if i in self.inplace_pattern:
odat = self.inputs[self.inplace_pattern[i]].data odat = self.inputs[self.inplace_pattern[i]].data
else: else:
odat = output.data odat = output.data
shape = [max(values) for values in zip(*[input.data.shape for input in self.inputs])]
if odat is not None: if odat is not None:
odat.resize(self.inputs[0].data.shape) odat.resize(shape)
else: else:
odat = numpy.ndarray(self.inputs[0].data.shape, dtype = output.dtype) odat = numpy.ndarray(shape, dtype = output.dtype)
output_storage.append(odat) output_storage.append(odat)
output.data = odat
self.ufunc(*([input.data for input in self.inputs] + output_storage)) self.ufunc(*([input.data for input in self.inputs] + output_storage))
def _c_all(self, inames, onames, sub):
defines = ""
undefs = ""
dmap = self.destroy_map()
idtypes = [input.dtype_specs()[1] for input in self.inputs]
real = zip(*[(r, s, r.dtype_specs()[1])
for r, s in zip(self.outputs, onames) if r not in dmap])
if real:
real_outputs, real_onames, real_odtypes = real
else:
real_outputs, real_onames, real_odtypes = [], [], []
aliased = zip(*[(r, s)
for (r, s) in zip(self.outputs, onames) if r in dmap])
if aliased:
aliased_outputs, aliased_onames = aliased
else:
aliased_outputs, aliased_onames = [], []
orders = [[x and 'x' or i for i, x in enumerate(input.broadcastable)] for input in self.inputs]
nnested = len(orders[0])
sub = dict(sub)
for i, (input, iname) in enumerate(zip(self.inputs, inames)):
sub['lv%i' % i] = iname
decl = cgen.make_declare(orders, idtypes, sub)
checks = cgen.make_checks(orders, idtypes, sub)
alloc = ""
for output, oname, odtype in zip(real_outputs, real_onames, real_odtypes):
i += 1
sub['lv%i' % i] = oname
sub['olv'] = oname
alloc += cgen.make_declare([range(nnested)], [odtype], dict(sub, lv0 = oname))
alloc += cgen.make_alloc(orders, odtype, sub)
alloc += cgen.make_checks([range(nnested)], [odtype], dict(sub, lv0 = oname))
for output, oname in zip(aliased_outputs, aliased_onames):
iname = inames[self.inputs.index(dmap[output][0])]
alloc += """
if (%(oname)s) {
Py_XDECREF(%(oname)s);
}
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
defines += "#define %(oname)s_i %(iname)s_i" % locals()
undefs += "#undef %(oname)s_i" % locals()
task_code = self.shadow.c_code(["%s_i" % s for s in inames],
["%s_i" % s for s in onames],
sub)
task_decl = "".join(["%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals() for name, dtype in zip(inames + list(real_onames), idtypes + list(real_odtypes))])
code = """
{
%(defines)s
%(task_decl)s
%(task_code)s
%(undefs)s
}
""" % locals()
if nnested:
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
else:
all_code = [code]
loop = cgen.make_loop(orders + [range(nnested)] * len(real_onames), idtypes + list(real_odtypes), all_code, sub)
return decl, checks, alloc, loop
def c_code(self, inames, onames, sub):
code = "\n".join(self._c_all(inames, onames, sub))
return code
def make_broadcast(scalar_opclass, inplace_pattern = {}, name = None):
class New(Broadcast):
def __init__(self, *inputs):
Broadcast.__init__(self, scalar_opclass, inputs, inplace_pattern)
def clone_with_new_inputs(self, *new_inputs):
return New(*new_inputs)
if name is not None:
New.__name__ = name
else:
New.__name__ = "Tensor" + scalar_opclass.__name__
return New
def broadcast(op): def broadcast(op):
def instantiate(*inputs): def instantiate(*inputs):
...@@ -234,8 +280,14 @@ def broadcast(op): ...@@ -234,8 +280,14 @@ def broadcast(op):
else: else:
args.append(DimShuffle(input, ['x']*difference + range(length))) args.append(DimShuffle(input, ['x']*difference + range(length)))
return op(*args) return op(*args)
return instantiate
################
### CAReduce ###
################
class CAReduce(Op): class CAReduce(Op):
""" """
CAReduce(scalar_op, inputs, dimensions_to_reduce = None, init = None, shortcut = False) CAReduce(scalar_op, inputs, dimensions_to_reduce = None, init = None, shortcut = False)
...@@ -269,8 +321,148 @@ class CAReduce(Op): ...@@ -269,8 +321,148 @@ class CAReduce(Op):
def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None): def __init__(self, scalar_opclass, inputs, dimensions_to_reduce = None):
if scalar_opclass.nin != 2 or scalar_opclass.nout != 1: if scalar_opclass.nin != 2 or scalar_opclass.nout != 1:
raise NotImplementedError("CAReduce only supports binary functions with a single output.") raise NotImplementedError("CAReduce only supports binary functions with a single output.")
if len(inputs) != 1:
raise TypeError("Only one argument expected.")
if dimensions_to_reduce is None:
dimensions_to_reduce = range(len(inputs[0].broadcastable))
self.nin = 1
self.nout = 1
self.inputs = inputs
self.outputs = [Tensor(dtype = inputs[0].dtype,
broadcastable = [x for i, x in enumerate(inputs[0].broadcastable) if i not in dimensions_to_reduce])]
self.dimensions_to_reduce = dimensions_to_reduce
self.scalar_opclass = scalar_opclass
self.shadow = scalar_opclass(*[Scalar(dtype = inputs[0].dtype) for i in xrange(scalar_opclass.nin)])
self.ufunc = numpy.frompyfunc(self.shadow.impl, scalar_opclass.nin, scalar_opclass.nout)
def id(self):
return (self.__class__, self.scalar_opclass, self.dimensions_to_reduce)
def clone_with_new_inputs(self, *new_inputs):
return CAReduce(self.scalar_opclass, new_inputs, self.dimensions_to_reduce)
def perform(self):
result = self.inputs[0].data
for dimension in reversed(sorted(self.dimensions_to_reduce)):
result = self.ufunc.reduce(result, dimension)
self.outputs[0].data = result
def _c_all(self, inames, onames, sub):
input = self.inputs[0]
output = self.outputs[0]
iname = inames[0]
oname = onames[0]
idtype = input.dtype_specs()[1]
odtype = output.dtype_specs()[1]
tosum = self.dimensions_to_reduce
order1 = [i for i in xrange(len(input.broadcastable)) if i not in tosum]
order = order1 + list(tosum)
nnested = len(order1)
sub = dict(sub)
for i, (input, iname) in enumerate(zip(self.inputs, inames)):
sub['lv%i' % i] = iname
decl = cgen.make_declare([order], [idtype], sub)
checks = cgen.make_checks([order], [idtype], sub)
alloc = ""
i += 1
sub['lv%i' % i] = oname
sub['olv'] = oname
alloc += cgen.make_declare([range(nnested) + ['x'] * len(tosum)], [odtype], dict(sub, lv0 = oname))
alloc += cgen.make_alloc([order1], odtype, sub)
alloc += cgen.make_checks([range(nnested) + ['x'] * len(tosum)], [odtype], dict(sub, lv0 = oname))
task0_decl = "%(dtype)s& %(name)s_i = *%(name)s_iter;\n%(name)s_i = %(identity)s;" % dict(dtype = odtype,
name = onames[0],
identity = self.shadow.identity)
task1_decl = "%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % dict(dtype = idtype, name = inames[0])
task1_code = self.shadow.c_code(["%s_i" % onames[0], "%s_i" % inames[0]],
["%s_i" % onames[0]],
sub)
code1 = """
{
%(task1_decl)s
%(task1_code)s
}
""" % locals()
if len(tosum) == 1:
all_code = [("", "")] * nnested + [(task0_decl, code1), ""]
else:
all_code = [("", "")] * nnested + [(task0_decl, "")] + [("", "")] * (len(tosum) - 2) + [("", code1), ""]
# if nnested:
# all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
# else:
# all_code = [code]
def reduce(op, dimensions_to_reduce): # print [order, range(nnested) + ['x'] * len(tosum)]
loop = cgen.make_loop([order, range(nnested) + ['x'] * len(tosum)], [idtype, odtype], all_code, sub)
return decl, checks, alloc, loop
def c_code(self, inames, onames, sub):
code = "\n".join(self._c_all(inames, onames, sub))
# print code
return code
def __str__(self):
input = self.inputs[0]
if len(input.broadcastable) == len(self.dimensions_to_reduce):
return "%s:%s(%s)" % (self.__class__.__name__,
self.scalar_opclass.__name__,
str(input))
else:
return "%s:%s(%s, axis = %s)" % (self.__class__.__name__,
self.scalar_opclass.__name__,
str(input),
self.dimensions_to_reduce)
def make_reduce(scalar_opclass, name = None):
if getattr(scalar_opclass, 'commutative', True) \
and getattr(scalar_opclass, 'associative', True):
reducer = CAReduce
else:
raise NotImplementedError("The scalar op class to reduce must be commutative and associative.")
class New(reducer):
def __init__(self, *inputs, **kwargs):
reducer.__init__(self, scalar_opclass, inputs, kwargs.get('axis', None))
def clone_with_new_inputs(self, *new_inputs):
return New(*new_inputs, **dict(axis = self.dimensions_to_reduce))
def __str__(self):
input = self.inputs[0]
if len(input.broadcastable) == len(self.dimensions_to_reduce):
return "%s(%s)" % (self.__class__.__name__,
str(input))
else:
return "%s(%s, axis = %s)" % (self.__class__.__name__,
str(input),
self.dimensions_to_reduce)
if name is not None:
New.__name__ = name
else:
New.__name__ = "Reduce" + scalar_opclass.__name__
return New
Sum = make_reduce(scalar_ops.Add, name = 'Sum')
def reduce(op):
if getattr(op, 'commutative', True) and getattr(op, 'associative', True): if getattr(op, 'commutative', True) and getattr(op, 'associative', True):
reducer = CAReduce reducer = CAReduce
else: else:
...@@ -281,16 +473,3 @@ def reduce(op, dimensions_to_reduce): ...@@ -281,16 +473,3 @@ def reduce(op, dimensions_to_reduce):
# class Elemwise(TensorOp):
# def propagate_dtype(self, idtypes):
# raise AbstractFunctionError
# def propagate_broadcastable(self, ibroadcastables):
# raise AbstractFunctionError
# def _calculate_elemwise_strategy(self, input_strategies):
# raise AbstractFunctionError
...@@ -25,20 +25,20 @@ class Double(ResultBase): ...@@ -25,20 +25,20 @@ class Double(ResultBase):
# def c_is_simple(self): return True # def c_is_simple(self): return True
def c_declare(self): def c_declare(self, name, sub):
return "double %(name)s; void* %(name)s_bad_thing;" return "double %(name)s; void* %(name)s_bad_thing;" % locals()
def c_init(self): def c_init(self, name, sub):
return """ return """
%(name)s = 0; %(name)s = 0;
%(name)s_bad_thing = malloc(100000); %(name)s_bad_thing = malloc(100000);
//printf("Initializing %(name)s\\n"); //printf("Initializing %(name)s\\n");
""" """ % locals()
def c_literal(self): def c_literal(self):
return str(self.data) return str(self.data)
def c_extract(self): def c_extract(self, name, sub):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "not a double!"); PyErr_SetString(PyExc_TypeError, "not a double!");
...@@ -47,23 +47,23 @@ class Double(ResultBase): ...@@ -47,23 +47,23 @@ class Double(ResultBase):
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL; %(name)s_bad_thing = NULL;
//printf("Extracting %(name)s\\n"); //printf("Extracting %(name)s\\n");
""" """ % dict(locals(), **sub)
def c_sync(self): def c_sync(self, name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = PyFloat_FromDouble(%(name)s); py_%(name)s = PyFloat_FromDouble(%(name)s);
if (!py_%(name)s) if (!py_%(name)s)
py_%(name)s = Py_None; py_%(name)s = Py_None;
//printf("Syncing %(name)s\\n"); //printf("Syncing %(name)s\\n");
""" """ % locals()
def c_cleanup(self): def c_cleanup(self, name, sub):
return """ return """
//printf("Cleaning up %(name)s\\n"); //printf("Cleaning up %(name)s\\n");
if (%(name)s_bad_thing) if (%(name)s_bad_thing)
free(%(name)s_bad_thing); free(%(name)s_bad_thing);
""" """ % locals()
class MyOp(Op): class MyOp(Op):
...@@ -80,43 +80,43 @@ class MyOp(Op): ...@@ -80,43 +80,43 @@ class MyOp(Op):
class Unary(MyOp): class Unary(MyOp):
nin = 1 nin = 1
def c_var_names(self): # def c_var_names(self):
return [['x'], ['z']] # return [['x'], ['z']]
class Binary(MyOp): class Binary(MyOp):
nin = 2 nin = 2
def c_var_names(self): # def c_var_names(self):
return [['x', 'y'], ['z']] # return [['x', 'y'], ['z']]
class Add(Binary): class Add(Binary):
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" return "%(z)s = %(x)s + %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data + self.inputs[1].data self.outputs[0].data = self.inputs[0].data + self.inputs[1].data
class Sub(Binary): class Sub(Binary):
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" return "%(z)s = %(x)s - %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = -10 # erroneous self.outputs[0].data = -10 # erroneous
class Mul(Binary): class Mul(Binary):
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" return "%(z)s = %(x)s * %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data * self.inputs[1].data self.outputs[0].data = self.inputs[0].data * self.inputs[1].data
class Div(Binary): class Div(Binary):
def c_validate_update(self): def c_validate_update(self, (x, y), (z, ), sub):
return """ return """
if (%(y)s == 0.0) { if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero"); PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s %(fail)s
} }
""" """ % dict(locals(), **sub)
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" return "%(z)s = %(x)s / %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data / self.inputs[1].data self.outputs[0].data = self.inputs[0].data / self.inputs[1].data
......
...@@ -66,15 +66,19 @@ class CodeBlock: ...@@ -66,15 +66,19 @@ class CodeBlock:
to jump to. It should also contain a key called 'failure_var' that contains to jump to. It should also contain a key called 'failure_var' that contains
the name of the variable that contains the error code. the name of the variable that contains the error code.
""" """
self.declare = declare % sub self.declare = declare #% sub
behavior_sub = copy(sub) # behavior_sub = copy(sub)
behavior_sub['fail'] = "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub # behavior_sub['fail'] = "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
self.behavior = behavior % behavior_sub self.behavior = behavior #% behavior_sub
# the dummy is because gcc throws an error when a label's right next to a closing # the dummy is because gcc throws an error when a label's right next to a closing
# brace (maybe there's an ignore flag for that...) # brace (maybe there's an ignore flag for that...)
# we need the label even if cleanup is empty because the behavior block jumps there # we need the label even if cleanup is empty because the behavior block jumps there
# on failure # on failure
self.cleanup = ("__label_%(id)i:\n" + cleanup + "\ndouble __DUMMY_%(id)i;\n") % sub self.cleanup = ("__label_%(id)i:\n"%sub + cleanup + "\ndouble __DUMMY_%(id)i;\n"%sub) #% sub
def failure_code(sub):
return "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
def code_gen(blocks): def code_gen(blocks):
...@@ -192,14 +196,14 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -192,14 +196,14 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are 1-element lists # TODO: add some error checking to make sure storage_<x> are 1-element lists
# and __ERROR is a 3-elements list. # and __ERROR is a 3-elements list.
struct_code = """ struct_code = """
struct %%(name)s { struct %(name)s {
PyObject* __ERROR; PyObject* __ERROR;
%(storage_decl)s %(storage_decl)s
%(struct_decl)s %(struct_decl)s
%%(name)s() {} %(name)s() {}
~%%(name)s(void) { ~%(name)s(void) {
cleanup(); cleanup();
} }
...@@ -232,47 +236,47 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -232,47 +236,47 @@ def struct_gen(args, struct_builders, blocks, sub):
# The get_<x> functions complete the return value of r.get_<x>() # The get_<x> functions complete the return value of r.get_<x>()
# with handling of the py_<name> variable. # with handling of the py_<name> variable.
def get_nothing(r): def get_nothing(r, name, sub):
"" ""
return "" return ""
def get_c_declare(r): def get_c_declare(r, name, sub):
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" """ % locals()
return pre + r.c_declare() return pre + r.c_declare(name, sub)
def get_c_init(r): def get_c_init(r, name, sub):
pre = "" """ pre = "" """
py_%(name)s = Py_None; py_%(name)s = Py_None;
""" """ % locals()
return pre + r.c_init() return pre + r.c_init(name, sub)
def get_c_extract(r): def get_c_extract(r, name, sub):
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
""" """ % locals()
return pre + r.c_extract() return pre + r.c_extract(name, sub)
def get_c_cleanup(r): def get_c_cleanup(r, name, sub):
post = """ post = """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
""" """ % locals()
return r.c_cleanup() + post return r.c_cleanup(name, sub) + post
def get_c_sync(r): def get_c_sync(r, name, sub):
return """ return """
if (!%%(failure_var)s) { if (!%(failure_var)s) {
%(sync)s %(sync)s
PyObject* old = PyList_GET_ITEM(storage_%%(name)s, 0); PyObject* old = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%%(name)s); Py_XINCREF(py_%(name)s);
PyList_SET_ITEM(storage_%%(name)s, 0, py_%%(name)s); PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
Py_XDECREF(old); Py_XDECREF(old);
} }
""" % dict(sync = r.c_sync()) """ % dict(sync = r.c_sync(name, sub), name = name, **sub)
def apply_policy(policy, r): def apply_policy(policy, r, name, sub):
""" """
policy -> list of functions that map a Result to a string, policy -> list of functions that map a Result to a string,
or a single such function or a single such function
...@@ -282,8 +286,8 @@ def apply_policy(policy, r): ...@@ -282,8 +286,8 @@ def apply_policy(policy, r):
if isinstance(r, (list, tuple)): if isinstance(r, (list, tuple)):
ret = "" ret = ""
for sub_policy in policy: for sub_policy in policy:
ret += sub_policy(r) ret += sub_policy(r, name, sub)
return policy(r) return policy(r, name, sub)
...@@ -304,11 +308,15 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub): ...@@ -304,11 +308,15 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
name = "V%i" % id name = "V%i" % id
symbol_table[result] = name symbol_table[result] = name
sub = copy(sub) sub = copy(sub)
sub['name'] = name # sub['name'] = name
sub['id'] = id sub['id'] = id
struct_builder = CodeBlock(*[apply_policy(policy, result) for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub) sub['fail'] = failure_code(sub)
struct_builder = CodeBlock(*[apply_policy(policy, result, name, sub)
for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub)
sub['id'] = id + 1 sub['id'] = id + 1
block = CodeBlock(*[apply_policy(policy, result) for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub) sub['fail'] = failure_code(sub)
block = CodeBlock(*[apply_policy(policy, result, name, sub)
for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub)
return struct_builder, block return struct_builder, block
...@@ -453,33 +461,39 @@ class CLinker(Linker): ...@@ -453,33 +461,39 @@ class CLinker(Linker):
# We populate sub with a mapping from the variable names specified by the op's c_var_names # We populate sub with a mapping from the variable names specified by the op's c_var_names
# method to the actual variable names that we will use. # method to the actual variable names that we will use.
ivnames, ovnames = op.c_var_names() ## ivnames, ovnames = op.c_var_names()
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames): ## for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
sub[vname] = symbol[result] ## sub[vname] = symbol[result]
isyms, osyms = [symbol[r] for r in op.inputs], [symbol[r] for r in op.outputs]
# Make the CodeBlock for c_validate_update # Make the CodeBlock for c_validate_update
try: validate_behavior = op.c_validate_update() sub['id'] = id
sub['fail'] = failure_code(sub)
try: validate_behavior = op.c_validate_update(isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
validate_behavior = "" validate_behavior = ""
try: validate_cleanup = op.c_validate_update_cleanup() try: validate_cleanup = op.c_validate_update_cleanup(isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
validate_cleanup = "" validate_cleanup = ""
sub['id'] = id
blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub)) blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub))
tasks.append((op, 'validate_update', id)) tasks.append((op, 'validate_update', id))
id += 1 id += 1
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
behavior = op.c_code() # this one must be implemented! sub['id'] = id
sub['fail'] = failure_code(sub)
behavior = op.c_code(isyms, osyms, sub) # this one must be implemented!
try: cleanup = op.c_code_cleanup() try: cleanup = op.c_code_cleanup(isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
cleanup = "" cleanup = ""
sub['id'] = id
blocks.append(CodeBlock("", behavior, cleanup, sub)) blocks.append(CodeBlock("", behavior, cleanup, sub))
tasks.append((op, 'code', id)) tasks.append((op, 'code', id))
id += 1 id += 1
...@@ -489,7 +503,7 @@ class CLinker(Linker): ...@@ -489,7 +503,7 @@ class CLinker(Linker):
args = [] args = []
args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)] args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)]
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var)) struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var, name = "%(name)s"))
# The hash calculated on the code identifies it so weave can cache properly. # The hash calculated on the code identifies it so weave can cache properly.
# (the hash has to be used outside of the support code because weave does not consider changes in the support code) # (the hash has to be used outside of the support code because weave does not consider changes in the support code)
......
...@@ -78,16 +78,13 @@ class Env(graph.Graph): ...@@ -78,16 +78,13 @@ class Env(graph.Graph):
# The inputs and outputs set bound the subgraph this Env operates on. # The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = list(inputs) self.inputs = list(inputs)
self.outputs = list(outputs) self.outputs = list(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# All ops in the subgraph defined by inputs and outputs are cached in _ops # All ops in the subgraph defined by inputs and outputs are cached in _ops
self._ops = set() self._ops = set()
# Ditto for results # Ditto for results
self._results = set(self.inputs) self._results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but # Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph. # are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),) # e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
...@@ -95,6 +92,9 @@ class Env(graph.Graph): ...@@ -95,6 +92,9 @@ class Env(graph.Graph):
# it will be removed from the set of orphans. # it will be removed from the set of orphans.
self._orphans = set(outputs) self._orphans = set(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# Maps results to ops that use them: # Maps results to ops that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v] # if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {} self._clients = {}
......
...@@ -179,15 +179,15 @@ class Op(object): ...@@ -179,15 +179,15 @@ class Op(object):
# C code generators # C code generators
# #
def c_var_names(self): # def c_var_names(self):
""" # """
Returns ([list of input names], [list of output names]) for # Returns ([list of input names], [list of output names]) for
use as C variables. # use as C variables.
""" # """
return [["i%i" % i for i in xrange(len(self.inputs))], # return [["i%i" % i for i in xrange(len(self.inputs))],
["o%i" % i for i in xrange(len(self.outputs))]] # ["o%i" % i for i in xrange(len(self.outputs))]]
def c_validate_update(self): def c_validate_update(self, inputs, outputs, sub):
""" """
Returns templated C code that checks that the inputs to this Returns templated C code that checks that the inputs to this
function can be worked on. If a failure occurs, set an function can be worked on. If a failure occurs, set an
...@@ -198,13 +198,13 @@ class Op(object): ...@@ -198,13 +198,13 @@ class Op(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_validate_update_cleanup(self): def c_validate_update_cleanup(self, inputs, outputs, sub):
""" """
Clean up things allocated by c_validate(). Clean up things allocated by c_validate().
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_code(self): def c_code(self, inputs, outputs, sub):
""" """
Returns templated C code that does the computation associated Returns templated C code that does the computation associated
to this Op. You may assume that input validation and output to this Op. You may assume that input validation and output
...@@ -215,7 +215,7 @@ class Op(object): ...@@ -215,7 +215,7 @@ class Op(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_code_cleanup(self): def c_code_cleanup(self, inputs, outputs, sub):
""" """
Clean up things allocated by c_code(). Clean up things allocated by c_code().
""" """
......
...@@ -175,13 +175,13 @@ class ResultBase(object): ...@@ -175,13 +175,13 @@ class ResultBase(object):
""" """
return False return False
def c_declare(self): def c_declare(self, name, sub):
""" """
Declares variables that will be instantiated by c_data_extract. Declares variables that will be instantiated by c_data_extract.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_extract(self): def c_extract(self, name, sub):
""" """
The code returned from this function must be templated using The code returned from this function must be templated using
"%(name)s", representing the name that the caller wants to "%(name)s", representing the name that the caller wants to
...@@ -193,7 +193,7 @@ class ResultBase(object): ...@@ -193,7 +193,7 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_cleanup(self): def c_cleanup(self, name, sub):
""" """
This returns C code that should deallocate whatever This returns C code that should deallocate whatever
c_data_extract allocated or decrease the reference counts. Do c_data_extract allocated or decrease the reference counts. Do
...@@ -201,7 +201,7 @@ class ResultBase(object): ...@@ -201,7 +201,7 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_sync(self): def c_sync(self, name, sub):
""" """
The code returned from this function must be templated using "%(name)s", The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result. representing the name that the caller wants to call this Result.
...@@ -297,28 +297,28 @@ class PythonResult(ResultBase): ...@@ -297,28 +297,28 @@ class PythonResult(ResultBase):
through %(name)s. through %(name)s.
""" """
def c_declare(self): def c_declare(self, name, sub):
return """ return """
PyObject* %(name)s; PyObject* %(name)s;
""" """ % locals()
def c_extract(self): def c_extract(self, name, sub):
return """ return """
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
%(name)s = py_%(name)s; %(name)s = py_%(name)s;
""" """ % locals()
def c_cleanup(self): def c_cleanup(self, name, sub):
return """ return """
Py_XDECREF(%(name)s); Py_XDECREF(%(name)s);
""" """ % locals()
def c_sync(self): def c_sync(self, name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = %(name)s; py_%(name)s = %(name)s;
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
""" """ % locals()
def same_properties(self, other): def same_properties(self, other):
return False return False
......
...@@ -18,10 +18,9 @@ def as_scalar(x, name = None): ...@@ -18,10 +18,9 @@ def as_scalar(x, name = None):
class Scalar(ResultBase): class Scalar(ResultBase):
def __init__(self, dtype, data = None, name=None): def __init__(self, dtype, name = None):
ResultBase.__init__(self, role = None, name = name)
self.dtype = dtype self.dtype = dtype
self.constant = False
ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self): def __get_constant(self):
return self._constant return self._constant
...@@ -40,60 +39,59 @@ class Scalar(ResultBase): ...@@ -40,60 +39,59 @@ class Scalar(ResultBase):
def same_properties(self, other): def same_properties(self, other):
return other.dtype == self.dtype return other.dtype == self.dtype
def mergeable(self, other): # def mergeable(self, other):
return getattr(self, 'constant', False) \ # return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \ # and getattr(other, 'constant', False) \
and self.data == other.data # and self.data == other.data
def dtype_specs(self): def dtype_specs(self):
return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype] return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype]
# def py_type(self):
# return {'float64': float}[self.dtype]
# def c_type(self):
# return {'float64': 'double'}[self.dtype]
# def c_from(self):
# return {'float64': 'PyFloat_FromDouble'}[self.dtype]
# def c_as(self):
# return {'float64': 'PyFloat_AsDouble'}[self.dtype]
def c_declare(self): def c_declare(self, name, sub):
return """ return """
%(dtype)s %%(name)s; %(dtype)s %(name)s;
typedef %(dtype)s %%(name)s_dtype; typedef %(dtype)s %(name)s_dtype;
""" % dict(dtype = self.dtype_specs()[1]) """ % dict(name = name, dtype = self.dtype_specs()[1])
def c_init(self): def c_init(self, name, sub):
return """ return """
%(name)s = 0; %(name)s = 0;
""" """ % locals()
def c_extract(self): def c_extract(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
if (!%(check)s(py_%%(name)s)) if (!%(check)s(py_%(name)s))
%%(fail)s %(fail)s
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s); %(name)s = (%(dtype)s)%(conv)s(py_%(name)s);
""" % dict(dtype = specs[1], """ % dict(sub,
name = name,
dtype = specs[1],
check = specs[2], check = specs[2],
conv = specs[3]) conv = specs[3])
def c_sync(self): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
Py_XDECREF(py_%%(name)s); Py_XDECREF(py_%(name)s);
py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s); py_%(name)s = %(conv)s((%(dtype)s)%(name)s);
if (!py_%%(name)s) if (!py_%(name)s)
py_%%(name)s = Py_None; py_%(name)s = Py_None;
""" % dict(dtype = specs[1], """ % dict(name = name,
dtype = specs[1],
conv = specs[4]) conv = specs[4])
def c_cleanup(self): def c_cleanup(self, name, sub):
return "" return ""
def __copy__(self):
"""
Return a copy of this instance (with its own attributes)
"""
cpy = self.__class__(self.dtype, self.name)
cpy.data = self.data
return cpy
class ScalarMixedOp(GuardedOp): class ScalarMixedOp(GuardedOp):
...@@ -104,8 +102,8 @@ class ScalarMixedOp(GuardedOp): ...@@ -104,8 +102,8 @@ class ScalarMixedOp(GuardedOp):
def __init__(self, *inputs): def __init__(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \ raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self, len(inputs), self.nin) % (self.__class__.__name__, len(inputs), self.nin))
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes)) o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes))
...@@ -125,14 +123,14 @@ class ScalarMixedOp(GuardedOp): ...@@ -125,14 +123,14 @@ class ScalarMixedOp(GuardedOp):
def perform(self): def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs]) self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def c_var_names(self): # def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl) # (self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames) # inames = utils.from_return_values(inames)
onames = utils.from_return_values(onames) # onames = utils.from_return_values(onames)
return [inames, onames] # return [inames, onames]
def c_code(self): # def c_code(self):
return self.c_impl(self.inputs, self.outputs) # return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
......
...@@ -4,71 +4,91 @@ import math ...@@ -4,71 +4,91 @@ import math
class Add(BinaryScalarOp): class Add(BinaryScalarOp):
identity = 0
def impl(self, x, y): def impl(self, x, y):
return x + y return x + y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" return "%(z)s = %(x)s + %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return gz, gz return gz, gz
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x - y return x - y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return gz, -gz return gz, -gz
class Mul(BinaryScalarOp): class Mul(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x * y return x * y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return mul(y, gz), mul(x, gz) return mul(y, gz), mul(x, gz)
class Div(BinaryScalarOp): class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return div(gz, y), -div(mul(x, gz), y*y) return div(gz, y), -div(mul(x, gz), y*y)
class Pow(BinaryScalarOp): class Pow(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x ** y return x ** y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" return "%(z)s = pow(%(x)s, %(y)s);" % locals()
class First(BinaryScalarOp):
def impl(self, x, y):
return x
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
class Second(BinaryScalarOp):
def impl(self, x, y):
return y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals()
class SquareDiff(BinaryScalarOp):
def impl(self, x, y):
diff = (x - y)
return diff * diff
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s; %(z)s *= %(z)s;" % locals()
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
return -gz return -gz
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" return "%(z)s = -%(x)s;" % locals()
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1 / x return 1 / x
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
return -gz / (x*x) return -gz / (x*x)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" return "%(z)s = 1 / %(x)s;" % locals()
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" return "%(z)s = log(%(x)s);" % locals()
class Exp(UnaryScalarOp): class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" return "%(z)s = exp(%(x)s);" % locals()
# class Sigmoid(UnaryComposite): # class Sigmoid(UnaryComposite):
......
...@@ -136,8 +136,12 @@ class _Op(BaseTensorOp): ...@@ -136,8 +136,12 @@ class _Op(BaseTensorOp):
onames = utils.from_return_values(onames) onames = utils.from_return_values(onames)
return [inames, onames] return [inames, onames]
def c_code(self): def c_code(self, input_names, output_names, sub):
return self.c_impl(self.inputs, self.outputs) sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
return self.c_impl(self.inputs, self.outputs) % sub
def c_impl(self, inputs, outputs): def c_impl(self, inputs, outputs):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -759,7 +763,7 @@ class Gemm(_Op): ...@@ -759,7 +763,7 @@ class Gemm(_Op):
return blas.ldflags() return blas.ldflags()
def c_var_names(self): def c_var_names(self):
return [['_z', '_a', '_x', '_y', '_b'], ['_zout']] return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self): def c_validate_update(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """ return """
if (%(_zout)s) if (%(_zout)s)
{ {
...@@ -770,10 +774,10 @@ class Gemm(_Op): ...@@ -770,10 +774,10 @@ class Gemm(_Op):
%(_zout)s = %(_z)s; %(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s); Py_INCREF(%(_zout)s);
} }
""" """ % locals()
def c_validate_update_cleanup(self): def c_validate_update_cleanup(self, ignore, _ignore, __ignore):
return "" return ""
def c_code(self): def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """ return """
int unit = 0; int unit = 0;
...@@ -913,7 +917,7 @@ class Gemm(_Op): ...@@ -913,7 +917,7 @@ class Gemm(_Op):
break; break;
} }
""" """ % dict(locals(), **sub)
gemm = gof.op.constructor(Gemm) gemm = gof.op.constructor(Gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论