提交 1d3a1623 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added linspace, geomspace and logspace to aesara.tensor.extra_ops

上级 980c4c2c
...@@ -134,6 +134,9 @@ from aesara.tensor.extra_ops import ( # noqa ...@@ -134,6 +134,9 @@ from aesara.tensor.extra_ops import ( # noqa
squeeze, squeeze,
unique, unique,
unravel_index, unravel_index,
linspace,
logspace,
geomspace,
) )
from aesara.tensor.shape import ( # noqa from aesara.tensor.shape import ( # noqa
reshape, reshape,
......
...@@ -1636,6 +1636,29 @@ class BroadcastTo(Op): ...@@ -1636,6 +1636,29 @@ class BroadcastTo(Op):
broadcast_to_ = BroadcastTo() broadcast_to_ = BroadcastTo()
def geomspace(start, end, steps, base=10.0):
from aesara.tensor.math import log
start = at.as_tensor_variable(start)
end = at.as_tensor_variable(end)
return base ** linspace(log(start) / log(base), log(end) / log(base), steps)
def logspace(start, end, steps, base=10.0):
start = at.as_tensor_variable(start)
end = at.as_tensor_variable(end)
return base ** linspace(start, end, steps)
def linspace(start, end, steps):
start = at.as_tensor_variable(start)
end = at.as_tensor_variable(end)
arr = at.arange(steps)
arr = at.shape_padright(arr, max(start.ndim, end.ndim))
multiplier = (end - start) / (steps - 1)
return start + arr * multiplier
def broadcast_to( def broadcast_to(
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]] x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]]
) -> TensorVariable: ) -> TensorVariable:
......
...@@ -36,6 +36,9 @@ from aesara.tensor.extra_ops import ( ...@@ -36,6 +36,9 @@ from aesara.tensor.extra_ops import (
diff, diff,
fill_diagonal, fill_diagonal,
fill_diagonal_offset, fill_diagonal_offset,
geomspace,
linspace,
logspace,
ravel_multi_index, ravel_multi_index,
repeat, repeat,
searchsorted, searchsorted,
...@@ -65,6 +68,11 @@ from aesara.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH ...@@ -65,6 +68,11 @@ from aesara.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
from tests import unittest_tools as utt from tests import unittest_tools as utt
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_cpu_contiguous(): def test_cpu_contiguous():
a = fmatrix("a") a = fmatrix("a")
i = iscalar("i") i = iscalar("i")
...@@ -1222,3 +1230,28 @@ def test_broadcast_arrays(): ...@@ -1222,3 +1230,28 @@ def test_broadcast_arrays():
assert np.array_equal(x_bcast_val, x_bcast_exp) assert np.array_equal(x_bcast_val, x_bcast_exp)
assert np.array_equal(y_bcast_val, y_bcast_exp) assert np.array_equal(y_bcast_val, y_bcast_exp)
@pytest.mark.parametrize(
"start, stop, num_samples",
[
(1, 10, 50),
(np.array([5, 6]), np.array([[10, 10], [10, 10]]), 25),
(1, np.array([5, 6]), 30),
],
)
def test_space_ops(start, stop, num_samples):
z = linspace(start, stop, num_samples)
aesara_res = function(inputs=[], outputs=z)()
numpy_res = np.linspace(start, stop, num=num_samples)
assert np.allclose(aesara_res, numpy_res)
z = logspace(start, stop, num_samples)
aesara_res = function(inputs=[], outputs=z)()
numpy_res = np.logspace(start, stop, num=num_samples)
assert np.allclose(aesara_res, numpy_res)
z = geomspace(start, stop, num_samples)
aesara_res = function(inputs=[], outputs=z)()
numpy_res = np.geomspace(start, stop, num=num_samples)
assert np.allclose(aesara_res, numpy_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论