提交 1f0dfd36 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Rebroadcast tests to tests.tensor.test_basic

上级 6f2a60ba
...@@ -4,9 +4,7 @@ import numpy as np ...@@ -4,9 +4,7 @@ import numpy as np
from aesara import function from aesara import function
from aesara.compile.ops import as_op from aesara.compile.ops import as_op
from aesara.configdefaults import config from aesara.tensor.type import dmatrix, dvector
from aesara.tensor.basic import Rebroadcast
from aesara.tensor.type import TensorType, dmatrix, dtensor4, dvector
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -81,29 +79,3 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -81,29 +79,3 @@ class TestOpDecorator(utt.InferShapeTester):
m2 = pickle.loads(s) m2 = pickle.loads(s)
assert m2.owner.op == m.owner.op assert m2.owner.op == m.owner.op
class TestRebroadcast(utt.InferShapeTester):
def test_rebroadcast(self):
rng = np.random.RandomState(3453)
# Rebroadcast
adtens4 = dtensor4()
adict = [(0, False), (1, True), (2, False), (3, True)]
adtens4_val = rng.rand(2, 1, 3, 1).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Rebroadcast(*adict)(adtens4)],
[adtens4_val],
Rebroadcast,
warn=False,
)
adtens4_bro = TensorType("float64", (True, True, True, False))()
bdict = [(0, True), (1, False), (2, False), (3, False)]
adtens4_bro_val = rng.rand(1, 1, 1, 3).astype(config.floatX)
self._compile_and_check(
[adtens4_bro],
[Rebroadcast(*bdict)(adtens4_bro)],
[adtens4_bro_val],
Rebroadcast,
)
...@@ -33,6 +33,7 @@ from aesara.tensor.basic import ( ...@@ -33,6 +33,7 @@ from aesara.tensor.basic import (
Join, Join,
MakeVector, MakeVector,
PermuteRowElements, PermuteRowElements,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
...@@ -3009,6 +3010,32 @@ class TestBroadcast: ...@@ -3009,6 +3010,32 @@ class TestBroadcast:
assert isinstance(topo[1].op, MakeVector) assert isinstance(topo[1].op, MakeVector)
class TestRebroadcast(utt.InferShapeTester):
def test_rebroadcast(self):
rng = np.random.RandomState(3453)
# Rebroadcast
adtens4 = dtensor4()
adict = [(0, False), (1, True), (2, False), (3, True)]
adtens4_val = rng.rand(2, 1, 3, 1).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Rebroadcast(*adict)(adtens4)],
[adtens4_val],
Rebroadcast,
warn=False,
)
adtens4_bro = TensorType("float64", (True, True, True, False))()
bdict = [(0, True), (1, False), (2, False), (3, False)]
adtens4_bro_val = rng.rand(1, 1, 1, 3).astype(config.floatX)
self._compile_and_check(
[adtens4_bro],
[Rebroadcast(*bdict)(adtens4_bro)],
[adtens4_bro_val],
Rebroadcast,
)
def test_len(): def test_len():
for shape_ in [(5,), (3, 4), (7, 4, 6)]: for shape_ in [(5,), (3, 4), (7, 4, 6)]:
x = tensor(dtype="floatX", broadcastable=(False,) * len(shape_)) x = tensor(dtype="floatX", broadcastable=(False,) * len(shape_))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论