broadcast, reduce are working in python and c

上级 88f7858f
......@@ -8,31 +8,23 @@ from scalar_ops import *
def inputs():
x = modes.build_eval(as_scalar(1.0, 'x'))
y = modes.build_eval(as_scalar(2.0, 'y'))
z = modes.build_eval(as_scalar(3.0, 'z'))
x = modes.build(as_scalar(1.0, 'x'))
y = modes.build(as_scalar(2.0, 'y'))
z = modes.build(as_scalar(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_ScalarOps(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
assert e.data == 1.5
def test_1(self):
def test_straightforward(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
g = env([x, y], [e])
fn = gof.cc.CLinker(g).make_function()
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
assert e.data == 1.5
if __name__ == '__main__':
......
......@@ -691,7 +691,7 @@ class t_gemm(unittest.TestCase):
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
t_gemm = None
if __name__ == '__main__':
unittest.main()
......@@ -108,53 +108,53 @@ class BaseTensor(ResultBase):
#
# C codegen stubs
#
def c_declare(self):
def c_declare(self, name, sub):
return """
PyArrayObject* %%(name)s;
int type_num_%%(name)s;
typedef %(dtype)s dtype_%%(name)s;
""" % dict(dtype = self.dtype_specs()[1])
PyArrayObject* %(name)s;
int type_num_%(name)s;
typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name = name, dtype = self.dtype_specs()[1])
def c_init(self):
def c_init(self, name, sub):
return """
%%(name)s = NULL;
type_num_%%(name)s = %(type_num)s;
""" % dict(type_num = self.dtype_specs()[2])
%(name)s = NULL;
type_num_%(name)s = %(type_num)s;
""" % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_extract(self):
def c_extract(self, name, sub):
return """
%%(name)s = NULL;
type_num_%%(name)s = %(type_num)s;
if (py_%%(name)s == Py_None) {
// We can either fail here or set %%(name)s to NULL and rely on Ops using
%(name)s = NULL;
type_num_%(name)s = %(type_num)s;
if (py_%(name)s == Py_None) {
// We can either fail here or set %(name)s to NULL and rely on Ops using
// tensors to handle the NULL case, but if they fail to do so they'll end up
// with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%%(fail)s
//%%(name)s = NULL;
%(fail)s
//%(name)s = NULL;
}
else if (!PyArray_Check(py_%%(name)s)) {
else if (!PyArray_Check(py_%(name)s)) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%%(fail)s
%(fail)s
}
else if (((PyArrayObject*)py_%%(name)s)->descr->type_num != %(type_num)s) {
else if (((PyArrayObject*)py_%(name)s)->descr->type_num != %(type_num)s) {
PyErr_SetString(PyExc_ValueError, "expected %(type_num)s");
%%(fail)s
%(fail)s
}
else {
%%(name)s = (PyArrayObject*)(py_%%(name)s);
Py_XINCREF(%%(name)s);
%(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s);
}
""" % dict(type_num = self.dtype_specs()[2])
""" % dict(sub, name = name, type_num = self.dtype_specs()[2])
def c_cleanup(self):
def c_cleanup(self, name, sub):
return """
if (%(name)s) {
Py_XDECREF(%(name)s);
}
"""
""" % locals()
def c_sync(self):
def c_sync(self, name, sub):
return """
if (!%(name)s) {
Py_XDECREF(py_%(name)s);
......@@ -165,7 +165,7 @@ class BaseTensor(ResultBase):
py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s);
}
"""
""" % locals()
def c_headers(self):
return []
......
......@@ -16,8 +16,8 @@ exec_opt.optimizer = None
def default_optimizer(env):
#TODO: pass tests with these un-commented
default_optimizer.const(env)
default_optimizer.merge(env)
# default_optimizer.const(env)
# default_optimizer.merge(env)
pass
default_optimizer.merge = gof.opt.MergeOptimizer()
default_optimizer.const = gof.opt.ConstantFinder()
......
......@@ -61,21 +61,41 @@ class Elemwise(Op):
return ret
def c_validate_update(self):
def c_validate_update(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd
return valupd % sub
def c_validate_update_cleanup(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_validate_update_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd_cleanup
return valupd_cleanup % sub
def c_code(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_code(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code
return code % sub
def c_code_cleanup(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
def c_code_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code_cleanup
return code_cleanup % sub
@classmethod
def inplace_version(cls, dmap = {0:0}):
......
......@@ -25,20 +25,20 @@ class Double(ResultBase):
# def c_is_simple(self): return True
def c_declare(self):
return "double %(name)s; void* %(name)s_bad_thing;"
def c_declare(self, name, sub):
return "double %(name)s; void* %(name)s_bad_thing;" % locals()
def c_init(self):
def c_init(self, name, sub):
return """
%(name)s = 0;
%(name)s_bad_thing = malloc(100000);
//printf("Initializing %(name)s\\n");
"""
""" % locals()
def c_literal(self):
return str(self.data)
def c_extract(self):
def c_extract(self, name, sub):
return """
if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "not a double!");
......@@ -47,23 +47,23 @@ class Double(ResultBase):
%(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL;
//printf("Extracting %(name)s\\n");
"""
""" % dict(locals(), **sub)
def c_sync(self):
def c_sync(self, name, sub):
return """
Py_XDECREF(py_%(name)s);
py_%(name)s = PyFloat_FromDouble(%(name)s);
if (!py_%(name)s)
py_%(name)s = Py_None;
//printf("Syncing %(name)s\\n");
"""
""" % locals()
def c_cleanup(self):
def c_cleanup(self, name, sub):
return """
//printf("Cleaning up %(name)s\\n");
if (%(name)s_bad_thing)
free(%(name)s_bad_thing);
"""
""" % locals()
class MyOp(Op):
......@@ -80,43 +80,43 @@ class MyOp(Op):
class Unary(MyOp):
nin = 1
def c_var_names(self):
return [['x'], ['z']]
# def c_var_names(self):
# return [['x'], ['z']]
class Binary(MyOp):
nin = 2
def c_var_names(self):
return [['x', 'y'], ['z']]
# def c_var_names(self):
# return [['x', 'y'], ['z']]
class Add(Binary):
def c_code(self):
return "%(z)s = %(x)s + %(y)s;"
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals()
def perform(self):
self.outputs[0].data = self.inputs[0].data + self.inputs[1].data
class Sub(Binary):
def c_code(self):
return "%(z)s = %(x)s - %(y)s;"
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals()
def perform(self):
self.outputs[0].data = -10 # erroneous
class Mul(Binary):
def c_code(self):
return "%(z)s = %(x)s * %(y)s;"
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals()
def perform(self):
self.outputs[0].data = self.inputs[0].data * self.inputs[1].data
class Div(Binary):
def c_validate_update(self):
def c_validate_update(self, (x, y), (z, ), sub):
return """
if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s
}
"""
def c_code(self):
return "%(z)s = %(x)s / %(y)s;"
""" % dict(locals(), **sub)
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" % locals()
def perform(self):
self.outputs[0].data = self.inputs[0].data / self.inputs[1].data
......
......@@ -66,15 +66,19 @@ class CodeBlock:
to jump to. It should also contain a key called 'failure_var' that contains
the name of the variable that contains the error code.
"""
self.declare = declare % sub
behavior_sub = copy(sub)
behavior_sub['fail'] = "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
self.behavior = behavior % behavior_sub
self.declare = declare #% sub
# behavior_sub = copy(sub)
# behavior_sub['fail'] = "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
self.behavior = behavior #% behavior_sub
# the dummy is because gcc throws an error when a label's right next to a closing
# brace (maybe there's an ignore flag for that...)
# we need the label even if cleanup is empty because the behavior block jumps there
# on failure
self.cleanup = ("__label_%(id)i:\n" + cleanup + "\ndouble __DUMMY_%(id)i;\n") % sub
self.cleanup = ("__label_%(id)i:\n"%sub + cleanup + "\ndouble __DUMMY_%(id)i;\n"%sub) #% sub
def failure_code(sub):
return "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
def code_gen(blocks):
......@@ -192,14 +196,14 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are 1-element lists
# and __ERROR is a 3-elements list.
struct_code = """
struct %%(name)s {
struct %(name)s {
PyObject* __ERROR;
%(storage_decl)s
%(struct_decl)s
%%(name)s() {}
~%%(name)s(void) {
%(name)s() {}
~%(name)s(void) {
cleanup();
}
......@@ -232,47 +236,47 @@ def struct_gen(args, struct_builders, blocks, sub):
# The get_<x> functions complete the return value of r.get_<x>()
# with handling of the py_<name> variable.
def get_nothing(r):
def get_nothing(r, name, sub):
""
return ""
def get_c_declare(r):
def get_c_declare(r, name, sub):
pre = """
PyObject* py_%(name)s;
"""
return pre + r.c_declare()
""" % locals()
return pre + r.c_declare(name, sub)
def get_c_init(r):
def get_c_init(r, name, sub):
pre = "" """
py_%(name)s = Py_None;
"""
return pre + r.c_init()
""" % locals()
return pre + r.c_init(name, sub)
def get_c_extract(r):
def get_c_extract(r, name, sub):
pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%(name)s);
"""
return pre + r.c_extract()
""" % locals()
return pre + r.c_extract(name, sub)
def get_c_cleanup(r):
def get_c_cleanup(r, name, sub):
post = """
Py_XDECREF(py_%(name)s);
"""
return r.c_cleanup() + post
""" % locals()
return r.c_cleanup(name, sub) + post
def get_c_sync(r):
def get_c_sync(r, name, sub):
return """
if (!%%(failure_var)s) {
if (!%(failure_var)s) {
%(sync)s
PyObject* old = PyList_GET_ITEM(storage_%%(name)s, 0);
Py_XINCREF(py_%%(name)s);
PyList_SET_ITEM(storage_%%(name)s, 0, py_%%(name)s);
PyObject* old = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%(name)s);
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
Py_XDECREF(old);
}
""" % dict(sync = r.c_sync())
""" % dict(sync = r.c_sync(name, sub), name = name, **sub)
def apply_policy(policy, r):
def apply_policy(policy, r, name, sub):
"""
policy -> list of functions that map a Result to a string,
or a single such function
......@@ -282,8 +286,8 @@ def apply_policy(policy, r):
if isinstance(r, (list, tuple)):
ret = ""
for sub_policy in policy:
ret += sub_policy(r)
return policy(r)
ret += sub_policy(r, name, sub)
return policy(r, name, sub)
......@@ -304,11 +308,15 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
name = "V%i" % id
symbol_table[result] = name
sub = copy(sub)
sub['name'] = name
# sub['name'] = name
sub['id'] = id
struct_builder = CodeBlock(*[apply_policy(policy, result) for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub)
sub['fail'] = failure_code(sub)
struct_builder = CodeBlock(*[apply_policy(policy, result, name, sub)
for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub)
sub['id'] = id + 1
block = CodeBlock(*[apply_policy(policy, result) for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub)
sub['fail'] = failure_code(sub)
block = CodeBlock(*[apply_policy(policy, result, name, sub)
for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub)
return struct_builder, block
......@@ -453,33 +461,39 @@ class CLinker(Linker):
# We populate sub with a mapping from the variable names specified by the op's c_var_names
# method to the actual variable names that we will use.
ivnames, ovnames = op.c_var_names()
## ivnames, ovnames = op.c_var_names()
sub = dict(failure_var = failure_var)
for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
sub[vname] = symbol[result]
## for result, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
## sub[vname] = symbol[result]
isyms, osyms = [symbol[r] for r in op.inputs], [symbol[r] for r in op.outputs]
# Make the CodeBlock for c_validate_update
try: validate_behavior = op.c_validate_update()
sub['id'] = id
sub['fail'] = failure_code(sub)
try: validate_behavior = op.c_validate_update(isyms, osyms, sub)
except AbstractFunctionError:
validate_behavior = ""
try: validate_cleanup = op.c_validate_update_cleanup()
try: validate_cleanup = op.c_validate_update_cleanup(isyms, osyms, sub)
except AbstractFunctionError:
validate_cleanup = ""
sub['id'] = id
blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub))
tasks.append((op, 'validate_update', id))
id += 1
# Make the CodeBlock for c_code
behavior = op.c_code() # this one must be implemented!
sub['id'] = id
sub['fail'] = failure_code(sub)
behavior = op.c_code(isyms, osyms, sub) # this one must be implemented!
try: cleanup = op.c_code_cleanup()
try: cleanup = op.c_code_cleanup(isyms, osyms, sub)
except AbstractFunctionError:
cleanup = ""
sub['id'] = id
blocks.append(CodeBlock("", behavior, cleanup, sub))
tasks.append((op, 'code', id))
id += 1
......@@ -489,7 +503,7 @@ class CLinker(Linker):
args = []
args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)]
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var))
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var, name = "%(name)s"))
# The hash calculated on the code identifies it so weave can cache properly.
# (the hash has to be used outside of the support code because weave does not consider changes in the support code)
......
......@@ -78,16 +78,13 @@ class Env(graph.Graph):
# The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = list(inputs)
self.outputs = list(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# All ops in the subgraph defined by inputs and outputs are cached in _ops
self._ops = set()
# Ditto for results
self._results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
......@@ -95,6 +92,9 @@ class Env(graph.Graph):
# it will be removed from the set of orphans.
self._orphans = set(outputs)
for feature_class in uniq_features(features):
self.add_feature(feature_class, False)
# Maps results to ops that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {}
......
......@@ -179,15 +179,15 @@ class Op(object):
# C code generators
#
def c_var_names(self):
"""
Returns ([list of input names], [list of output names]) for
use as C variables.
"""
return [["i%i" % i for i in xrange(len(self.inputs))],
["o%i" % i for i in xrange(len(self.outputs))]]
def c_validate_update(self):
# def c_var_names(self):
# """
# Returns ([list of input names], [list of output names]) for
# use as C variables.
# """
# return [["i%i" % i for i in xrange(len(self.inputs))],
# ["o%i" % i for i in xrange(len(self.outputs))]]
def c_validate_update(self, inputs, outputs, sub):
"""
Returns templated C code that checks that the inputs to this
function can be worked on. If a failure occurs, set an
......@@ -198,13 +198,13 @@ class Op(object):
"""
raise AbstractFunctionError()
def c_validate_update_cleanup(self):
def c_validate_update_cleanup(self, inputs, outputs, sub):
"""
Clean up things allocated by c_validate().
"""
raise AbstractFunctionError()
def c_code(self):
def c_code(self, inputs, outputs, sub):
"""
Returns templated C code that does the computation associated
to this Op. You may assume that input validation and output
......@@ -215,7 +215,7 @@ class Op(object):
"""
raise AbstractFunctionError()
def c_code_cleanup(self):
def c_code_cleanup(self, inputs, outputs, sub):
"""
Clean up things allocated by c_code().
"""
......
......@@ -175,13 +175,13 @@ class ResultBase(object):
"""
return False
def c_declare(self):
def c_declare(self, name, sub):
"""
Declares variables that will be instantiated by c_data_extract.
"""
raise AbstractFunctionError()
def c_extract(self):
def c_extract(self, name, sub):
"""
The code returned from this function must be templated using
"%(name)s", representing the name that the caller wants to
......@@ -193,7 +193,7 @@ class ResultBase(object):
"""
raise AbstractFunctionError()
def c_cleanup(self):
def c_cleanup(self, name, sub):
"""
This returns C code that should deallocate whatever
c_data_extract allocated or decrease the reference counts. Do
......@@ -201,7 +201,7 @@ class ResultBase(object):
"""
raise AbstractFunctionError()
def c_sync(self):
def c_sync(self, name, sub):
"""
The code returned from this function must be templated using "%(name)s",
representing the name that the caller wants to call this Result.
......@@ -297,28 +297,28 @@ class PythonResult(ResultBase):
through %(name)s.
"""
def c_declare(self):
def c_declare(self, name, sub):
return """
PyObject* %(name)s;
"""
""" % locals()
def c_extract(self):
def c_extract(self, name, sub):
return """
Py_XINCREF(py_%(name)s);
%(name)s = py_%(name)s;
"""
""" % locals()
def c_cleanup(self):
def c_cleanup(self, name, sub):
return """
Py_XDECREF(%(name)s);
"""
""" % locals()
def c_sync(self):
def c_sync(self, name, sub):
return """
Py_XDECREF(py_%(name)s);
py_%(name)s = %(name)s;
Py_XINCREF(py_%(name)s);
"""
""" % locals()
def same_properties(self, other):
return False
......
......@@ -18,10 +18,9 @@ def as_scalar(x, name = None):
class Scalar(ResultBase):
def __init__(self, dtype, data = None, name=None):
def __init__(self, dtype, name = None):
ResultBase.__init__(self, role = None, name = name)
self.dtype = dtype
self.constant = False
ResultBase.__init__(self, role = None, data = data, name = name)
def __get_constant(self):
return self._constant
......@@ -40,60 +39,59 @@ class Scalar(ResultBase):
def same_properties(self, other):
return other.dtype == self.dtype
def mergeable(self, other):
return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \
and self.data == other.data
# def mergeable(self, other):
# return getattr(self, 'constant', False) \
# and getattr(other, 'constant', False) \
# and self.data == other.data
def dtype_specs(self):
return {'float64': (float, 'double', 'PyFloat_Check', 'PyFloat_AsDouble', 'PyFloat_FromDouble')}[self.dtype]
# def py_type(self):
# return {'float64': float}[self.dtype]
# def c_type(self):
# return {'float64': 'double'}[self.dtype]
# def c_from(self):
# return {'float64': 'PyFloat_FromDouble'}[self.dtype]
# def c_as(self):
# return {'float64': 'PyFloat_AsDouble'}[self.dtype]
def c_declare(self):
def c_declare(self, name, sub):
return """
%(dtype)s %%(name)s;
typedef %(dtype)s %%(name)s_dtype;
""" % dict(dtype = self.dtype_specs()[1])
%(dtype)s %(name)s;
typedef %(dtype)s %(name)s_dtype;
""" % dict(name = name, dtype = self.dtype_specs()[1])
def c_init(self):
def c_init(self, name, sub):
return """
%(name)s = 0;
"""
""" % locals()
def c_extract(self):
def c_extract(self, name, sub):
specs = self.dtype_specs()
return """
if (!%(check)s(py_%%(name)s))
%%(fail)s
%%(name)s = (%(dtype)s)%(conv)s(py_%%(name)s);
""" % dict(dtype = specs[1],
if (!%(check)s(py_%(name)s))
%(fail)s
%(name)s = (%(dtype)s)%(conv)s(py_%(name)s);
""" % dict(sub,
name = name,
dtype = specs[1],
check = specs[2],
conv = specs[3])
def c_sync(self):
def c_sync(self, name, sub):
specs = self.dtype_specs()
return """
Py_XDECREF(py_%%(name)s);
py_%%(name)s = %(conv)s((%(dtype)s)%%(name)s);
if (!py_%%(name)s)
py_%%(name)s = Py_None;
""" % dict(dtype = specs[1],
Py_XDECREF(py_%(name)s);
py_%(name)s = %(conv)s((%(dtype)s)%(name)s);
if (!py_%(name)s)
py_%(name)s = Py_None;
""" % dict(name = name,
dtype = specs[1],
conv = specs[4])
def c_cleanup(self):
def c_cleanup(self, name, sub):
return ""
def __copy__(self):
"""
Return a copy of this instance (with its own attributes)
"""
cpy = self.__class__(self.dtype, self.name)
cpy.data = self.data
return cpy
class ScalarMixedOp(GuardedOp):
......@@ -104,8 +102,8 @@ class ScalarMixedOp(GuardedOp):
def __init__(self, *inputs):
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)") \
% (self, len(inputs), self.nin)
raise TypeError("Wrong number of inputs for %s (got %i, expected %i)" \
% (self.__class__.__name__, len(inputs), self.nin))
i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes))
......@@ -125,14 +123,14 @@ class ScalarMixedOp(GuardedOp):
def perform(self):
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def c_var_names(self):
(self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
inames = utils.from_return_values(inames)
onames = utils.from_return_values(onames)
return [inames, onames]
# def c_var_names(self):
# (self, inames, onames), _1, _2, _3 = inspect.getargspec(self.c_impl)
# inames = utils.from_return_values(inames)
# onames = utils.from_return_values(onames)
# return [inames, onames]
def c_code(self):
return self.c_impl(self.inputs, self.outputs)
# def c_code(self):
# return self.c_impl(self.inputs, self.outputs)
def upcast(dtype, *dtypes):
......
......@@ -4,71 +4,91 @@ import math
class Add(BinaryScalarOp):
identity = 0
def impl(self, x, y):
return x + y
def c_impl(self, (x, y), z):
return "%(z)s = %(x)s + %(y)s;"
def grad(self, (x, y), gz):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz, gz
class Sub(BinaryScalarOp):
def impl(self, x, y):
return x - y
def c_impl(self, (x, y), z):
return "%(z)s = %(x)s - %(y)s;"
def grad(self, (x, y), gz):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return gz, -gz
class Mul(BinaryScalarOp):
def impl(self, x, y):
return x * y
def c_impl(self, (x, y), z):
return "%(z)s = %(x)s * %(y)s;"
def grad(self, (x, y), gz):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return mul(y, gz), mul(x, gz)
class Div(BinaryScalarOp):
def impl(self, x, y):
return x / y
def c_impl(self, (x, y), z):
return "%(z)s = %(x)s / %(y)s;"
def grad(self, (x, y), gz):
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" % locals()
def grad(self, (x, y), (gz, )):
return div(gz, y), -div(mul(x, gz), y*y)
class Pow(BinaryScalarOp):
def impl(self, x, y):
return x ** y
def c_impl(self, (x, y), z):
return "%(z)s = pow(%(x)s, %(y)s);"
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = pow(%(x)s, %(y)s);" % locals()
class First(BinaryScalarOp):
def impl(self, x, y):
return x
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s;" % locals()
class Second(BinaryScalarOp):
def impl(self, x, y):
return y
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals()
class SquareDiff(BinaryScalarOp):
def impl(self, x, y):
diff = (x - y)
return diff * diff
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s; %(z)s *= %(z)s;" % locals()
class Neg(UnaryScalarOp):
def impl(self, x):
return -x
def grad(self, x, gz):
def grad(self, (x, ), (gz, )):
return -gz
def c_impl(self, x, z):
return "%(z)s = -%(x)s;"
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = -%(x)s;" % locals()
class Inv(UnaryScalarOp):
def impl(self, x):
return 1 / x
def grad(self, x, gz):
def grad(self, (x, ), (gz, )):
return -gz / (x*x)
def c_impl(self, x, z):
return "%(z)s = 1 / %(x)s;"
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = 1 / %(x)s;" % locals()
class Log(UnaryScalarOp):
def impl(self, x):
return math.log(x)
def c_impl(self, x, z):
return "%(z)s = log(%(x)s);"
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = log(%(x)s);" % locals()
class Exp(UnaryScalarOp):
def impl(self, x):
return math.exp(x)
def c_impl(self, x, z):
return "%(z)s = exp(%(x)s);"
def c_code(self, (x, ), (z, ), sub):
return "%(z)s = exp(%(x)s);" % locals()
# class Sigmoid(UnaryComposite):
......
......@@ -136,8 +136,12 @@ class _Op(BaseTensorOp):
onames = utils.from_return_values(onames)
return [inames, onames]
def c_code(self):
return self.c_impl(self.inputs, self.outputs)
def c_code(self, input_names, output_names, sub):
sub = dict(sub)
icvn, ocvn = self.c_var_names()
for real, tosub in zip(input_names + output_names, icvn + ocvn):
sub[tosub] = real
return self.c_impl(self.inputs, self.outputs) % sub
def c_impl(self, inputs, outputs):
raise AbstractFunctionError()
......@@ -759,7 +763,7 @@ class Gemm(_Op):
return blas.ldflags()
def c_var_names(self):
return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self):
def c_validate_update(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """
if (%(_zout)s)
{
......@@ -770,10 +774,10 @@ class Gemm(_Op):
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
"""
def c_validate_update_cleanup(self):
""" % locals()
def c_validate_update_cleanup(self, ignore, _ignore, __ignore):
return ""
def c_code(self):
def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """
int unit = 0;
......@@ -913,7 +917,7 @@ class Gemm(_Op):
break;
}
"""
""" % dict(locals(), **sub)
gemm = gof.op.constructor(Gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论