提交 5229feba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement C code for ExtractDiagonal and ARange

Set view flag of ExtractDiagonal to True and respect by default
上级 d9a8471b
...@@ -3207,13 +3207,14 @@ def tile( ...@@ -3207,13 +3207,14 @@ def tile(
return A_replicated.reshape(tiled_shape) return A_replicated.reshape(tiled_shape)
class ARange(Op): class ARange(COp):
"""Create an array containing evenly spaced values within a given interval. """Create an array containing evenly spaced values within a given interval.
Parameters and behaviour are the same as numpy.arange(). Parameters and behaviour are the same as numpy.arange().
""" """
# TODO: Arange should work with scalars as inputs, not arrays
__props__ = ("dtype",) __props__ = ("dtype",)
def __init__(self, dtype): def __init__(self, dtype):
...@@ -3293,13 +3294,30 @@ class ARange(Op): ...@@ -3293,13 +3294,30 @@ class ARange(Op):
) )
] ]
def perform(self, node, inp, out_): def perform(self, node, inputs, output_storage):
start, stop, step = inp start, stop, step = inputs
(out,) = out_ output_storage[0][0] = np.arange(
start = start.item() start.item(), stop.item(), step.item(), dtype=self.dtype
stop = stop.item() )
step = step.item()
out[0] = np.arange(start, stop, step, dtype=self.dtype) def c_code(self, node, nodename, input_names, output_names, sub):
[start_name, stop_name, step_name] = input_names
[out_name] = output_names
typenum = np.dtype(self.dtype).num
return f"""
double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0];
double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0];
double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0];
//printf("start: %f, stop: %f, step: %f\\n", start, stop, step);
Py_XDECREF({out_name});
{out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum});
if (!{out_name}) {{
{sub["fail"]}
}}
"""
def c_code_cache_version(self):
return (0,)
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [False], [True]] return [[True], [False], [True]]
...@@ -3685,8 +3703,7 @@ def inverse_permutation(perm): ...@@ -3685,8 +3703,7 @@ def inverse_permutation(perm):
) )
# TODO: optimization to insert ExtractDiag with view=True class ExtractDiag(COp):
class ExtractDiag(Op):
""" """
Return specified diagonals. Return specified diagonals.
...@@ -3742,7 +3759,7 @@ class ExtractDiag(Op): ...@@ -3742,7 +3759,7 @@ class ExtractDiag(Op):
__props__ = ("offset", "axis1", "axis2", "view") __props__ = ("offset", "axis1", "axis2", "view")
def __init__(self, offset=0, axis1=0, axis2=1, view=False): def __init__(self, offset=0, axis1=0, axis2=1, view=True):
self.view = view self.view = view
if self.view: if self.view:
self.view_map = {0: [0]} self.view_map = {0: [0]}
...@@ -3765,24 +3782,74 @@ class ExtractDiag(Op): ...@@ -3765,24 +3782,74 @@ class ExtractDiag(Op):
if x.ndim < 2: if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x) raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
out_shape = [ if (dim1 := x.type.shape[self.axis1]) is not None and (
st_dim dim2 := x.type.shape[self.axis2]
for i, st_dim in enumerate(x.type.shape) ) is not None:
if i not in (self.axis1, self.axis2) offset = self.offset
] + [None] if offset > 0:
diag_size = int(np.clip(dim2 - offset, 0, dim1))
elif offset < 0:
diag_size = int(np.clip(dim1 + offset, 0, dim2))
else:
diag_size = int(np.minimum(dim1, dim2))
else:
diag_size = None
out_shape = (
*(
dim
for i, dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
),
diag_size,
)
return Apply( return Apply(
self, self,
[x], [x],
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()], [x.type.clone(dtype=x.dtype, shape=out_shape)()],
) )
def perform(self, node, inputs, outputs): def perform(self, node, inputs, output_storage):
(x,) = inputs (x,) = inputs
(z,) = outputs out = x.diagonal(self.offset, self.axis1, self.axis2)
z[0] = x.diagonal(self.offset, self.axis1, self.axis2) if self.view:
if not self.view: try:
z[0] = z[0].copy() out.flags.writeable = True
except ValueError:
# We can't make this array writable
out = out.copy()
else:
out = out.copy()
output_storage[0][0] = out
def c_code(self, node, nodename, input_names, output_names, sub):
[x_name] = input_names
[out_name] = output_names
return f"""
Py_XDECREF({out_name});
{out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2});
if (!{out_name}) {{
{sub["fail"]} // Error already set by Numpy
}}
if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{
// Make output writeable if input was writeable
PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE);
}} else {{
// Make a copy
PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name});
Py_DECREF({out_name});
if (!{out_name}_copy) {{
{sub['fail']}; // Error already set by Numpy
}}
{out_name} = {out_name}_copy;
}}
"""
def c_code_cache_version(self):
return (0,)
def grad(self, inputs, gout): def grad(self, inputs, gout):
# Avoid circular import # Avoid circular import
...@@ -3829,19 +3896,6 @@ class ExtractDiag(Op): ...@@ -3829,19 +3896,6 @@ class ExtractDiag(Op):
out_shape.append(diag_size) out_shape.append(diag_size)
return [tuple(out_shape)] return [tuple(out_shape)]
def __setstate__(self, state):
self.__dict__.update(state)
if self.view:
self.view_map = {0: [0]}
if "offset" not in state:
self.offset = 0
if "axis1" not in state:
self.axis1 = 0
if "axis2" not in state:
self.axis2 = 1
def extract_diag(x): def extract_diag(x):
warnings.warn( warnings.warn(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论