提交 65bdde40 authored 作者: Frederic's avatar Frederic

use simpler elemwise code for [broadcasted] scalar.

上级 74ba5cf4
......@@ -1091,7 +1091,10 @@ class Elemwise(Op):
%(undefs)s
}
""" % locals()
if all([o.ndim <= 1 for o in node.outputs]):
if all([o.ndim <= 1 for o in node.outputs] or
# Use simpler code when output ndim == 0 or 1
# or for broadcated scalar.
all(node.outputs[0].broadcastable)):
if nnested:
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
else:
......@@ -1113,7 +1116,9 @@ class Elemwise(Op):
# If all inputs and outputs are contiguous
# and the scalar op define optimized code for that case
# use it! The scalar_op need to check the broadcast flag himself.
if all([o.ndim >= 1 for o in node.outputs]):
if (all([o.ndim >= 1 for o in node.outputs]) and
# Don't use the contig code for broadcasted scalar.
not all(node.outputs[0].broadcastable)):
contig = None
try:
contig = self.scalar_op.c_code_contiguous(
......@@ -1159,10 +1164,10 @@ class Elemwise(Op):
z = zip(inames + onames, inputs + node.outputs)
cond1 = ' && '.join(["PyArray_ISCONTIGUOUS(%s)" % arr
for arr, var in z
if not all(var.broadcastable)])
if not all(var.broadcastable)])
cond2 = ' && '.join(["PyArray_ISFORTRAN(%s)" % arr
for arr, var in z
if not all(var.broadcastable)])
if not all(var.broadcastable)])
loop = """
if((%(cond1)s) || (%(cond2)s)){
%(contig)s
......@@ -1188,7 +1193,7 @@ class Elemwise(Op):
return support_code
def c_code_cache_version_apply(self, node):
version = [10] # the version corresponding to the c code in this Op
version = [11] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论