提交 39235a34 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a scalar mean Op

上级 6fe9f839
...@@ -1857,6 +1857,32 @@ class Add(ScalarOp): ...@@ -1857,6 +1857,32 @@ class Add(ScalarOp):
add = Add(upcast_out, name="add") add = Add(upcast_out, name="add")
class Mean(ScalarOp):
identity = 0
commutative = True
associative = False
nfunc_spec = ("mean", 2, 1)
nfunc_variadic = "mean"
def impl(self, *inputs):
return sum(inputs) / len(inputs)
def c_code(self, node, name, inputs, outputs, sub):
(z,) = outputs
if not inputs:
return f"{z} = 0;"
else:
return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});"
def L_op(self, inputs, outputs, gout):
(gz,) = gout
retval = [gz / len(inputs)] * len(inputs)
return retval
mean = Mean(float_out, name="mean")
class Mul(ScalarOp): class Mul(ScalarOp):
identity = 1 identity = 1
commutative = True commutative = True
......
...@@ -1474,8 +1474,11 @@ def complex_from_polar(abs, angle): ...@@ -1474,8 +1474,11 @@ def complex_from_polar(abs, angle):
class Mean(CAReduce): class Mean(CAReduce):
__props__ = ("axis",)
nfunc_spec = ("mean", 1, 1)
def __init__(self, axis=None): def __init__(self, axis=None):
super().__init__(aes.add, axis) super().__init__(aes.mean, axis)
assert self.axis is None or len(self.axis) == 1 assert self.axis is None or len(self.axis) == 1
def __str__(self): def __str__(self):
......
...@@ -51,6 +51,7 @@ from aesara.scalar.basic import ( ...@@ -51,6 +51,7 @@ from aesara.scalar.basic import (
log1p, log1p,
log2, log2,
log10, log10,
mean,
mul, mul,
neq, neq,
rad2deg, rad2deg,
...@@ -64,7 +65,7 @@ from aesara.scalar.basic import ( ...@@ -64,7 +65,7 @@ from aesara.scalar.basic import (
true_div, true_div,
uint8, uint8,
) )
from aesara.tensor.type import fscalar, imatrix, matrix from aesara.tensor.type import fscalar, imatrix, iscalar, matrix
def test_mul_add_true(): def test_mul_add_true():
...@@ -468,3 +469,31 @@ def test_constant(): ...@@ -468,3 +469,31 @@ def test_constant():
c = constant(2, dtype="float32") c = constant(2, dtype="float32")
assert c.name is None assert c.name is None
assert c.dtype == "float32" assert c.dtype == "float32"
@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")])
def test_mean(mode):
a = iscalar("a")
b = iscalar("b")
z = mean(a, b)
z_fn = aesara.function([a, b], z, mode=mode)
res = z_fn(1, 1)
assert np.allclose(res, 1.0)
a = fscalar("a")
b = fscalar("b")
c = fscalar("c")
z = mean(a, b, c)
z_fn = aesara.function([a, b, c], aesara.grad(z, [a]), mode=mode)
res = z_fn(3, 4, 5)
assert np.allclose(res, 1 / 3)
z_fn = aesara.function([a, b, c], aesara.grad(z, [b]), mode=mode)
res = z_fn(3, 4, 5)
assert np.allclose(res, 1 / 3)
z = mean()
z_fn = aesara.function([], z, mode=mode)
assert z_fn() == 0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论