提交 5229feba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement C code for ExtractDiagonal and ARange

Set view flag of ExtractDiagonal to True and respect by default
上级 d9a8471b
......@@ -3207,13 +3207,14 @@ def tile(
return A_replicated.reshape(tiled_shape)
class ARange(Op):
class ARange(COp):
"""Create an array containing evenly spaced values within a given interval.
Parameters and behaviour are the same as numpy.arange().
"""
# TODO: Arange should work with scalars as inputs, not arrays
__props__ = ("dtype",)
def __init__(self, dtype):
......@@ -3293,13 +3294,30 @@ class ARange(Op):
)
]
def perform(self, node, inp, out_):
start, stop, step = inp
(out,) = out_
start = start.item()
stop = stop.item()
step = step.item()
out[0] = np.arange(start, stop, step, dtype=self.dtype)
def perform(self, node, inputs, output_storage):
start, stop, step = inputs
output_storage[0][0] = np.arange(
start.item(), stop.item(), step.item(), dtype=self.dtype
)
def c_code(self, node, nodename, input_names, output_names, sub):
[start_name, stop_name, step_name] = input_names
[out_name] = output_names
typenum = np.dtype(self.dtype).num
return f"""
double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0];
double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0];
double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0];
//printf("start: %f, stop: %f, step: %f\\n", start, stop, step);
Py_XDECREF({out_name});
{out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum});
if (!{out_name}) {{
{sub["fail"]}
}}
"""
def c_code_cache_version(self):
return (0,)
def connection_pattern(self, node):
return [[True], [False], [True]]
......@@ -3685,8 +3703,7 @@ def inverse_permutation(perm):
)
# TODO: optimization to insert ExtractDiag with view=True
class ExtractDiag(Op):
class ExtractDiag(COp):
"""
Return specified diagonals.
......@@ -3742,7 +3759,7 @@ class ExtractDiag(Op):
__props__ = ("offset", "axis1", "axis2", "view")
def __init__(self, offset=0, axis1=0, axis2=1, view=False):
def __init__(self, offset=0, axis1=0, axis2=1, view=True):
self.view = view
if self.view:
self.view_map = {0: [0]}
......@@ -3765,24 +3782,74 @@ class ExtractDiag(Op):
if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
out_shape = [
st_dim
for i, st_dim in enumerate(x.type.shape)
if (dim1 := x.type.shape[self.axis1]) is not None and (
dim2 := x.type.shape[self.axis2]
) is not None:
offset = self.offset
if offset > 0:
diag_size = int(np.clip(dim2 - offset, 0, dim1))
elif offset < 0:
diag_size = int(np.clip(dim1 + offset, 0, dim2))
else:
diag_size = int(np.minimum(dim1, dim2))
else:
diag_size = None
out_shape = (
*(
dim
for i, dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
] + [None]
),
diag_size,
)
return Apply(
self,
[x],
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
[x.type.clone(dtype=x.dtype, shape=out_shape)()],
)
def perform(self, node, inputs, outputs):
def perform(self, node, inputs, output_storage):
(x,) = inputs
(z,) = outputs
z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
if not self.view:
z[0] = z[0].copy()
out = x.diagonal(self.offset, self.axis1, self.axis2)
if self.view:
try:
out.flags.writeable = True
except ValueError:
# We can't make this array writable
out = out.copy()
else:
out = out.copy()
output_storage[0][0] = out
def c_code(self, node, nodename, input_names, output_names, sub):
[x_name] = input_names
[out_name] = output_names
return f"""
Py_XDECREF({out_name});
{out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2});
if (!{out_name}) {{
{sub["fail"]} // Error already set by Numpy
}}
if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{
// Make output writeable if input was writeable
PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE);
}} else {{
// Make a copy
PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name});
Py_DECREF({out_name});
if (!{out_name}_copy) {{
{sub['fail']}; // Error already set by Numpy
}}
{out_name} = {out_name}_copy;
}}
"""
def c_code_cache_version(self):
return (0,)
def grad(self, inputs, gout):
# Avoid circular import
......@@ -3829,19 +3896,6 @@ class ExtractDiag(Op):
out_shape.append(diag_size)
return [tuple(out_shape)]
def __setstate__(self, state):
self.__dict__.update(state)
if self.view:
self.view_map = {0: [0]}
if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
def extract_diag(x):
warnings.warn(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论