提交 e752fc3d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

CAReduce loop reordering C-impl

上级 00a8a883
import builtins import builtins
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from textwrap import dedent
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
...@@ -361,12 +362,14 @@ class FixedOpCAReduce(CAReduce): ...@@ -361,12 +362,14 @@ class FixedOpCAReduce(CAReduce):
class NonZeroDimsCAReduce(FixedOpCAReduce): class NonZeroDimsCAReduce(FixedOpCAReduce):
def _c_all(self, node, name, inames, onames, sub): def _c_all(self, node, name, input_names, output_names, sub):
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub) setup, alloc, loop, cast = super()._c_all(
node, name, input_names, output_names, sub
)
# We add an additional check for zero-sized dimensions (This seems like # We add an additional check for zero-sized dimensions (This seems like
# something that could enabled in `elemwise_cgen.make_checks`.) # something that could enabled in `elemwise_cgen.make_checks`.)
iname = inames[0] [iname] = input_names
axis = self.axis axis = self.axis
if axis is None: if axis is None:
...@@ -378,8 +381,9 @@ class NonZeroDimsCAReduce(FixedOpCAReduce): ...@@ -378,8 +381,9 @@ class NonZeroDimsCAReduce(FixedOpCAReduce):
pattern_ = str(pattern)[1:-1] pattern_ = str(pattern)[1:-1]
decl += f"""int tosum[]={{{pattern_}}};""" setup = f"int tosum[]={{{pattern_}}};" + setup
alloc += f""" alloc += dedent(
f"""
for(int i=0;i<PyArray_NDIM({iname});i++){{ for(int i=0;i<PyArray_NDIM({iname});i++){{
if(PyArray_DIMS({iname})[i]==0 && tosum[i]){{ if(PyArray_DIMS({iname})[i]==0 && tosum[i]){{
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
...@@ -388,7 +392,8 @@ class NonZeroDimsCAReduce(FixedOpCAReduce): ...@@ -388,7 +392,8 @@ class NonZeroDimsCAReduce(FixedOpCAReduce):
}} }}
}} }}
""" """
return decl, checks, alloc, loop, end )
return setup, alloc, loop, cast
class Max(NonZeroDimsCAReduce): class Max(NonZeroDimsCAReduce):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论