提交 f982cbd7 authored 作者: Frederic's avatar Frederic

Elemwise now special case the code when all inputs are contiguous.

Do not work yet when the inputs is fortran, as the outputs is always c contig for now.
上级 6ab3dddc
......@@ -1100,8 +1100,9 @@ class Elemwise(Op):
# If all inputs and outputs are contiguous
# and the scalar op define optimized code for that case
# use it!
# use it! The scalar_op need to check the broadcast flag himself.
if all([o.ndim >= 1 for o in node.outputs]):
contig = None
try:
contig = self.scalar_op.c_code_contiguous(
node,
......@@ -1109,19 +1110,54 @@ class Elemwise(Op):
_inames,
onames,
sub)
# PyArray_ISONESEGMENT(arr)
# return true if arr is fortran or c contiguous.
cond = ' && '.join(["PyArray_ISONESEGMENT(%s)" % arr
for arr in _inames + onames])
except theano.gof.utils.MethodNotDefined:
# Try to make one generic version, this will help the
# compiler to vectorize the code as their won't be as
# many ptr and the stride will be hard coded.
if all([io.broadcastable == node.outputs[0].broadcastable or
all(io.broadcastable)
for io in node.inputs + node.outputs]):
z = onames[0]
contig = """
// All output have the same size
npy_intp n = PyArray_SIZE(%(z)s);
""" % locals()
index = ""
for x, var in zip(inames + onames,
inputs + node.outputs):
if not all(var.broadcastable):
contig += """
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
""" % locals()
index += """
dtype_%(x)s& %(x)s_i = %(x)s_ptr[i];
""" % locals()
else:
contig += """
dtype_%(x)s& %(x)s_i = ((dtype_%(x)s*) PyArray_DATA(%(x)s))[0];
""" % locals()
contig += """
for(int i=0; i<n; i++){
%(index)s
%(task_code)s;
}
""" % locals()
if contig is not None:
z = zip(inames + onames, inputs + node.outputs)
cond1 = ' && '.join(["PyArray_ISCONTIGUOUS(%s)" % arr
for arr, var in z
if not all(var.broadcastable)])
cond2 = ' && '.join(["PyArray_ISFORTRAN(%s)" % arr
for arr, var in z
if not all(var.broadcastable)])
loop = """
if(%(cond)s){
if((%(cond1)s) || (%(cond2)s)){
%(contig)s
}else{
%(loop)s
}
""" % locals()
except theano.gof.utils.MethodNotDefined:
pass
return decl, checks, alloc, loop
def c_code(self, node, nodename, inames, onames, sub):
......@@ -1140,7 +1176,7 @@ class Elemwise(Op):
return support_code
def c_code_cache_version_apply(self, node):
version = [8] # the version corresponding to the c code in this Op
version = [9] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论