提交 b8e6b3ea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor and fix static shape issues in IfElse

上级 01c4a55f
差异被折叠。
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import aesara import aesara
import aesara.ifelse import aesara.ifelse
import aesara.sparse
import aesara.tensor.basic as at import aesara.tensor.basic as at
from aesara import function from aesara import function
from aesara.compile.mode import Mode, get_mode from aesara.compile.mode import Mode, get_mode
...@@ -14,15 +15,19 @@ from aesara.graph.op import Op ...@@ -14,15 +15,19 @@ from aesara.graph.op import Op
from aesara.ifelse import IfElse, ifelse from aesara.ifelse import IfElse, ifelse
from aesara.link.c.type import generic from aesara.link.c.type import generic
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.type import col, iscalar, matrix, row, scalar, tensor3, vector from aesara.tensor.type import (
col,
iscalar,
ivector,
matrix,
row,
scalar,
tensor3,
vector,
)
from tests import unittest_tools as utt from tests import unittest_tools as utt
__docformat__ = "restructedtext en"
__authors__ = "Razvan Pascanu " "PyMC Development Team " "Aesara Developers "
__copyright__ = "(c) 2010, Universite de Montreal"
class TestIfelse(utt.OptimizationTestMixin): class TestIfelse(utt.OptimizationTestMixin):
mode = None mode = None
dtype = aesara.config.floatX dtype = aesara.config.floatX
...@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin):
with pytest.raises(ValueError): with pytest.raises(ValueError):
IfElse(0)(c, x, x) IfElse(0)(c, x, x)
def test_const_Op_argument(self): def test_const_false_branch(self):
x = vector("x", dtype=self.dtype) x = vector("x", dtype=self.dtype)
y = np.array([2.0, 3.0], dtype=self.dtype) y = np.array([2.0, 3.0], dtype=self.dtype)
c = iscalar("c") c = iscalar("c")
...@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin):
ifelse(cond, y, x) ifelse(cond, y, x)
def test_sparse_tensor_error(self): def test_sparse_tensor_error(self):
pytest.importorskip("scipy", minversion="0.7.0")
import aesara.sparse
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
data = rng.random((2, 3)).astype(self.dtype) data = rng.random((2, 3)).astype(self.dtype)
...@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin):
res.owner.op.as_view = True res.owner.op.as_view = True
assert str(res.owner).startswith("if{name,inplace}") assert str(res.owner).startswith("if{name,inplace}")
@pytest.mark.parametrize(
"x_shape, y_shape, x_val, y_val, exp_shape",
[
((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)),
((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)),
],
)
def test_static_branch_shapes(self, x_shape, y_shape, x_val, y_val, exp_shape):
x = at.tensor(dtype=self.dtype, shape=x_shape, name="x")
y = at.tensor(dtype=self.dtype, shape=y_shape, name="y")
c = iscalar("c")
z = IfElse(1)(c, x, y)
assert z.type.shape == exp_shape
f = function([c, x, y], z, mode=self.mode)
x_val = x_val.astype(self.dtype)
y_val = y_val.astype(self.dtype)
val = f(0, x_val, y_val)
assert np.array_equal(val, y_val)
def test_nonscalar_condition(self):
x = vector("x")
y = vector("y")
c = ivector("c")
with pytest.raises(TypeError, match="The condition argument"):
IfElse(1)(c, x, y)
class IfElseIfElseIf(Op): class IfElseIfElseIf(Op):
def __init__(self, inplace=False): def __init__(self, inplace=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论