提交 65bdde40 authored 作者: Frederic's avatar Frederic

use simpler elemwise code for [broadcasted] scalar.

上级 74ba5cf4
...@@ -1091,7 +1091,10 @@ class Elemwise(Op): ...@@ -1091,7 +1091,10 @@ class Elemwise(Op):
%(undefs)s %(undefs)s
} }
""" % locals() """ % locals()
if all([o.ndim <= 1 for o in node.outputs]): if all([o.ndim <= 1 for o in node.outputs] or
# Use simpler code when output ndim == 0 or 1
# or for broadcated scalar.
all(node.outputs[0].broadcastable)):
if nnested: if nnested:
all_code = [("", "")] * (nnested - 1) + [("", code)] + [""] all_code = [("", "")] * (nnested - 1) + [("", code)] + [""]
else: else:
...@@ -1113,7 +1116,9 @@ class Elemwise(Op): ...@@ -1113,7 +1116,9 @@ class Elemwise(Op):
# If all inputs and outputs are contiguous # If all inputs and outputs are contiguous
# and the scalar op define optimized code for that case # and the scalar op define optimized code for that case
# use it! The scalar_op need to check the broadcast flag himself. # use it! The scalar_op need to check the broadcast flag himself.
if all([o.ndim >= 1 for o in node.outputs]): if (all([o.ndim >= 1 for o in node.outputs]) and
# Don't use the contig code for broadcasted scalar.
not all(node.outputs[0].broadcastable)):
contig = None contig = None
try: try:
contig = self.scalar_op.c_code_contiguous( contig = self.scalar_op.c_code_contiguous(
...@@ -1188,7 +1193,7 @@ class Elemwise(Op): ...@@ -1188,7 +1193,7 @@ class Elemwise(Op):
return support_code return support_code
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
version = [10] # the version corresponding to the c code in this Op version = [11] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(self.scalar_op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论