提交 9e24b10a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove Mean Op

This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce
上级 1a3af4b2
...@@ -34,7 +34,6 @@ from pytensor.scalar.basic import ( ...@@ -34,7 +34,6 @@ from pytensor.scalar.basic import (
Add, Add,
Composite, Composite,
IntDiv, IntDiv,
Mean,
Mul, Mul,
ScalarMaximum, ScalarMaximum,
ScalarMinimum, ScalarMinimum,
...@@ -77,11 +76,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr): ...@@ -77,11 +76,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}" return f"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register(Mean)
def scalar_in_place_fn_Mean(op, idx, res, arr):
return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)"
@scalar_in_place_fn.register(Mul) @scalar_in_place_fn.register(Mul)
def scalar_in_place_fn_Mul(op, idx, res, arr): def scalar_in_place_fn_Mul(op, idx, res, arr):
return f"{res}[{idx}] *= {arr}" return f"{res}[{idx}] *= {arr}"
......
...@@ -1871,32 +1871,6 @@ class Add(ScalarOp): ...@@ -1871,32 +1871,6 @@ class Add(ScalarOp):
add = Add(upcast_out, name="add") add = Add(upcast_out, name="add")
class Mean(ScalarOp):
identity = 0
commutative = True
associative = False
nfunc_spec = ("mean", 2, 1)
nfunc_variadic = "mean"
def impl(self, *inputs):
return sum(inputs) / len(inputs)
def c_code(self, node, name, inputs, outputs, sub):
(z,) = outputs
if not inputs:
return f"{z} = 0;"
else:
return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});"
def L_op(self, inputs, outputs, gout):
(gz,) = gout
retval = [gz / len(inputs)] * len(inputs)
return retval
mean = Mean(float_out, name="mean")
class Mul(ScalarOp): class Mul(ScalarOp):
identity = 1 identity = 1
commutative = True commutative = True
......
...@@ -1316,63 +1316,7 @@ def complex_from_polar(abs, angle): ...@@ -1316,63 +1316,7 @@ def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification.""" """Return complex-valued tensor from polar coordinate specification."""
class Mean(FixedOpCAReduce): def mean(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
__props__ = ("axis",)
nfunc_spec = ("mean", 1, 1)
def __init__(self, axis=None):
super().__init__(ps.mean, axis)
assert self.axis is None or len(self.axis) == 1
def __str__(self):
if self.axis is not None:
args = ", ".join(str(x) for x in self.axis)
return f"Mean{{{args}}}"
else:
return "Mean"
def _output_dtype(self, idtype):
# we want to protect against overflow
return "float64"
def perform(self, node, inp, out):
(input,) = inp
(output,) = out
if self.axis is None:
axis = None
else:
axis = self.axis[0]
# numpy.asarray is needed as otherwise we can end up with a
# numpy scalar.
output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis))
def c_code(self, node, name, inames, onames, sub):
ret = super().c_code(node, name, inames, onames, sub)
if self.axis is not None:
return ret
# TODO: c_code perform support only axis is None
return (
ret
+ f"""
*((double *)PyArray_DATA({onames[0]})) /= PyArray_SIZE({inames[0]});
"""
)
def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis)
# TODO: implement the grad. When done and tested, you can make this the default
# version.
# def grad(self, (x,), (gout,)):
# import pdb;pdb.set_trace()
# return grad(mean(x, self.axis, op=False),[x])
def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None):
""" """
Computes the mean value along the given axis(es) of a tensor `input`. Computes the mean value along the given axis(es) of a tensor `input`.
...@@ -1397,25 +1341,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None) ...@@ -1397,25 +1341,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
be in a float type). If None, then we use the same rules as `sum()`. be in a float type). If None, then we use the same rules as `sum()`.
""" """
input = as_tensor_variable(input) input = as_tensor_variable(input)
if op:
if dtype not in (None, "float64"):
raise NotImplementedError(
"The Mean op does not support the dtype argument, "
"and will always use float64. If you want to specify "
"the dtype, call tensor.mean(..., op=False).",
dtype,
)
if acc_dtype not in (None, "float64"):
raise NotImplementedError(
"The Mean op does not support the acc_dtype argument, "
"and will always use float64. If you want to specify "
"acc_dtype, call tensor.mean(..., op=False).",
dtype,
)
out = Mean(axis)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
return out
if dtype is not None: if dtype is not None:
# The summation will be done with the specified dtype. # The summation will be done with the specified dtype.
......
...@@ -16,7 +16,7 @@ from pytensor.gradient import grad ...@@ -16,7 +16,7 @@ from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
...@@ -256,18 +256,6 @@ def test_Dimshuffle_non_contiguous(): ...@@ -256,18 +256,6 @@ def test_Dimshuffle_non_contiguous():
0, 0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
), ),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
......
...@@ -43,7 +43,6 @@ from pytensor.scalar.basic import ( ...@@ -43,7 +43,6 @@ from pytensor.scalar.basic import (
log1p, log1p,
log2, log2,
log10, log10,
mean,
mul, mul,
neg, neg,
neq, neq,
...@@ -58,7 +57,7 @@ from pytensor.scalar.basic import ( ...@@ -58,7 +57,7 @@ from pytensor.scalar.basic import (
true_div, true_div,
uint8, uint8,
) )
from pytensor.tensor.type import fscalar, imatrix, iscalar, matrix from pytensor.tensor.type import fscalar, imatrix, matrix
from tests.link.test_link import make_function from tests.link.test_link import make_function
...@@ -521,34 +520,6 @@ def test_constant(): ...@@ -521,34 +520,6 @@ def test_constant():
assert c.dtype == "float32" assert c.dtype == "float32"
@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")])
def test_mean(mode):
a = iscalar("a")
b = iscalar("b")
z = mean(a, b)
z_fn = pytensor.function([a, b], z, mode=mode)
res = z_fn(1, 1)
assert np.allclose(res, 1.0)
a = fscalar("a")
b = fscalar("b")
c = fscalar("c")
z = mean(a, b, c)
z_fn = pytensor.function([a, b, c], pytensor.grad(z, [a]), mode=mode)
res = z_fn(3, 4, 5)
assert np.allclose(res, 1 / 3)
z_fn = pytensor.function([a, b, c], pytensor.grad(z, [b]), mode=mode)
res = z_fn(3, 4, 5)
assert np.allclose(res, 1 / 3)
z = mean()
z_fn = pytensor.function([], z, mode=mode)
assert z_fn() == 0
def test_shape(): def test_shape():
a = float32("a") a = float32("a")
assert isinstance(a.type, ScalarType) assert isinstance(a.type, ScalarType)
......
...@@ -40,7 +40,6 @@ from pytensor.tensor.math import ( ...@@ -40,7 +40,6 @@ from pytensor.tensor.math import (
Argmax, Argmax,
Dot, Dot,
Max, Max,
Mean,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
...@@ -2587,15 +2586,6 @@ def test_mod_compile(): ...@@ -2587,15 +2586,6 @@ def test_mod_compile():
class TestInferShape(utt.InferShapeTester): class TestInferShape(utt.InferShapeTester):
def test_Mean(self):
adtens3 = dtensor3()
adtens3_val = random(3, 4, 5)
aiscal_val = 2
self._compile_and_check([adtens3], [Mean(None)(adtens3)], [adtens3_val], Mean)
self._compile_and_check(
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
)
def test_Max(self): def test_Max(self):
adtens3 = dtensor3() adtens3 = dtensor3()
adtens3_val = random(4, 5, 3) adtens3_val = random(4, 5, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论