提交 b8e6b3ea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor and fix static shape issues in IfElse

上级 01c4a55f
差异被折叠。
......@@ -6,6 +6,7 @@ import pytest
import aesara
import aesara.ifelse
import aesara.sparse
import aesara.tensor.basic as at
from aesara import function
from aesara.compile.mode import Mode, get_mode
......@@ -14,15 +15,19 @@ from aesara.graph.op import Op
from aesara.ifelse import IfElse, ifelse
from aesara.link.c.type import generic
from aesara.tensor.math import eq
from aesara.tensor.type import col, iscalar, matrix, row, scalar, tensor3, vector
from aesara.tensor.type import (
col,
iscalar,
ivector,
matrix,
row,
scalar,
tensor3,
vector,
)
from tests import unittest_tools as utt
__docformat__ = "restructedtext en"
__authors__ = "Razvan Pascanu " "PyMC Development Team " "Aesara Developers "
__copyright__ = "(c) 2010, Universite de Montreal"
class TestIfelse(utt.OptimizationTestMixin):
mode = None
dtype = aesara.config.floatX
......@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin):
with pytest.raises(ValueError):
IfElse(0)(c, x, x)
def test_const_Op_argument(self):
def test_const_false_branch(self):
x = vector("x", dtype=self.dtype)
y = np.array([2.0, 3.0], dtype=self.dtype)
c = iscalar("c")
......@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin):
ifelse(cond, y, x)
def test_sparse_tensor_error(self):
pytest.importorskip("scipy", minversion="0.7.0")
import aesara.sparse
rng = np.random.default_rng(utt.fetch_seed())
data = rng.random((2, 3)).astype(self.dtype)
......@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin):
res.owner.op.as_view = True
assert str(res.owner).startswith("if{name,inplace}")
@pytest.mark.parametrize(
"x_shape, y_shape, x_val, y_val, exp_shape",
[
((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)),
((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)),
],
)
def test_static_branch_shapes(self, x_shape, y_shape, x_val, y_val, exp_shape):
x = at.tensor(dtype=self.dtype, shape=x_shape, name="x")
y = at.tensor(dtype=self.dtype, shape=y_shape, name="y")
c = iscalar("c")
z = IfElse(1)(c, x, y)
assert z.type.shape == exp_shape
f = function([c, x, y], z, mode=self.mode)
x_val = x_val.astype(self.dtype)
y_val = y_val.astype(self.dtype)
val = f(0, x_val, y_val)
assert np.array_equal(val, y_val)
def test_nonscalar_condition(self):
x = vector("x")
y = vector("y")
c = ivector("c")
with pytest.raises(TypeError, match="The condition argument"):
IfElse(1)(c, x, y)
class IfElseIfElseIf(Op):
def __init__(self, inplace=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论