提交 d9b3924f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify Unique Op

上级 fa0ab9de
...@@ -41,7 +41,7 @@ from pytensor.tensor.math import ( ...@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
) )
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.shape import Shape_i, specify_broadcastable
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -1194,23 +1194,22 @@ class Unique(Op): ...@@ -1194,23 +1194,22 @@ class Unique(Op):
self.return_index = return_index self.return_index = return_index
self.return_inverse = return_inverse self.return_inverse = return_inverse
self.return_counts = return_counts self.return_counts = return_counts
if axis is not None and axis < 0:
raise ValueError("Axis cannot be negative.")
self.axis = axis self.axis = axis
def make_node(self, x): def make_node(self, x):
x = ptb.as_tensor_variable(x) x = ptb.as_tensor_variable(x)
self_axis = self.axis axis = self.axis
if self_axis is None: if axis is None:
out_shape = (None,) out_shape = (None,)
else: else:
if self_axis < 0: if axis >= x.type.ndim:
self_axis += x.type.ndim
if self_axis < 0 or self_axis >= x.type.ndim:
raise ValueError( raise ValueError(
f"Unique axis {self.axis} is outside of input ndim = {x.type.ndim}" f"Axis {axis} out of range for input {x} with ndim={x.type.ndim}."
) )
out_shape = tuple( out_shape = tuple(
s if s == 1 and axis != self_axis else None None if dim == axis else s for dim, s in enumerate(x.type.shape)
for axis, s in enumerate(x.type.shape)
) )
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()] outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
...@@ -1224,60 +1223,37 @@ class Unique(Op): ...@@ -1224,60 +1223,37 @@ class Unique(Op):
return Apply(self, [x], outputs) return Apply(self, [x], outputs)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x = inputs[0] [x] = inputs
z = output_storage outs = np.unique(
param = {} x,
if self.return_index: return_index=self.return_index,
param["return_index"] = True return_inverse=self.return_inverse,
if self.return_inverse: return_counts=self.return_counts,
param["return_inverse"] = True axis=self.axis,
if self.return_counts: )
param["return_counts"] = True if isinstance(outs, tuple):
if self.axis is not None:
param["axis"] = self.axis
outs = np.unique(x, **param)
if (
(not self.return_inverse)
and (not self.return_index)
and (not self.return_counts)
):
z[0][0] = outs
else:
for i in range(len(outs)): for i in range(len(outs)):
z[i][0] = outs[i] output_storage[i][0] = outs[i]
else:
output_storage[0][0] = outs
def infer_shape(self, fgraph, node, i0_shapes): def infer_shape(self, fgraph, node, i0_shapes):
ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes) [x_shape] = i0_shapes
if self.axis is not None: shape0_op = Shape_i(0)
self_axis = self.axis out_shapes = [(shape0_op(out),) for out in node.outputs]
ndim = len(i0_shapes[0])
if self_axis < 0: axis = self.axis
self_axis += ndim if axis is not None:
if self_axis < 0 or self_axis >= ndim: shape = list(x_shape)
raise RuntimeError( shape[axis] = Shape_i(axis)(node.outputs[0])
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}." out_shapes[0] = tuple(shape)
)
ret[0] = tuple(
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
)
if self.return_inverse: if self.return_inverse:
if self.axis is None: shape = prod(x_shape) if self.axis is None else x_shape[axis]
shape = (prod(i0_shapes[0]),) return_index_out_idx = 2 if self.return_index else 1
else: out_shapes[return_index_out_idx] = (shape,)
shape = (i0_shapes[0][self_axis],)
if self.return_index: return out_shapes
ret[2] = shape
return ret
ret[1] = shape
return ret
return ret
def __setstate__(self, state):
self.__dict__.update(state)
# For backwards compatibility with pickled instances of Unique that
# did not have the axis parameter specified
if "axis" not in state:
self.axis = None
def unique( def unique(
...@@ -1293,6 +1269,9 @@ def unique( ...@@ -1293,6 +1269,9 @@ def unique(
* the number of times each unique value comes up in the input array * the number of times each unique value comes up in the input array
""" """
ar = as_tensor_variable(ar)
if axis is not None:
axis = normalize_axis_index(axis, ar.ndim)
return Unique(return_index, return_inverse, return_counts, axis)(ar) return Unique(return_index, return_inverse, return_counts, axis)(ar)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论