提交 2307d877 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove unused or unreachable code in tensor/sort.py

上级 71592152
import numpy as np
from pytensor.gradient import grad_undefined
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor.basic import arange, as_tensor_variable, switch
from pytensor.tensor.math import eq, ge, mul
from pytensor.tensor.math import eq, ge
from pytensor.tensor.type import TensorType
def _variable_is_none(var):
return isinstance(var, Constant) and var.data is None
def _check_tensor_is_scalar(var):
"""
Checks if a tensor variable is scalar, raise ValueError otherwise
"""
msg = "%(var)s is expected to be 0d tensor, got %(ndim)d"
if var.ndim != 0:
raise ValueError(msg % (var, var.ndim))
class SortOp(Op):
"""
This class is a wrapper for numpy sort function.
......@@ -39,28 +26,16 @@ class SortOp(Op):
def make_node(self, input, axis=-1):
input = as_tensor_variable(input)
axis = as_tensor_variable(axis)
axis = as_tensor_variable(axis, ndim=0, dtype=int)
out_type = input.type()
return Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
if axis is not None:
if axis != int(axis):
raise ValueError("sort axis must be an integer or None")
axis = int(axis)
a, axis = inputs
z = output_storage[0]
z[0] = np.sort(a, axis, self.kind, self.order)
z[0] = np.sort(a, int(axis), self.kind, self.order)
def infer_shape(self, fgraph, node, inputs_shapes):
if _variable_is_none(node.inputs[1]):
# That means axis = None,
# So the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)]
# axis should not be None
# So there should be the same number of dimensions
# in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] == ()
return [inputs_shapes[0]]
......@@ -172,7 +147,7 @@ class ArgSortOp(Op):
def make_node(self, input, axis=-1):
input = as_tensor_variable(input)
axis = as_tensor_variable(axis)
axis = as_tensor_variable(axis, ndim=0, dtype=int)
return Apply(
self,
[input, axis],
......@@ -180,22 +155,14 @@ class ArgSortOp(Op):
)
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
if axis is not None:
if axis != int(axis):
raise ValueError("sort axis must be an integer or None")
axis = int(axis)
a, axis = inputs
z = output_storage[0]
z[0] = _asarray(
np.argsort(a, axis, self.kind, self.order), dtype=node.outputs[0].dtype
np.argsort(a, int(axis), self.kind, self.order),
dtype=node.outputs[0].dtype,
)
def infer_shape(self, fgraph, node, inputs_shapes):
if _variable_is_none(node.inputs[1]):
return [(mul(*inputs_shapes[0]),)]
# axis should not be None, so there should be the same number of
# dimensions in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] == ()
return [inputs_shapes[0]]
......@@ -239,66 +206,3 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
a = a.flatten()
axis = 0
return ArgSortOp(kind, order)(a, axis)
def _topk_py_impl(op, x, k, axis, idx_dtype):
ndim = x.ndim
assert -ndim <= axis < ndim
axis %= ndim
if k == 0:
raise ValueError("topk: kth cannot be zero")
elif k > x.shape[axis]:
raise ValueError(
f"topk: kth cannot be larger than the size of specified axis {int(axis)}"
)
if abs(k) == 1:
# negative k means min instead of max
fn_max = [None, np.max, np.min][k]
fn_argmax = [None, np.argmax, np.argmin][k]
if not op.return_indices:
return np.expand_dims(fn_max(x, axis=axis), axis)
elif op.return_values:
zi = np.expand_dims(fn_argmax(x, axis=axis), axis)
idx2 = tuple(
np.arange(s).reshape((s,) + (1,) * (ndim - i - 1)) if i != axis else zi
for i, s in enumerate(x.shape)
)
zv = x[idx2]
return zv, zi.astype(idx_dtype)
else:
zi = np.expand_dims(fn_argmax(x, axis=axis), axis)
return zi.astype(idx_dtype)
if x.shape[axis] == abs(k):
if not op.return_indices:
return x.copy()
else:
l = axis
r = ndim - l
reps = list(x.shape)
reps[axis] = 1
zi = np.arange(abs(k), dtype=idx_dtype)
zi = zi.reshape((1,) * l + (k,) + (1,) * (r - 1))
zi = np.tile(zi, reps)
if op.return_values:
return x.copy(), zi
else:
return zi
idx = [slice(None)] * ndim
idx[axis] = slice(-k, None) if k > 0 else slice(-k)
if not op.return_indices:
zv = np.partition(x, -k, axis=axis)[tuple(idx)]
return zv
elif op.return_values:
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
idx2 = tuple(
np.arange(s).reshape((s,) + (1,) * (ndim - i - 1)) if i != axis else zi
for i, s in enumerate(x.shape)
)
zv = x[idx2]
return zv, zi.astype(idx_dtype)
else:
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
return zi.astype(idx_dtype)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论