提交 f8a1d6b9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.tensor.extra_ops.RepeatOp to Repeat

上级 67d7f140
...@@ -38,7 +38,7 @@ from aesara.tensor.extra_ops import ( ...@@ -38,7 +38,7 @@ from aesara.tensor.extra_ops import (
FillDiagonal, FillDiagonal,
FillDiagonalOffset, FillDiagonalOffset,
RavelMultiIndex, RavelMultiIndex,
RepeatOp, Repeat,
Unique, Unique,
UnravelIndex, UnravelIndex,
) )
...@@ -877,8 +877,8 @@ def jax_funcify_DiffOp(op, **kwargs): ...@@ -877,8 +877,8 @@ def jax_funcify_DiffOp(op, **kwargs):
return diffop return diffop
@jax_funcify.register(RepeatOp) @jax_funcify.register(Repeat)
def jax_funcify_RepeatOp(op, **kwargs): def jax_funcify_Repeat(op, **kwargs):
axis = op.axis axis = op.axis
def repeatop(x, repeats, axis=axis): def repeatop(x, repeats, axis=axis):
......
...@@ -656,7 +656,7 @@ def compress(condition, x, axis=None): ...@@ -656,7 +656,7 @@ def compress(condition, x, axis=None):
return x.take(indices, axis=axis) return x.take(indices, axis=axis)
class RepeatOp(Op): class Repeat(Op):
# See the repeat function for docstring # See the repeat function for docstring
__props__ = ("axis",) __props__ = ("axis",)
...@@ -800,7 +800,7 @@ def repeat(x, repeats, axis=None): ...@@ -800,7 +800,7 @@ def repeat(x, repeats, axis=None):
raise ValueError("The dimension of repeats should not exceed 1.") raise ValueError("The dimension of repeats should not exceed 1.")
if repeats.ndim == 1 and not repeats.broadcastable[0]: if repeats.ndim == 1 and not repeats.broadcastable[0]:
return RepeatOp(axis=axis)(x, repeats) return Repeat(axis=axis)(x, repeats)
else: else:
if repeats.ndim == 1: if repeats.ndim == 1:
repeats = repeats[0] repeats = repeats[0]
......
...@@ -20,7 +20,7 @@ from aesara.tensor.extra_ops import ( ...@@ -20,7 +20,7 @@ from aesara.tensor.extra_ops import (
FillDiagonal, FillDiagonal,
FillDiagonalOffset, FillDiagonalOffset,
RavelMultiIndex, RavelMultiIndex,
RepeatOp, Repeat,
SearchsortedOp, SearchsortedOp,
Unique, Unique,
UnravelIndex, UnravelIndex,
...@@ -437,14 +437,14 @@ class TestCompress(utt.InferShapeTester): ...@@ -437,14 +437,14 @@ class TestCompress(utt.InferShapeTester):
assert np.allclose(tested, expected) assert np.allclose(tested, expected)
class TestRepeatOp(utt.InferShapeTester): class TestRepeat(utt.InferShapeTester):
def _possible_axis(self, ndim): def _possible_axis(self, ndim):
return [None] + list(range(ndim)) + [-i for i in range(ndim)] return [None] + list(range(ndim)) + [-i for i in range(ndim)]
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.op_class = RepeatOp self.op_class = Repeat
self.op = RepeatOp() self.op = Repeat()
# uint64 always fails # uint64 always fails
# int64 and uint32 also fail if python int are 32-bit # int64 and uint32 also fail if python int are 32-bit
if LOCAL_BITWIDTH == 64: if LOCAL_BITWIDTH == 64:
...@@ -452,7 +452,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -452,7 +452,7 @@ class TestRepeatOp(utt.InferShapeTester):
if LOCAL_BITWIDTH == 32: if LOCAL_BITWIDTH == 32:
self.numpy_unsupported_dtypes = ("uint32", "int64", "uint64") self.numpy_unsupported_dtypes = ("uint32", "int64", "uint64")
def test_repeatOp(self): def test_basic(self):
for ndim in [1, 3]: for ndim in [1, 3]:
x = TensorType(config.floatX, [False] * ndim)() x = TensorType(config.floatX, [False] * ndim)()
a = np.random.random((10,) * ndim).astype(config.floatX) a = np.random.random((10,) * ndim).astype(config.floatX)
...@@ -489,7 +489,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -489,7 +489,7 @@ class TestRepeatOp(utt.InferShapeTester):
assert np.allclose(np.repeat(a, r, axis=axis), f(a)) assert np.allclose(np.repeat(a, r, axis=axis), f(a))
assert not np.any( assert not np.any(
[ [
isinstance(n.op, RepeatOp) isinstance(n.op, Repeat)
for n in f.maker.fgraph.toposort() for n in f.maker.fgraph.toposort()
] ]
) )
...@@ -501,7 +501,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -501,7 +501,7 @@ class TestRepeatOp(utt.InferShapeTester):
assert np.allclose(np.repeat(a, r[0], axis=axis), f(a, r)) assert np.allclose(np.repeat(a, r[0], axis=axis), f(a, r))
assert not np.any( assert not np.any(
[ [
isinstance(n.op, RepeatOp) isinstance(n.op, Repeat)
for n in f.maker.fgraph.toposort() for n in f.maker.fgraph.toposort()
] ]
) )
...@@ -524,7 +524,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -524,7 +524,7 @@ class TestRepeatOp(utt.InferShapeTester):
else: else:
self._compile_and_check( self._compile_and_check(
[x, r_var], [x, r_var],
[RepeatOp(axis=axis)(x, r_var)], [Repeat(axis=axis)(x, r_var)],
[a, r], [a, r],
self.op_class, self.op_class,
) )
...@@ -541,7 +541,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -541,7 +541,7 @@ class TestRepeatOp(utt.InferShapeTester):
self._compile_and_check( self._compile_and_check(
[x, r_var], [x, r_var],
[RepeatOp(axis=axis)(x, r_var)], [Repeat(axis=axis)(x, r_var)],
[a, r], [a, r],
self.op_class, self.op_class,
) )
...@@ -551,15 +551,15 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -551,15 +551,15 @@ class TestRepeatOp(utt.InferShapeTester):
a = np.random.random((10,) * ndim).astype(config.floatX) a = np.random.random((10,) * ndim).astype(config.floatX)
for axis in self._possible_axis(ndim): for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a]) utt.verify_grad(lambda x: Repeat(axis=axis)(x, 3), [a])
def test_broadcastable(self): def test_broadcastable(self):
x = TensorType(config.floatX, [False, True, False])() x = TensorType(config.floatX, [False, True, False])()
r = RepeatOp(axis=1)(x, 2) r = Repeat(axis=1)(x, 2)
assert r.broadcastable == (False, False, False) assert r.broadcastable == (False, False, False)
r = RepeatOp(axis=1)(x, 1) r = Repeat(axis=1)(x, 1)
assert r.broadcastable == (False, True, False) assert r.broadcastable == (False, True, False)
r = RepeatOp(axis=0)(x, 2) r = Repeat(axis=0)(x, 2)
assert r.broadcastable == (False, True, False) assert r.broadcastable == (False, True, False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论