broadcast, reduce are working in python and c

上级 88f7858f
...@@ -8,31 +8,23 @@ from scalar_ops import * ...@@ -8,31 +8,23 @@ from scalar_ops import *
def inputs(): def inputs():
x = modes.build_eval(as_scalar(1.0, 'x')) x = modes.build(as_scalar(1.0, 'x'))
y = modes.build_eval(as_scalar(2.0, 'y')) y = modes.build(as_scalar(2.0, 'y'))
z = modes.build_eval(as_scalar(3.0, 'z')) z = modes.build(as_scalar(3.0, 'z'))
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_ScalarOps(unittest.TestCase): class _test_ScalarOps(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
assert e.data == 1.5
def test_1(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
g = env([x, y], [e]) g = env([x, y], [e])
fn = gof.cc.CLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
assert e.data == 1.5
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -691,7 +691,7 @@ class t_gemm(unittest.TestCase): ...@@ -691,7 +691,7 @@ class t_gemm(unittest.TestCase):
self.rand(3,5), self.rand(5,4), 1.0) self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0, def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0) self.rand(3,5), self.rand(5,4), -1.0)
t_gemm = None
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -108,53 +108,53 @@ class BaseTensor(ResultBase): ...@@ -108,53 +108,53 @@ class BaseTensor(ResultBase):
# #
# C codegen stubs # C codegen stubs
# #
def c_declare(self): def c_declare(self, name, sub):
return """ return """
PyArrayObject* %%(name)s; PyArrayObject* %(name)s;
int type_num_%%(name)s; int type_num_%(name)s;
typedef %(dtype)s dtype_%%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(dtype = self.dtype_specs()[1]) """ % dict(sub, name = name, dtype = self.dtype_specs()[1])
def c_init(self): def c_init(self, name, sub):
return """ return """
%%(name)s = NULL; %(name)s = NULL;
type_num_%%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
""" % dict(type_num = self.dtype_specs()[2]) """ % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_extract(self): def c_extract(self, name, sub):
return """ return """
%%(name)s = NULL; %(name)s = NULL;
type_num_%%(name)s = %(type_num)s; type_num_%(name)s = %(type_num)s;
if (py_%%(name)s == Py_None) { if (py_%(name)s == Py_None) {
// We can either fail here or set %%(name)s to NULL and rely on Ops using // We can either fail here or set %(name)s to NULL and rely on Ops using
// tensors to handle the NULL case, but if they fail to do so they'll end up // tensors to handle the NULL case, but if they fail to do so they'll end up
// with nasty segfaults, so this is public service. // with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None"); PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%%(fail)s %(fail)s
//%%(name)s = NULL; //%(name)s = NULL;
} }
else if (!PyArray_Check(py_%%(name)s)) { else if (!PyArray_Check(py_%(name)s)) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray"); PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%%(fail)s %(fail)s
} }
else if (((PyArrayObject*)py_%%(name)s)->descr->type_num != %(type_num)s) { else if (((PyArrayObject*)py_%(name)s)->descr->type_num != %(type_num)s) {
PyErr_SetString(PyExc_ValueError, "expected %(type_num)s"); PyErr_SetString(PyExc_ValueError, "expected %(type_num)s");
%%(fail)s %(fail)s
} }
else { else {
%%(name)s = (PyArrayObject*)(py_%%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%%(name)s); Py_XINCREF(%(name)s);
} }
""" % dict(type_num = self.dtype_specs()[2]) """ % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_cleanup(self): def c_cleanup(self, name, sub):
return """ return """
if (%(name)s) { if (%(name)s) {
Py_XDECREF(%(name)s); Py_XDECREF(%(name)s);
} }
""" """ % locals()
def c_sync(self): def c_sync(self, name, sub):
return """ return """
if (!%(name)s) { if (!%(name)s) {
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
...@@ -165,7 +165,7 @@ class BaseTensor(ResultBase): ...@@ -165,7 +165,7 @@ class BaseTensor(ResultBase):
py_%(name)s = (PyObject*)%(name)s; py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
} }
""" """ % locals()
def c_headers(self): def c_headers(self):
return [] return []
......
...@@ -16,8 +16,8 @@ exec_opt.optimizer = None ...@@ -16,8 +16,8 @@ exec_opt.optimizer = None
def default_optimizer(env): def default_optimizer(env):
#TODO: pass tests with these un-commented #TODO: pass tests with these un-commented
default_optimizer.const(env) # default_optimizer.const(env)
default_optimizer.merge(env) # default_optimizer.merge(env)
pass pass
default_optimizer.merge = gof.opt.MergeOptimizer() default_optimizer.merge = gof.opt.MergeOptimizer()
default_optimizer.const = gof.opt.ConstantFinder() default_optimizer.const = gof.opt.ConstantFinder()
......
...@@ -61,21 +61,41 @@ class Elemwise(Op): ...@@ -61,21 +61,41 @@ class Elemwise(Op):
return ret return ret
def c_validate_update(self): def c_validate_update(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd return valupd % sub
def c_validate_update_cleanup(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_validate_update_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd_cleanup return valupd_cleanup % sub
def c_code(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_code(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code return code % sub
def c_code_cleanup(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_code_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code() (valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code_cleanup return code_cleanup % sub
@classmethod @classmethod
def inplace_version(cls, dmap = {0:0}): def inplace_version(cls, dmap = {0:0}):
......
...@@ -25,20 +25,20 @@ class Double(ResultBase): ...@@ -25,20 +25,20 @@ class Double(ResultBase):
# def c_is_simple(self): return True # def c_is_simple(self): return True
def c_declare(self): def c_declare(self, name, sub):
return "double %(name)s; void* %(name)s_bad_thing;" return "double %(name)s; void* %(name)s_bad_thing;" % locals()
def c_init(self): def c_init(self, name, sub):
return """ return """
%(name)s = 0; %(name)s = 0;
%(name)s_bad_thing = malloc(100000); %(name)s_bad_thing = malloc(100000);
//printf("Initializing %(name)s\\n"); //printf("Initializing %(name)s\\n");
""" """ % locals()
def c_literal(self): def c_literal(self):
return str(self.data) return str(self.data)
def c_extract(self): def c_extract(self, name, sub):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "not a double!"); PyErr_SetString(PyExc_TypeError, "not a double!");
...@@ -47,23 +47,23 @@ class Double(ResultBase): ...@@ -47,23 +47,23 @@ class Double(ResultBase):
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL; %(name)s_bad_thing = NULL;
//printf("Extracting %(name)s\\n"); //printf("Extracting %(name)s\\n");
""" """ % dict(locals(), **sub)
def c_sync(self): def c_sync(self, name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = PyFloat_FromDouble(%(name)s); py_%(name)s = PyFloat_FromDouble(%(name)s);
if (!py_%(name)s) if (!py_%(name)s)
py_%(name)s = Py_None; py_%(name)s = Py_None;
//printf("Syncing %(name)s\\n"); //printf("Syncing %(name)s\\n");
""" """ % locals()
def c_cleanup(self): def c_cleanup(self, name, sub):
return """ return """
//printf("Cleaning up %(name)s\\n"); //printf("Cleaning up %(name)s\\n");
if (%(name)s_bad_thing) if (%(name)s_bad_thing)
free(%(name)s_bad_thing); free(%(name)s_bad_thing);
""" """ % locals()
class MyOp(Op): class MyOp(Op):
...@@ -80,43 +80,43 @@ class MyOp(Op): ...@@ -80,43 +80,43 @@ class MyOp(Op):
class Unary(MyOp): class Unary(MyOp):
nin = 1 nin = 1
def c_var_names(self): # def c_var_names(self):
return [['x'], ['z']] # return [['x'], ['z']]
class Binary(MyOp): class Binary(MyOp):
nin = 2 nin = 2
def c_var_names(self): # def c_var_names(self):
return [['x', 'y'], ['z']] # return [['x', 'y'], ['z']]
class Add(Binary): class Add(Binary):
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" return "%(z)s = %(x)s + %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data + self.inputs[1].data self.outputs[0].data = self.inputs[0].data + self.inputs[1].data
class Sub(Binary): class Sub(Binary):
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" return "%(z)s = %(x)s - %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = -10 # erroneous self.outputs[0].data = -10 # erroneous
class Mul(Binary): class Mul(Binary):
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" return "%(z)s = %(x)s * %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data * self.inputs[1].data self.outputs[0].data = self.inputs[0].data * self.inputs[1].data
class Div(Binary): class Div(Binary):
def c_validate_update(self): def c_validate_update(self, (x, y), (z, ), sub):
return """ return """
if (%(y)s == 0.0) { if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero"); PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s %(fail)s
} }
""" """ % dict(locals(), **sub)
def c_code(self): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" return "%(z)s = %(x)s / %(y)s;" % locals()
def perform(self): def perform(self):
self.outputs[0].data = self.inputs[0].data / self.inputs[1].data self.outputs[0].data = self.inputs[0].data / self.inputs[1].data
......
...@@ -66,15 +66,19 @@ class CodeBlock: ...@@ -66,15 +66,19 @@ class CodeBlock:
to jump to. It should also contain a key called 'failure_var' that contains to jump to. It should also contain a key called 'failure_var' that contains
the name of the variable that contains the error code. the name of the variable that contains the error code.
""" """
self.declare = declare % sub self.declare = declare #% sub
behavior_sub = copy(sub) # behavior_sub = copy(sub)
behavior_sub['fail'] = "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub # behavior_sub['fail'] = "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
self.behavior = behavior % behavior_sub self.behavior = behavior #% behavior_sub
# the dummy is because gcc throws an error when a label's right next to a closing # the dummy is because gcc throws an error when a label's right next to a closing
# brace (maybe there's an ignore flag for that...) # brace (maybe there's an ignore flag for that...)
# we need the label even if cleanup is empty because the behavior block jumps there # we need the label even if cleanup is empty because the behavior block jumps there
# on failure # on failure
self.cleanup = ("__label_%(id)i:\n" + cleanup + "\ndouble __DUMMY_%(id)i;\n") % sub self.cleanup = ("__label_%(id)i:\n"%sub + cleanup + "\ndouble __DUMMY_%(id)i;\n"%sub) #% sub
def failure_code(sub):
return "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
def code_gen(blocks): def code_gen(blocks):
...@@ -192,14 +196,14 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -192,14 +196,14 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are 1-element lists # TODO: add some error checking to make sure storage_<x> are 1-element lists
# and __ERROR is a 3-elements list. # and __ERROR is a 3-elements list.
struct_code = """ struct_code = """
struct %%(name)s { struct %(name)s {
PyObject* __ERROR; PyObject* __ERROR;
%(storage_decl)s %(storage_decl)s
%(struct_decl)s %(struct_decl)s
%%(name)s() {} %(name)s() {}
~%%(name)s(void) { ~%(name)s(void) {
cleanup(); cleanup();
} }
...@@ -232,47 +236,47 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -232,47 +236,47 @@ def struct_gen(args, struct_builders, blocks, sub):
# The get_<x> functions complete the return value of r.get_<x>() # The get_<x> functions complete the return value of r.get_<x>()
# with handling of the py_<name> variable. # with handling of the py_<name> variable.
def get_nothing(r): def get_nothing(r, name, sub):
"" ""
return "" return ""
def get_c_declare(r): def get_c_declare(r, name, sub):
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" """ % locals()
return pre + r.c_declare() return pre + r.c_declare(name, sub)
def get_c_init(r): def get_c_init(r, name, sub):
pre = "" """ pre = "" """
py_%(name)s = Py_None; py_%(name)s = Py_None;
""" """ % locals()
return pre + r.c_init() return pre + r.c_init(name, sub)
def get_c_extract(r): def get_c_extract(r, name, sub):
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
""" """ % locals()
return pre + r.c_extract() return pre + r.c_extract(name, sub)
def get_c_cleanup(r): def get_c_cleanup(r, name, sub):
post = """ post = """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
""" """ % locals()
return r.c_cleanup() + post return r.c_cleanup(name, sub) + post
def get_c_sync(r): def get_c_sync(r, name, sub):
return """ return """
if (!%%(failure_var)s) { if (!%(failure_var)s) {
%(sync)s %(sync)s
PyObject* old = PyList_GET_ITEM(storage_%%(name)s, 0); PyObject* old = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%%(name)s); Py_XINCREF(py_%(name)s);
PyList_SET_ITEM(storage_%%(name)s, 0, py_%%(name)s); PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
Py_XDECREF(old); Py_XDECREF(old);
} }
""" % dict(sync = r.c_sync()) """ % dict(sync = r.c_sync(name, sub), name = name, **sub)
def apply_policy(policy, r): def apply_policy(policy, r, name, sub):
""" """
policy -> list of functions that map a Result to a string, policy -> list of functions that map a Result to a string,
or a single such function or a single such function
...@@ -282,8 +286,8 @@ def apply_policy(policy, r): ...@@ -282,8 +286,8 @@ def apply_policy(policy, r):
if isinstance(r, (list, tuple)): if isinstance(r, (list, tuple)):
ret = "" ret = ""
for sub_policy in policy: for sub_policy in policy:
ret += sub_policy(r) ret += sub_policy(r, name, sub)
return policy(r) return policy(r, name, sub)
...@@ -304,11 +308,15 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub): ...@@ -304,11 +308,15 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
name = "V%i" % id name = "V%i" % id
symbol_table[result] = name symbol_table[result] = name
sub = copy(sub) sub = copy(sub)
sub['name'] = name # sub['name'] = name
sub['id'] = id sub['id'] = id
struct_builder = CodeBlock(*[apply_policy(policy, result) for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub) sub['fail'] = failure_code(sub)
struct_builder = CodeBlock(*[apply_policy(policy, result, name, sub)
for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub)
sub['id'] = id + 1 sub['id'] = id + 1
block = CodeBlock(*[apply_policy(policy, result) for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub) sub['fail'] = failure_code(sub)
block = CodeBlock(*[apply_policy(policy, result, name, sub)
for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub)
return struct_builder, block return struct_builder, block
...@@ -453,33 +461,39 @@ class CLinker(Linker): ...@@ -453,33 +461,39 @@ class CLinker(Linker):
# We populate sub with a mapping from the variable names specified by the op's c_var_names # We populate sub with a mapping from the variable names specified by the op's c_var_names
# method to the actual variable names that we will use. # method to the actual variable names that we will use.
ivnames, ovnames = op.c_var_names() ## ivnames, ovnames = op.c_var_names()
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames): ## for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
sub[vname] = symbol[result] ## sub[vname] = symbol[result]
isyms, osyms = [symbol[r] for r in op.inputs], [symbol[r] for r in op.outputs]
# Make the CodeBlock for c_validate_update # Make the CodeBlock for c_validate_update
try: validate_behavior = op.c_validate_update() sub['id'] = id
sub['fail'] = failure_code(sub)
try: validate_behavior = op.c_validate_update(isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
validate_behavior = "" validate_behavior = ""
try: validate_cleanup = op.c_validate_update_cleanup() try: validate_cleanup = op.c_validate_update_cleanup(isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
validate_cleanup = "" validate_cleanup = ""
sub['id'] = id
blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub)) blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub))
tasks.append((op, 'validate_update', id)) tasks.append((op, 'validate_update', id))
id += 1 id += 1
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
behavior = op.c_code() # this one must be implemented! sub['id'] = id
sub['fail'] = failure_code(sub)
behavior = op.c_code(isyms, osyms, sub) # this one must be implemented!
try: cleanup = op.c_code_cleanup() try: cleanup = op.c_code_cleanup(isyms, osyms, sub)
except AbstractFunctionError: except AbstractFunctionError:
cleanup = "" cleanup = ""
sub['id'] = id
blocks.append(CodeBlock("", behavior, cleanup, sub)) blocks.append(CodeBlock("", behavior, cleanup, sub))
tasks.append((op, 'code', id)) tasks.append((op, 'code', id))
id += 1 id += 1
...@@ -489,7 +503,7 @@ class CLinker(Linker): ...@@ -489,7 +503,7 @@ class CLinker(Linker):
args = [] args = []
args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)] args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)]
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var)) struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var, name = "%(name)s"))
# The hash calculated on the code identifies it so weave can cache properly. # The hash calculated on the code identifies it so weave can cache properly.
# (the hash has to be used outside of the support code because weave does not consider changes in the support code) # (the hash has to be used outside of the support code because weave does not consider changes in the support code)
......
...@@ -78,16 +78,13 @@ class Env(graph.Graph): ...@@ -78,16 +78,13 @@ class Env(graph.Graph):
# The inputs and outputs set bound the subgraph this Env operates on. # The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = list(inputs) self.inputs = list(inputs)
self.outputs = list(outputs) self.outputs = list(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# All ops in the subgraph defined by inputs and outputs are cached in _ops # All ops in the subgraph defined by inputs and outputs are cached in _ops
self._ops = set() self._ops = set()
# Ditto for results # Ditto for results
self._results = set(self.inputs) self._results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but # Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph. # are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),) # e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
...@@ -95,6 +92,9 @@ class Env(graph.Graph): ...@@ -95,6 +92,9 @@ class Env(graph.Graph):
# it will be removed from the set of orphans. # it will be removed from the set of orphans.
self._orphans = set(outputs) self._orphans = set(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# Maps results to ops that use them: # Maps results to ops that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v] # if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {} self._clients = {}
......
...@@ -179,15 +179,15 @@ class Op(object): ...@@ -179,15 +179,15 @@ class Op(object):
# C code generators # C code generators
# #
def c_var_names(self): # def c_var_names(self):
""" # """
Returns ([list of input names], [list of output names]) for # Returns ([list of input names], [list of output names]) for
use as C variables. # use as C variables.
""" # """
return [["i%i" % i for i in xrange(len(self.inputs))], # return [["i%i" % i for i in xrange(len(self.inputs))],
["o%i" % i for i in xrange(len(self.outputs))]] # ["o%i" % i for i in xrange(len(self.outputs))]]
def c_validate_update(self): def c_validate_update(self, inputs, outputs, sub):
""" """
Returns templated C code that checks that the inputs to this Returns templated C code that checks that the inputs to this
function can be worked on. If a failure occurs, set an function can be worked on. If a failure occurs, set an
...@@ -198,13 +198,13 @@ class Op(object): ...@@ -198,13 +198,13 @@ class Op(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_validate_update_cleanup(self): def c_validate_update_cleanup(self, inputs, outputs, sub):
""" """
Clean up things allocated by c_validate(). Clean up things allocated by c_validate().
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_code(self): def c_code(self, inputs, outputs, sub):
""" """
Returns templated C code that does the computation associated Returns templated C code that does the computation associated
to this Op. You may assume that input validation and output to this Op. You may assume that input validation and output
...@@ -215,7 +215,7 @@ class Op(object): ...@@ -215,7 +215,7 @@ class Op(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_code_cleanup(self): def c_code_cleanup(self, inputs, outputs, sub):
""" """
Clean up things allocated by c_code(). Clean up things allocated by c_code().
""" """
......
...@@ -175,13 +175,13 @@ class ResultBase(object): ...@@ -175,13 +175,13 @@ class ResultBase(object):
""" """
return False return False
def c_declare(self): def c_declare(self, name, sub):
""" """
Declares variables that will be instantiated by c_data_extract. Declares variables that will be instantiated by c_data_extract.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_extract(self): def c_extract(self, name, sub):
""" """
The code returned from this function must be templated using The code returned from this function must be templated using
"%(name)s", representing the name that the caller wants to "%(name)s", representing the name that the caller wants to
...@@ -193,7 +193,7 @@ class ResultBase(object): ...@@ -193,7 +193,7 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_cleanup(self): def c_cleanup(self, name, sub):
""" """
This returns C code that should deallocate whatever This returns C code that should deallocate whatever
c_data_extract allocated or decrease the reference counts. Do c_data_extract allocated or decrease the reference counts. Do
...@@ -201,7 +201,7 @@ class ResultBase(object): ...@@ -201,7 +201,7 @@ class ResultBase(object):
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
def c_sync(self): def c_sync(self, name, sub):
""" """
The code returned from this function must be templated using "%(name)s", The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result. representing the name that the caller wants to call this Result.
...@@ -297,28 +297,28 @@ class PythonResult(ResultBase): ...@@ -297,28 +297,28 @@ class PythonResult(ResultBase):
through %(name)s. through %(name)s.
""" """
def c_declare(self): def c_declare(self, name, sub):
return """ return """
PyObject* %(name)s; PyObject* %(name)s;
""" """ % locals()
def c_extract(self): def c_extract(self, name, sub):
return """ return """
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
%(name)s = py_%(name)s; %(name)s = py_%(name)s;
""" """ % locals()
def c_cleanup(self): def c_cleanup(self, name, sub):
return """ return """
Py_XDECREF(%(name)s); Py_XDECREF(%(name)s);
""" """ % locals()
def c_sync(self): def c_sync(self, name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = %(name)s; py_%(name)s = %(name)s;
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
""" """ % locals()
def same_properties(self, other): def same_properties(self, other):
return False return False
......
...@@ -18,10 +18,9 @@ def as_scalar(x, name = None): ...@@ -18,10 +18,9 @@ def as_scalar(x, name = None):
class Scalar(ResultBase): class Scalar(ResultBase):
def __init__(self, dtype, data = None, name=None): def __init__(self, dtype, name = None):
ResultBase.__init__(self, role = None, name = name)
self.dtype = dtype self.dtype = dtype
self.constant = False
ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self): def __get_constant(self):
return self._constant return self._constant
...@@ -40,60 +39,59 @@ class Scalar(ResultBase): ...@@ -40,60 +39,59 @@ class Scalar(ResultBase):
def same_properties(self, other): def same_properties(self, other):
return other.dtype == self.dtype return other.dtype == self.dtype
def mergeable(self, other): # def mergeable(self, other):
return getattr(self, 'constant', False) \ # return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \ # and getattr(other, 'constant', False) \
and self.data == other.data # and self.data == other.data
def dtype_specs(self): def dtype_specs(self):
return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype] return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype]
# def py_type(self):
# return {'float64': float}[self.dtype]
# def c_type(self):
# return {'float64': 'double'}[self.dtype]
# def c_from(self):
# return {'float64': 'PyFloat_FromDouble'}[self.dtype]
# def c_as(self):
# return {'float64': 'PyFloat_AsDouble'}[self.dtype]
def c_declare(self): def c_declare(self, name, sub):
return """ return """
%(dtype)s %%(name)s; %(dtype)s %(name)s;
typedef %(dtype)s %%(name)s_dtype; typedef %(dtype)s %(name)s_dtype;
""" % dict(dtype = self.dtype_specs()[1]) """ % dict(name = name, dtype = self.dtype_specs()[1])
def c_init(self): def c_init(self, name, sub):
return """ return """
%(name)s = 0; %(name)s = 0;
""" """ % locals()
def c_extract(self): def c_extract(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
if (!%(check)s(py_%%(name)s)) if (!%(check)s(py_%(name)s))
%%(fail)s %(fail)s
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s); %(name)s = (%(dtype)s)%(conv)s(py_%(name)s);
""" % dict(dtype = specs[1], """ % dict(sub,
name = name,
dtype = specs[1],
check = specs[2], check = specs[2],
conv = specs[3]) conv = specs[3])
def c_sync(self): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
Py_XDECREF(py_%%(name)s); Py_XDECREF(py_%(name)s);
py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s); py_%(name)s = %(conv)s((%(dtype)s)%(name)s);
if (!py_%%(name)s) if (!py_%(name)s)
py_%%(name)s = Py_None; py_%(name)s = Py_None;
""" % dict(dtype = specs[1], """ % dict(name = name,
dtype = specs[1],
conv = specs[4]) conv = specs[4])
def c_cleanup(self): def c_cleanup(self, name, sub):
return "" return ""
def __copy__(self):
"""
Return a copy of this instance (with its own attributes)
"""
cpy = self.__class__(self.dtype, self.name)
cpy.data = self.data
return cpy
class ScalarMixedOp(GuardedOp): class ScalarMixedOp(GuardedOp):
...@@ -104,8 +102,8 @@ class ScalarMixedOp(GuardedOp): ...@@ -104,8 +102,8 @@ class ScalarMixedOp(GuardedOp):
def __init__(self, *inputs): def __init__(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \ raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self, len(inputs), self.nin) % (self.__class__.__name__, len(inputs), self.nin))
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes)) o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes))
...@@ -125,14 +123,14 @@ class ScalarMixedOp(GuardedOp): ...@@ -125,14 +123,14 @@ class ScalarMixedOp(GuardedOp):
def perform(self): def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs]) self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def c_var_names(self): # def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl) # (self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames) # inames = utils.from_return_values(inames)
onames = utils.from_return_values(onames) # onames = utils.from_return_values(onames)
return [inames, onames] # return [inames, onames]
def c_code(self): # def c_code(self):
return self.c_impl(self.inputs, self.outputs) # return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
......
...@@ -4,71 +4,91 @@ import math ...@@ -4,71 +4,91 @@ import math
class Add(BinaryScalarOp): class Add(BinaryScalarOp):
identity = 0
def impl(self, x, y): def impl(self, x, y):
return x + y return x + y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" return "%(z)s = %(x)s + %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return gz, gz return gz, gz
class Sub(BinaryScalarOp): class Sub(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x - y return x - y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return gz, -gz return gz, -gz
class Mul(BinaryScalarOp): class Mul(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x * y return x * y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return mul(y, gz), mul(x, gz) return mul(y, gz), mul(x, gz)
class Div(BinaryScalarOp): class Div(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x / y return x / y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), gz): def grad(self, (x, y), (gz, )):
return div(gz, y), -div(mul(x, gz), y*y) return div(gz, y), -div(mul(x, gz), y*y)
class Pow(BinaryScalarOp): class Pow(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
return x ** y return x ** y
def c_impl(self, (x, y), z): def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" return "%(z)s = pow(%(x)s, %(y)s);" % locals()
class First(BinaryScalarOp):
def impl(self, x, y):
return x
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
class Second(BinaryScalarOp):
def impl(self, x, y):
return y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals()
class SquareDiff(BinaryScalarOp):
def impl(self, x, y):
diff = (x - y)
return diff * diff
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s; %(z)s *= %(z)s;" % locals()
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
return -gz return -gz
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" return "%(z)s = -%(x)s;" % locals()
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return 1 / x return 1 / x
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
return -gz / (x*x) return -gz / (x*x)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" return "%(z)s = 1 / %(x)s;" % locals()
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" return "%(z)s = log(%(x)s);" % locals()
class Exp(UnaryScalarOp): class Exp(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return math.exp(x) return math.exp(x)
def c_impl(self, x, z): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" return "%(z)s = exp(%(x)s);" % locals()
# class Sigmoid(UnaryComposite): # class Sigmoid(UnaryComposite):
......
...@@ -136,8 +136,12 @@ class _Op(BaseTensorOp): ...@@ -136,8 +136,12 @@ class _Op(BaseTensorOp):
onames = utils.from_return_values(onames) onames = utils.from_return_values(onames)
return [inames, onames] return [inames, onames]
def c_code(self): def c_code(self, input_names, output_names, sub):
return self.c_impl(self.inputs, self.outputs) sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
return self.c_impl(self.inputs, self.outputs) % sub
def c_impl(self, inputs, outputs): def c_impl(self, inputs, outputs):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -759,7 +763,7 @@ class Gemm(_Op): ...@@ -759,7 +763,7 @@ class Gemm(_Op):
return blas.ldflags() return blas.ldflags()
def c_var_names(self): def c_var_names(self):
return [['_z', '_a', '_x', '_y', '_b'], ['_zout']] return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self): def c_validate_update(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """ return """
if (%(_zout)s) if (%(_zout)s)
{ {
...@@ -770,10 +774,10 @@ class Gemm(_Op): ...@@ -770,10 +774,10 @@ class Gemm(_Op):
%(_zout)s = %(_z)s; %(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s); Py_INCREF(%(_zout)s);
} }
""" """ % locals()
def c_validate_update_cleanup(self): def c_validate_update_cleanup(self, ignore, _ignore, __ignore):
return "" return ""
def c_code(self): def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """ return """
int unit = 0; int unit = 0;
...@@ -913,7 +917,7 @@ class Gemm(_Op): ...@@ -913,7 +917,7 @@ class Gemm(_Op):
break; break;
} }
""" """ % dict(locals(), **sub)
gemm = gof.op.constructor(Gemm) gemm = gof.op.constructor(Gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论