提交 f8a1d6b9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.tensor.extra_ops.RepeatOp to Repeat

上级 67d7f140
......@@ -38,7 +38,7 @@ from aesara.tensor.extra_ops import (
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
RepeatOp,
Repeat,
Unique,
UnravelIndex,
)
......@@ -877,8 +877,8 @@ def jax_funcify_DiffOp(op, **kwargs):
return diffop
@jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op, **kwargs):
@jax_funcify.register(Repeat)
def jax_funcify_Repeat(op, **kwargs):
axis = op.axis
def repeatop(x, repeats, axis=axis):
......
......@@ -656,7 +656,7 @@ def compress(condition, x, axis=None):
return x.take(indices, axis=axis)
class RepeatOp(Op):
class Repeat(Op):
# See the repeat function for docstring
__props__ = ("axis",)
......@@ -800,7 +800,7 @@ def repeat(x, repeats, axis=None):
raise ValueError("The dimension of repeats should not exceed 1.")
if repeats.ndim == 1 and not repeats.broadcastable[0]:
return RepeatOp(axis=axis)(x, repeats)
return Repeat(axis=axis)(x, repeats)
else:
if repeats.ndim == 1:
repeats = repeats[0]
......
......@@ -20,7 +20,7 @@ from aesara.tensor.extra_ops import (
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
RepeatOp,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
......@@ -437,14 +437,14 @@ class TestCompress(utt.InferShapeTester):
assert np.allclose(tested, expected)
class TestRepeatOp(utt.InferShapeTester):
class TestRepeat(utt.InferShapeTester):
def _possible_axis(self, ndim):
return [None] + list(range(ndim)) + [-i for i in range(ndim)]
def setup_method(self):
super().setup_method()
self.op_class = RepeatOp
self.op = RepeatOp()
self.op_class = Repeat
self.op = Repeat()
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
if LOCAL_BITWIDTH == 64:
......@@ -452,7 +452,7 @@ class TestRepeatOp(utt.InferShapeTester):
if LOCAL_BITWIDTH == 32:
self.numpy_unsupported_dtypes = ("uint32", "int64", "uint64")
def test_repeatOp(self):
def test_basic(self):
for ndim in [1, 3]:
x = TensorType(config.floatX, [False] * ndim)()
a = np.random.random((10,) * ndim).astype(config.floatX)
......@@ -489,7 +489,7 @@ class TestRepeatOp(utt.InferShapeTester):
assert np.allclose(np.repeat(a, r, axis=axis), f(a))
assert not np.any(
[
isinstance(n.op, RepeatOp)
isinstance(n.op, Repeat)
for n in f.maker.fgraph.toposort()
]
)
......@@ -501,7 +501,7 @@ class TestRepeatOp(utt.InferShapeTester):
assert np.allclose(np.repeat(a, r[0], axis=axis), f(a, r))
assert not np.any(
[
isinstance(n.op, RepeatOp)
isinstance(n.op, Repeat)
for n in f.maker.fgraph.toposort()
]
)
......@@ -524,7 +524,7 @@ class TestRepeatOp(utt.InferShapeTester):
else:
self._compile_and_check(
[x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[Repeat(axis=axis)(x, r_var)],
[a, r],
self.op_class,
)
......@@ -541,7 +541,7 @@ class TestRepeatOp(utt.InferShapeTester):
self._compile_and_check(
[x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[Repeat(axis=axis)(x, r_var)],
[a, r],
self.op_class,
)
......@@ -551,15 +551,15 @@ class TestRepeatOp(utt.InferShapeTester):
a = np.random.random((10,) * ndim).astype(config.floatX)
for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a])
def test_broadcastable(self):
x = TensorType(config.floatX, [False, True, False])()
r = RepeatOp(axis=1)(x, 2)
r = Repeat(axis=1)(x, 2)
assert r.broadcastable == (False, False, False)
r = RepeatOp(axis=1)(x, 1)
r = Repeat(axis=1)(x, 1)
assert r.broadcastable == (False, True, False)
r = RepeatOp(axis=0)(x, 2)
r = Repeat(axis=0)(x, 2)
assert r.broadcastable == (False, True, False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论