提交 d9b3924f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify Unique Op

上级 fa0ab9de
......@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
)
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable
......@@ -1194,23 +1194,22 @@ class Unique(Op):
self.return_index = return_index
self.return_inverse = return_inverse
self.return_counts = return_counts
if axis is not None and axis < 0:
raise ValueError("Axis cannot be negative.")
self.axis = axis
def make_node(self, x):
x = ptb.as_tensor_variable(x)
self_axis = self.axis
if self_axis is None:
axis = self.axis
if axis is None:
out_shape = (None,)
else:
if self_axis < 0:
self_axis += x.type.ndim
if self_axis < 0 or self_axis >= x.type.ndim:
if axis >= x.type.ndim:
raise ValueError(
f"Unique axis {self.axis} is outside of input ndim = {x.type.ndim}"
f"Axis {axis} out of range for input {x} with ndim={x.type.ndim}."
)
out_shape = tuple(
s if s == 1 and axis != self_axis else None
for axis, s in enumerate(x.type.shape)
None if dim == axis else s for dim, s in enumerate(x.type.shape)
)
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
......@@ -1224,60 +1223,37 @@ class Unique(Op):
return Apply(self, [x], outputs)
def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage
param = {}
if self.return_index:
param["return_index"] = True
if self.return_inverse:
param["return_inverse"] = True
if self.return_counts:
param["return_counts"] = True
if self.axis is not None:
param["axis"] = self.axis
outs = np.unique(x, **param)
if (
(not self.return_inverse)
and (not self.return_index)
and (not self.return_counts)
):
z[0][0] = outs
else:
[x] = inputs
outs = np.unique(
x,
return_index=self.return_index,
return_inverse=self.return_inverse,
return_counts=self.return_counts,
axis=self.axis,
)
if isinstance(outs, tuple):
for i in range(len(outs)):
z[i][0] = outs[i]
output_storage[i][0] = outs[i]
else:
output_storage[0][0] = outs
def infer_shape(self, fgraph, node, i0_shapes):
ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes)
if self.axis is not None:
self_axis = self.axis
ndim = len(i0_shapes[0])
if self_axis < 0:
self_axis += ndim
if self_axis < 0 or self_axis >= ndim:
raise RuntimeError(
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
)
ret[0] = tuple(
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
)
[x_shape] = i0_shapes
shape0_op = Shape_i(0)
out_shapes = [(shape0_op(out),) for out in node.outputs]
axis = self.axis
if axis is not None:
shape = list(x_shape)
shape[axis] = Shape_i(axis)(node.outputs[0])
out_shapes[0] = tuple(shape)
if self.return_inverse:
if self.axis is None:
shape = (prod(i0_shapes[0]),)
else:
shape = (i0_shapes[0][self_axis],)
if self.return_index:
ret[2] = shape
return ret
ret[1] = shape
return ret
return ret
def __setstate__(self, state):
self.__dict__.update(state)
# For backwards compatibility with pickled instances of Unique that
# did not have the axis parameter specified
if "axis" not in state:
self.axis = None
shape = prod(x_shape) if self.axis is None else x_shape[axis]
return_index_out_idx = 2 if self.return_index else 1
out_shapes[return_index_out_idx] = (shape,)
return out_shapes
def unique(
......@@ -1293,6 +1269,9 @@ def unique(
* the number of times each unique value comes up in the input array
"""
ar = as_tensor_variable(ar)
if axis is not None:
axis = normalize_axis_index(axis, ar.ndim)
return Unique(return_index, return_inverse, return_counts, axis)(ar)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论