elemwise is functional

上级 b51535d9
from copy import copy from copy import copy
from gof import Op
from gof.utils import AbstractFunctionError
from tensor import _Op
class BasicElemwise(Op): class Elemwise(_Op):
def var_desc(self): def var_desc(self):
raise AbstractFunctionError() raise AbstractFunctionError()
...@@ -42,9 +45,56 @@ class BasicElemwise(Op): ...@@ -42,9 +45,56 @@ class BasicElemwise(Op):
def c_code_finalize(self): def c_code_finalize(self):
raise AbstractFunctionError() raise AbstractFunctionError()
def c_code(self): def __c_code(self):
try: before = self.c_code_init()
except AbstractFunctionError: before = ""
during = self.c_code_foreach()
try: after = self.c_code_finalize()
except AbstractFunctionError: after = ""
idesc, odesc = self.var_desc()
input_loop_vars = [x[0] for x in idesc if x[1]]
output_loop_vars = [x[0] for x in odesc if x[1]]
aliases = {}
try:
dmap = self.destroy_map()
except:
dmap = {}
if dmap != {}:
for (oname, isloopvar), output in zip(odesc, self.outputs):
if isloopvar:
for input in dmap.get(output, []):
iname = idesc[self.inputs.index(input)][0]
aliases["%("+oname+")s"] = "%("+iname+")s"
output_loop_vars.remove(oname)
ret = elemwise_wrap(before, during, after,
["%("+v+")s" for v in input_loop_vars],
["%("+v+")s" for v in output_loop_vars],
aliases)
return ret
def c_validate_update(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd
def c_validate_update_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return valupd_cleanup
def c_code(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code
def c_code_cleanup(self):
(valupd, valupd_cleanup), (code, code_cleanup) = self.__c_code()
return code_cleanup
@classmethod @classmethod
def inplace_version(cls): def inplace_version(cls):
class Ret(cls, Destroyer): class Ret(cls, Destroyer):
...@@ -77,32 +127,32 @@ class Looper: ...@@ -77,32 +127,32 @@ class Looper:
return ret return ret
general_looper = Looper(declare = """ general_looper = Looper(declare = """
PyArrayIterObject* iter_%(name)s = NULL; PyArrayIterObject* %(iter)s_%(name)s = NULL;
dtype_%(name)s%(type_suffix)s %(name)s_i; //dtype_%(name)s& %(name)s_i;
""", """,
init = """ init = """
iter_%(name)s = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(name)s); %(iter)s_%(name)s = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(name)s);
if (iter_%(name)s == NULL) { if (%(iter)s_%(name)s == NULL) {
PyErr_SetString(PyExc_ValueError, "Could not make an iterator over variable %(name)s."); PyErr_SetString(PyExc_ValueError, "Could not make an iterator over variable %(name)s.");
%%(fail)s; %%(fail)s;
} }
""", """,
acquire = "%(name)s_i = *((dtype_%(name)s*)iter_%(name)s->dataptr);", acquire = "dtype_%(name)s& %(name)s_i = *((dtype_%(name)s*)%(iter)s_%(name)s->dataptr);",
next = "PyArray_ITER_NEXT(iter_%(name)s);", next = "PyArray_ITER_NEXT(%(iter)s_%(name)s);",
cleanup = "if (iter_%(name)s) Py_DECREF(iter_%(name)s);") cleanup = "if (%(iter)s_%(name)s) Py_DECREF(%(iter)s_%(name)s);")
contiguous_looper = Looper(declare = """ contiguous_looper = Looper(declare = """
dtype_%(name)s* __restrict__ iter_%(name)s = NULL; dtype_%(name)s* %(iter)s_%(name)s = NULL;
dtype_%(name)s%(type_suffix)s %(name)s_i; //dtype_%(name)s %(name)s_i;
""", """,
init = "iter_%(name)s = (dtype_%(name)s*)PyArray_DATA(%(name)s);", init = "%(iter)s_%(name)s = (dtype_%(name)s*)PyArray_DATA(%(name)s);",
acquire = "%(name)s_i = *iter_%(name)s;", acquire = "dtype_%(name)s& %(name)s_i = *%(iter)s_%(name)s;",
next = "iter_%(name)s++;", next = "%(iter)s_%(name)s++;",
cleanup = "") cleanup = "")
alias_looper = Looper(declare = "dtype_%(name)s%(type_suffix)s %(name)s_i;", alias_looper = Looper(declare = "", #dtype_%(name)s%(type_suffix)s %(name)s_i;",
init = "", init = "",
acquire = "%(name)s_i = %(other)s_i;", acquire = "dtype_%(name)s& %(name)s_i = %(other)s_i;",
next = "", next = "",
cleanup = "") cleanup = "")
...@@ -116,7 +166,6 @@ def elemwise_loopcode(loopcode, loopers): ...@@ -116,7 +166,6 @@ def elemwise_loopcode(loopcode, loopers):
template[block_type] = make_block(loopers, block_type) template[block_type] = make_block(loopers, block_type)
code = """ code = """
%(declare)s
%(init)s %(init)s
while (__SIZE--) { while (__SIZE--) {
%(acquire)s %(acquire)s
...@@ -129,106 +178,120 @@ def elemwise_loopcode(loopcode, loopers): ...@@ -129,106 +178,120 @@ def elemwise_loopcode(loopcode, loopers):
%(cleanup)s %(cleanup)s
""" % template """ % template
return code, cleanup return template['declare'], code, cleanup
def elemwise_wrap(beforeloop, inloop, afterloop, input_loop_vars, output_loop_vars, aliases): def elemwise_wrap(beforeloop, inloop, afterloop, input_loop_vars, output_loop_vars, aliases):
if len(input_loop_vars) > 1: if len(input_loop_vars) == 0:
validate = """ raise Exception("You need at least one input loop variable.")
npy_intp nd = %(first_loop_var)s->nd;
npy_intp* dims = %(first_loop_var)s->dimensions; validate = """
npy_intp* dims2; npy_intp nd = %(first_loop_var)s->nd;
""" % dict(first_loop_var = input_loop_vars[0]) npy_intp* dims = %(first_loop_var)s->dimensions;
npy_intp* dims2;
for other_loop_var in input_loop_vars[1:]: """ % dict(first_loop_var = input_loop_vars[0])
validate += """
if (%(other_loop_var)s->nd != nd) { for other_loop_var in input_loop_vars[1:]:
PyErr_SetString(PyExc_ValueError, \"The number of dimensions of the inputs do not match.\"); validate += """
%%(fail)s if (%(other_loop_var)s->nd != nd) {
} PyErr_SetString(PyExc_ValueError, \"The number of dimensions of the inputs do not match.\");
dims2 = %(other_loop_var)s->dimensions; %%(fail)s
for (int i = 0; i < nd; i++) { }
if (dims2[i] != dims[i]) { dims2 = %(other_loop_var)s->dimensions;
PyErr_SetString(PyExc_ValueError, \"The dimensions of the inputs do not match.\"); for (int i = 0; i < nd; i++) {
%%(fail)s; if (dims2[i] != dims[i]) {
} PyErr_SetString(PyExc_ValueError, \"The dimensions of the inputs do not match.\");
%%(fail)s;
} }
""" % dict(other_loop_var = other_loop_var) }
else: """ % dict(other_loop_var = other_loop_var)
validate = ""
update = "" update = ""
for output_loop_var in output_loop_vars: for output_loop_var in output_loop_vars:
update += """ update += """
if (!%(output_loop_var)s) { if (!%(output_loop_var)s) {
%(output_loop_var)s = PyArray_EMPTY(nd, dims, type_num_%(output_loop_var)s, 0); %(output_loop_var)s = (PyArrayObject*)PyArray_EMPTY(nd, dims, type_num_%(output_loop_var)s, 0);
} }
else { else {
PyArray_Dims new_dims(nd, dims); PyArray_Dims new_dims;
PyObject* success = PyArray_Resize(%(output_loop_var)s, &new_dims, 0); new_dims.len = nd;
new_dims.ptr = dims;
PyObject* success = PyArray_Resize(%(output_loop_var)s, &new_dims, 0, PyArray_CORDER);
if (!success) { if (!success) {
// If we can't resize the ndarray we have we can allocate a new one. // If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear(); PyErr_Clear();
Py_XDECREF(%(output_loop_var)s); Py_XDECREF(%(output_loop_var)s);
%(output_loop_var)s = PyArray_EMPTY(nd, dims, type_num_%(output_loop_var)s, 0); %(output_loop_var)s = (PyArrayObject*)PyArray_EMPTY(nd, dims, type_num_%(output_loop_var)s, 0);
} }
} }
if (!%(output_loop_var)s) { if (!%(output_loop_var)s) {
%%(fail)s %%(fail)s
} }
""" """ % dict(output_loop_var = output_loop_var)
validate_update = validate + update validate_update = validate + update
validate_update_cleanup = ""
general_loopers = [] general_loopers = []
general_loopers += [general_looper.fill(name = loop_var, type_suffix = "") for loop_var in input_loop_vars] general_loopers += [general_looper.fill(name = loop_var, iter = "iter_general_") for loop_var in input_loop_vars + output_loop_vars]
general_loopers += [general_looper.fill(name = loop_var, type_suffix = "&") for loop_var in output_loop_vars] general_loopers += [alias_looper.fill(name = v1, other = v2, iter = "iter_general_") for v1, v2 in aliases.items()]
general_loopers += [general_looper.fill(name = loop_var, type_suffix = "&") for v1, v2 in aliases.items()]
contig_loopers = []
contig_loopers += [contiguous_looper.fill(name = loop_var, iter = "iter_contig_") for loop_var in input_loop_vars + output_loop_vars]
contig_loopers += [alias_looper.fill(name = v1, other = v2, iter = "iter_contig_") for v1, v2 in aliases.items()]
general_decl, general_loop, general_loop_cleanup = elemwise_loopcode(inloop, general_loopers)
contig_decl, contig_loop, contig_loop_cleanup = elemwise_loopcode(inloop, contig_loopers)
contiguity_check = "".join(["all_c_contiguous &= PyArray_ISCARRAY(%(loop_var)s);\n" \
all_loop_vars = loop_vars + writable_loop_vars "all_f_contiguous &= PyArray_ISFARRAY(%(loop_var)s);\n" \
v1 = (loop_vars + writable_loop_vars)[0] % dict(loop_var = loop_var)
template = dict( for loop_var in input_loop_vars + output_loop_vars])
v1 = v1,
check_init = check_init % dict(loop_var = v1), v1 = input_loop_vars[0]
check = "\n".join([check % dict(loop_var = loop_var) for loop_var in loop_vars + writable_loop_vars if loop_var is not v1]),
beforeloop = beforeloop,
general_loop = elemwise_loopcode(
inloop,
general_init, general_next, general_acquire, general_cleanup,
loop_vars, writable_loop_vars, aliases),
contiguous_loop = elemwise_loopcode(
inloop,
contiguous_init, contiguous_next, contiguous_acquire, contiguous_cleanup,
loop_vars, writable_loop_vars, aliases),
contiguity_check = "".join(["all_c_contiguous &= PyArray_ISCARRAY(%(loop_var)s);\n" \
"all_f_contiguous &= PyArray_ISFARRAY(%(loop_var)s);\n" \
% dict(loop_var = loop_var)
for loop_var in all_loop_vars]),
afterloop = afterloop)
code = """ code = """
{ %(general_decl)s
%(check_init)s %(contig_decl)s
%(check)s npy_intp __SIZE = PyArray_SIZE(%(v1)s);
}
npy_intp __elemwise_size = PyArray_SIZE(%(v1)s);
%(beforeloop)s %(beforeloop)s
bool all_c_contiguous = 1; bool all_c_contiguous = 1;
bool all_f_contiguous = 1; bool all_f_contiguous = 1;
%(contiguity_check)s %(contiguity_check)s
if (all_c_contiguous || all_f_contiguous) { if (all_c_contiguous || all_f_contiguous) {
%(contiguous_loop)s %(contig_loop)s
} }
else { else {
%(general_loop)s %(general_loop)s
} }
%(afterloop)s %(afterloop)s
""" % template """ % locals()
return code cleanup = """
if (all_c_contiguous || all_f_contiguous) {
%(contig_loop_cleanup)s
}
else {
%(general_loop_cleanup)s
}
""" % locals()
# print """
# VALIDATE_UPDATE
# %(validate_update)s
# CODE
# %(code)s
# CLEANUP
# %(cleanup)s
# VU_CLEANUP
# %(validate_update_cleanup)s
# """ % locals()
return (validate_update, validate_update_cleanup), (code, cleanup)
...@@ -276,351 +339,351 @@ def elemwise_wrap(beforeloop, inloop, afterloop, input_loop_vars, output_loop_va ...@@ -276,351 +339,351 @@ def elemwise_wrap(beforeloop, inloop, afterloop, input_loop_vars, output_loop_va
import core # import core
def elemwise_loopcode(loopcode, init_template, next_template, acquire_template, cleanup_template, loop_vars, writable_loop_vars, aliases): # def elemwise_loopcode(loopcode, init_template, next_template, acquire_template, cleanup_template, loop_vars, writable_loop_vars, aliases):
all_loop_vars = loop_vars + writable_loop_vars # all_loop_vars = loop_vars + writable_loop_vars
template = dict( # template = dict(
init = "".join([init_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]), # init = "".join([init_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]),
next = "".join([next_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]), # next = "".join([next_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]),
cleanup = "".join([cleanup_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]), # cleanup = "".join([cleanup_template % dict(loop_var = loop_var) for loop_var in all_loop_vars]),
idefs = "".join([("%(loop_var)s_dtype %(loop_var)s_i = " + acquire_template + ";\n") # idefs = "".join([("%(loop_var)s_dtype %(loop_var)s_i = " + acquire_template + ";\n")
% dict(loop_var = loop_var) for loop_var in loop_vars]), # % dict(loop_var = loop_var) for loop_var in loop_vars]),
odefs = "".join([("%(loop_var)s_dtype& %(loop_var)s_i = " + acquire_template + ";\n") # odefs = "".join([("%(loop_var)s_dtype& %(loop_var)s_i = " + acquire_template + ";\n")
% dict(loop_var = loop_var) for loop_var in writable_loop_vars]), # % dict(loop_var = loop_var) for loop_var in writable_loop_vars]),
aliasdefs = "".join(["%(v1)s_dtype %(v1)s_i = %(v2)s_i;\n" % dict(v1=v1, v2=v2) # aliasdefs = "".join(["%(v1)s_dtype %(v1)s_i = %(v2)s_i;\n" % dict(v1=v1, v2=v2)
for v1, v2 in aliases.items()]), # for v1, v2 in aliases.items()]),
loopcode = loopcode # loopcode = loopcode
) # )
code = """ # code = """
%(init)s # %(init)s
while (__elemwise_size--) { # while (__elemwise_size--) {
%(idefs)s # %(idefs)s
%(odefs)s # %(odefs)s
%(aliasdefs)s # %(aliasdefs)s
%(loopcode)s # %(loopcode)s
%(next)s # %(next)s
} # }
%(cleanup)s # %(cleanup)s
""" % template # """ % template
return code # return code
def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, aliases): # def elemwise_wrap(beforeloop, inloop, afterloop, loop_vars, writable_loop_vars, aliases):
check_init = """ # check_init = """
npy_intp nd = %(loop_var)s->nd; # npy_intp nd = %(loop_var)s->nd;
npy_intp* dims = %(loop_var)s->dimensions; # npy_intp* dims = %(loop_var)s->dimensions;
npy_intp* dims2; # npy_intp* dims2;
""" # """
check = """ # check = """
if (%(loop_var)s->nd != nd) { # if (%(loop_var)s->nd != nd) {
PyErr_SetString(PyExc_ValueError, \"The number of dimensions of the inputs do not match.\"); # PyErr_SetString(PyExc_ValueError, \"The number of dimensions of the inputs do not match.\");
} # }
dims2 = %(loop_var)s->dimensions; # dims2 = %(loop_var)s->dimensions;
for (int i = 0; i < nd; i++) { # for (int i = 0; i < nd; i++) {
if (dims2[i] != dims[i]) { # if (dims2[i] != dims[i]) {
PyErr_SetString(PyExc_ValueError, \"The dimensions of the inputs do not match.\"); # PyErr_SetString(PyExc_ValueError, \"The dimensions of the inputs do not match.\");
return 1; # return 1;
} # }
} # }
""" # """
general_init = "PyArrayIterObject* %(loop_var)s_iter = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(loop_var)s);\n" # general_init = "PyArrayIterObject* %(loop_var)s_iter = (PyArrayIterObject*)PyArray_IterNew((PyObject*)%(loop_var)s);\n"
# "if (%(loop_var)s_iter == NULL) {\n" \ # # "if (%(loop_var)s_iter == NULL) {\n" \
# " PyErr_SetString(PyExc_ValueError, \"Could not make an iterator over variable %(loop_var)s.\");\n" \ # # " PyErr_SetString(PyExc_ValueError, \"Could not make an iterator over variable %(loop_var)s.\");\n" \
# " return 1;\n" \ # # " return 1;\n" \
# "}\n" # # "}\n"
general_next = "PyArray_ITER_NEXT(%(loop_var)s_iter);\n" # general_next = "PyArray_ITER_NEXT(%(loop_var)s_iter);\n"
general_acquire = "*((%(loop_var)s_dtype*)%(loop_var)s_iter->dataptr)"; # general_acquire = "*((%(loop_var)s_dtype*)%(loop_var)s_iter->dataptr)";
general_cleanup = "if (%(loop_var)s_iter) Py_DECREF(%(loop_var)s_iter);\n"; # general_cleanup = "if (%(loop_var)s_iter) Py_DECREF(%(loop_var)s_iter);\n";
contiguous_init = "%(loop_var)s_dtype* __restrict__ %(loop_var)s_iter = (%(loop_var)s_dtype*)PyArray_DATA(%(loop_var)s);\n" # contiguous_init = "%(loop_var)s_dtype* __restrict__ %(loop_var)s_iter = (%(loop_var)s_dtype*)PyArray_DATA(%(loop_var)s);\n"
contiguous_next = "%(loop_var)s_iter++;\n" # contiguous_next = "%(loop_var)s_iter++;\n"
contiguous_acquire = "*%(loop_var)s_iter" # contiguous_acquire = "*%(loop_var)s_iter"
contiguous_cleanup = "" # contiguous_cleanup = ""
all_loop_vars = loop_vars + writable_loop_vars # all_loop_vars = loop_vars + writable_loop_vars
v1 = (loop_vars + writable_loop_vars)[0] # v1 = (loop_vars + writable_loop_vars)[0]
template = dict( # template = dict(
v1 = v1, # v1 = v1,
check_init = check_init % dict(loop_var = v1), # check_init = check_init % dict(loop_var = v1),
check = "\n".join([check % dict(loop_var = loop_var) for loop_var in loop_vars + writable_loop_vars if loop_var is not v1]), # check = "\n".join([check % dict(loop_var = loop_var) for loop_var in loop_vars + writable_loop_vars if loop_var is not v1]),
beforeloop = beforeloop, # beforeloop = beforeloop,
general_loop = elemwise_loopcode( # general_loop = elemwise_loopcode(
inloop, # inloop,
general_init, general_next, general_acquire, general_cleanup, # general_init, general_next, general_acquire, general_cleanup,
loop_vars, writable_loop_vars, aliases), # loop_vars, writable_loop_vars, aliases),
contiguous_loop = elemwise_loopcode( # contiguous_loop = elemwise_loopcode(
inloop, # inloop,
contiguous_init, contiguous_next, contiguous_acquire, contiguous_cleanup, # contiguous_init, contiguous_next, contiguous_acquire, contiguous_cleanup,
loop_vars, writable_loop_vars, aliases), # loop_vars, writable_loop_vars, aliases),
contiguity_check = "".join(["all_c_contiguous &= PyArray_ISCARRAY(%(loop_var)s);\n" \ # contiguity_check = "".join(["all_c_contiguous &= PyArray_ISCARRAY(%(loop_var)s);\n" \
"all_f_contiguous &= PyArray_ISFARRAY(%(loop_var)s);\n" \ # "all_f_contiguous &= PyArray_ISFARRAY(%(loop_var)s);\n" \
% dict(loop_var = loop_var) # % dict(loop_var = loop_var)
for loop_var in all_loop_vars]), # for loop_var in all_loop_vars]),
afterloop = afterloop) # afterloop = afterloop)
code = """ # code = """
{ # {
%(check_init)s # %(check_init)s
%(check)s # %(check)s
} # }
npy_intp __elemwise_size = PyArray_SIZE(%(v1)s); # npy_intp __elemwise_size = PyArray_SIZE(%(v1)s);
%(beforeloop)s # %(beforeloop)s
bool all_c_contiguous = 1; # bool all_c_contiguous = 1;
bool all_f_contiguous = 1; # bool all_f_contiguous = 1;
%(contiguity_check)s # %(contiguity_check)s
if (all_c_contiguous || all_f_contiguous) { # if (all_c_contiguous || all_f_contiguous) {
%(contiguous_loop)s # %(contiguous_loop)s
} # }
else { # else {
%(general_loop)s # %(general_loop)s
} # }
%(afterloop)s # %(afterloop)s
""" % template # """ % template
return code # return code
class elemwise(omega_op): # class elemwise(omega_op):
@staticmethod # @staticmethod
def __clsinit__(cls, name, bases, dct): # def __clsinit__(cls, name, bases, dct):
for fname in ['c_init', 'c_foreach', 'c_finalize']: # for fname in ['c_init', 'c_foreach', 'c_finalize']:
gof.make_static(cls, fname) # gof.make_static(cls, fname)
# make impl, grad, etc. static methods # # make impl, grad, etc. static methods
omega_op.__clsinit__(cls, name, bases, dct) # omega_op.__clsinit__(cls, name, bases, dct)
def TOGO_specs(self): # def TOGO_specs(self):
try: # try:
return self.specs(*[input.spec for input in self.inputs]) # return self.specs(*[input.spec for input in self.inputs])
except NotImplementedError: # except NotImplementedError:
inames, onames = self.variable_names() # inames, onames = self.variable_names()
linames, lonames = self.loop_variables() # linames, lonames = self.loop_variables()
for oname in onames: # for oname in onames:
if oname not in lonames: # if oname not in lonames:
raise Exception("cannot infer a specification automatically for variable " \ # raise Exception("cannot infer a specification automatically for variable " \
"%s.%s because it is not part of the elementwise loop - "\ # "%s.%s because it is not part of the elementwise loop - "\
"please override the specs method" % (self.__class__.__name__, oname)) # "please override the specs method" % (self.__class__.__name__, oname))
shape, dtype = None, None # shape, dtype = None, None
for iname, input in zip(inames, self.inputs): # for iname, input in zip(inames, self.inputs):
if iname in linames: # if iname in linames:
if input.spec: # if input.spec:
shape = input.spec[2] # shape = input.spec[2]
if shape is None: # if shape is None:
raise Exception("cannot infer a specification automatically for output variables " \ # raise Exception("cannot infer a specification automatically for output variables " \
"because there is no input variable in the loop from which to get the shape, "\ # "because there is no input variable in the loop from which to get the shape, "\
"or their shape is unknown") # "or their shape is unknown")
try: # try:
dtype = core.upcast(*[input.spec[1] # dtype = core.upcast(*[input.spec[1]
for iname, input in zip(inames, self.inputs) # for iname, input in zip(inames, self.inputs)
if input.spec[0] is numpy.ndarray]) # if input.spec[0] is numpy.ndarray])
except IndexError: # except IndexError:
raise Exception("not all numpy inputs are specified") # raise Exception("not all numpy inputs are specified")
dmap = self.destroy_map() # dmap = self.destroy_map()
res = [] # res = []
for output in self.outputs: # for output in self.outputs:
inplace_inputs = dmap.get(output, []) # inplace_inputs = dmap.get(output, [])
if inplace_inputs: # if inplace_inputs:
assert len(inplace_inputs) == 1 # assert len(inplace_inputs) == 1
res.append(inplace_inputs[0].spec) # res.append(inplace_inputs[0].spec)
else: # else:
res.append((numpy.ndarray, dtype, shape)) # res.append((numpy.ndarray, dtype, shape))
if self.nout == 1: # if self.nout == 1:
return res[0] # return res[0]
else: # else:
return res # return res
def TOGO_alloc(self, except_list = []): # def TOGO_alloc(self, except_list = []):
dmap = self.destroy_map() # dmap = self.destroy_map()
vmap = self.view_map() # vmap = self.view_map()
gof.PythonOp.alloc(self, except_list = except_list + dmap.keys()) # gof.PythonOp.alloc(self, except_list = except_list + dmap.keys())
for output, (input, ) in dmap.items(): # for output, (input, ) in dmap.items():
if output not in except_list: # if output not in except_list:
output.set_value(input.data) # output.set_value(input.data)
def refresh_shape(self): # def refresh_shape(self):
"""Make the output have the right stuff""" # """Make the output have the right stuff"""
if len(self.outputs) > 1: # if len(self.outputs) > 1:
raise NotImplementedError('multiple outputs') # raise NotImplementedError('multiple outputs')
dmap = self.destroy_map() # dmap = self.destroy_map()
vmap = self.view_map() # vmap = self.view_map()
if dmap != {} or vmap != {}: # if dmap != {} or vmap != {}:
raise NotImplementedError('destroys or views confuse things', # raise NotImplementedError('destroys or views confuse things',
self.__class__, dmap, vmap) # self.__class__, dmap, vmap)
# take the shape of the leftmost loop_variable input # # take the shape of the leftmost loop_variable input
inames, onames = self.variable_names() # inames, onames = self.variable_names()
linames, lonames = self.loop_variables() # linames, lonames = self.loop_variables()
unknown_output_names = [n for n in onames if n not in lonames] # unknown_output_names = [n for n in onames if n not in lonames]
if len(unknown_output_names): # if len(unknown_output_names):
raise Exception("cannot infer a specification automatically for variables " \ # raise Exception("cannot infer a specification automatically for variables " \
"%s.{%s} because it is not part of the elementwise loop - "\ # "%s.{%s} because it is not part of the elementwise loop - "\
"please override the specs method" % # "please override the specs method" %
(self.__class__.__name__, str(unknown_output_names))) # (self.__class__.__name__, str(unknown_output_names)))
# shape is leftmost loop-variable input # # shape is leftmost loop-variable input
input_loop_shapes = [i.shape for n,i in zip(inames, self.inputs) if n in linames] # input_loop_shapes = [i.shape for n,i in zip(inames, self.inputs) if n in linames]
if len(input_loop_shapes) == 0: # if len(input_loop_shapes) == 0:
raise Exception("cannot infer a specification automatically for output variables " \ # raise Exception("cannot infer a specification automatically for output variables " \
"because there is no input loop variable ") # "because there is no input loop variable ")
for i in xrange(1,len(input_loop_shapes)): # for i in xrange(1,len(input_loop_shapes)):
if input_loop_shapes[i] != input_loop_shapes[0]: # if input_loop_shapes[i] != input_loop_shapes[0]:
raise Exception("Input loop variables have different shapes", self.__class__) # raise Exception("Input loop variables have different shapes", self.__class__)
return input_loop_shapes[0] # return input_loop_shapes[0]
def refresh_dtype(self): # def refresh_dtype(self):
return core.upcast(*[i.dtype for i in self.inputs if hasattr(i, 'dtype')]) # return core.upcast(*[i.dtype for i in self.inputs if hasattr(i, 'dtype')])
@classmethod # @classmethod
def set_impl(cls, impl): # def set_impl(cls, impl):
gof.lib.make_static(cls, 'impl') # gof.lib.make_static(cls, 'impl')
@staticmethod # @staticmethod
def is_loop_var(name): # def is_loop_var(name):
return name.endswith("_i") # return name.endswith("_i")
@staticmethod # @staticmethod
def extract_name(name): # def extract_name(name):
if name.endswith("_i"): # if name.endswith("_i"):
return name[:-2] # return name[:-2]
else: # else:
return name # return name
@classmethod # @classmethod
def variable_names(cls): # def variable_names(cls):
(inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach) # (inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach)
spec = ([cls.extract_name(name) for name in inames], # spec = ([cls.extract_name(name) for name in inames],
[cls.extract_name(name) for name in onames]) # [cls.extract_name(name) for name in onames])
if cls.c_init is not elemwise.c_init: # if cls.c_init is not elemwise.c_init:
(inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_init) # (inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_init)
assert spec == (list(inames), list(onames)) # assert spec == (list(inames), list(onames))
if cls.c_finalize is not elemwise.c_finalize: # if cls.c_finalize is not elemwise.c_finalize:
(inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_finalize) # (inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_finalize)
assert spec == (list(inames), list(onames)) # assert spec == (list(inames), list(onames))
return spec # return spec
@classmethod # @classmethod
def loop_variables(cls): # def loop_variables(cls):
(inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach) # (inames, onames), _1, _2, _3 = inspect.getargspec(cls.c_foreach)
return ([cls.extract_name(name) for name in inames if cls.is_loop_var(name)], # return ([cls.extract_name(name) for name in inames if cls.is_loop_var(name)],
[cls.extract_name(name) for name in onames if cls.is_loop_var(name)]) # [cls.extract_name(name) for name in onames if cls.is_loop_var(name)])
def _c_init(self): # def _c_init(self):
return self.c_init(self.inputs, self.outputs) # return self.c_init(self.inputs, self.outputs)
def c_init(inputs, outputs): # def c_init(inputs, outputs):
return "" # return ""
def _c_foreach(self): # def _c_foreach(self):
return self.c_foreach(self.inputs, self.outputs) # return self.c_foreach(self.inputs, self.outputs)
def c_foreach(inputs, outputs): # def c_foreach(inputs, outputs):
raise NotImplementedError() # raise NotImplementedError()
def _c_finalize(self): # def _c_finalize(self):
return self.c_finalize(self.inputs, self.outputs) # return self.c_finalize(self.inputs, self.outputs)
def c_finalize(inputs, outputs): # def c_finalize(inputs, outputs):
return "" # return ""
def c_code(self, converters = None, elemwise_wrap = elemwise_wrap): # def c_code(self, converters = None, elemwise_wrap = elemwise_wrap):
def mangle(name): # def mangle(name):
if name.endswith("_i"): # if name.endswith("_i"):
return name[:-2] # return name[:-2]
else: # else:
return name # return name
try: # try:
self._c_impl() # self._c_impl()
raise Exception("c_impl is not used by elemwise ops - define behavior in c_foreach instead") # raise Exception("c_impl is not used by elemwise ops - define behavior in c_foreach instead")
except NotImplementedError: # except NotImplementedError:
pass # pass
before = self._c_init() # before = self._c_init()
during = self._c_foreach() # during = self._c_foreach()
after = self._c_finalize() # after = self._c_finalize()
(inames, onames) = self.variable_names() # (inames, onames) = self.variable_names()
(linames, lonames) = self.loop_variables() # (linames, lonames) = self.loop_variables()
aliases = {} # aliases = {}
dmap = self.destroy_map() # dmap = self.destroy_map()
if dmap != {}: # if dmap != {}:
for oname, output in zip(onames, self.outputs): # for oname, output in zip(onames, self.outputs):
if oname in lonames: # if oname in lonames:
for input in dmap.get(output, []): # for input in dmap.get(output, []):
aliases[inames[self.inputs.index(input)]] = oname # aliases[inames[self.inputs.index(input)]] = oname
behavior = elemwise_wrap(before, during, after, # behavior = elemwise_wrap(before, during, after,
[name for name in linames if name not in aliases], # [name for name in linames if name not in aliases],
lonames, # lonames,
aliases) # aliases)
return cgen(self.__class__.__name__, behavior, inames + onames, self.inputs + self.outputs, converters) # return cgen(self.__class__.__name__, behavior, inames + onames, self.inputs + self.outputs, converters)
@classmethod # @classmethod
def inplace_version(cls, dmap = {0: 0}): # def inplace_version(cls, dmap = {0: 0}):
inames, onames = cls.variable_names() # inames, onames = cls.variable_names()
linames, lonames = cls.loop_variables() # linames, lonames = cls.loop_variables()
for i, oname in enumerate(onames): # for i, oname in enumerate(onames):
if i in dmap: # if i in dmap:
assert oname in lonames # assert oname in lonames
class C(cls): # class C(cls):
def destroy_map(self): # def destroy_map(self):
assert cls.destroy_map(self) == {} # assert cls.destroy_map(self) == {}
ret = {} # ret = {}
for output, input in dmap.items(): # for output, input in dmap.items():
ret[self.outputs[output]] = [self.inputs[input]] # ret[self.outputs[output]] = [self.inputs[input]]
return ret # return ret
def _impl(self): # def _impl(self):
if self.impl is not cls.impl: # if self.impl is not cls.impl:
# If the user sets his own inplace operation, we use it # # If the user sets his own inplace operation, we use it
return cls._impl(self) # return cls._impl(self)
else: # else:
res = cls._impl(self) # res = cls._impl(self)
if isinstance(res, (list, tuple)): # if isinstance(res, (list, tuple)):
res = copy.copy(res) # res = copy.copy(res)
else: # else:
res = [res] # res = [res]
for output, input in dmap.items(): # for output, input in dmap.items():
# The default implementation returned a copy, so we just # # The default implementation returned a copy, so we just
# overwrite the original input with the contents of that copy # # overwrite the original input with the contents of that copy
# This is not meant to be efficient, only correct. # # This is not meant to be efficient, only correct.
# # #
# TODO: change this to use set_value_inplace # # TODO: change this to use set_value_inplace
a = self.inputs[input].data # a = self.inputs[input].data
a[:] = res[output] # a[:] = res[output]
res[output] = a # res[output] = a
if len(res) == 1: # if len(res) == 1:
return res[0] # return res[0]
else: # else:
return res # return res
if dmap == {0:0}: # if dmap == {0:0}:
C.__name__ = cls.__name__ + "_inplace" % dmap # C.__name__ = cls.__name__ + "_inplace" % dmap
else: # else:
C.__name__ = cls.__name__ + "_inplace%s" % dmap # C.__name__ = cls.__name__ + "_inplace%s" % dmap
return C # return C
...@@ -338,7 +338,7 @@ class CLinker(Linker): ...@@ -338,7 +338,7 @@ class CLinker(Linker):
except AbstractFunctionError: except AbstractFunctionError:
validate_behavior = "" validate_behavior = ""
try: validate_behavior = op.c_validate_update_cleanup() try: validate_cleanup = op.c_validate_update_cleanup()
except AbstractFunctionError: except AbstractFunctionError:
validate_cleanup = "" validate_cleanup = ""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论