提交 07706002 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Group subtensor specify_shape lift tests in class

上级 ccbab653
......@@ -247,106 +247,106 @@ def test_local_subtensor_of_alloc():
assert xval.__getitem__(slices).shape == val.shape
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val",
[
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1,),
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([2, 2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(0,),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 1),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1,),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1, 0),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
],
)
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
rewrites = RewriteDatabaseQuery(include=[None])
no_rewrites_mode = Mode(optimizer=rewrites)
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
y_val = y_val_fn(*([x_val, *s_val]))
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether
assert isinstance(y_opt.owner.op, Subtensor)
assert y_opt.owner.inputs[0] is x
else:
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val, *s_val]))
assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize(
"x, s, idx",
[
(
matrix(),
(iscalar(), iscalar()),
(slice(1, None),),
),
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 0),
),
],
)
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
class TestLocalSubtensorSpecifyShapeLift:
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val",
[
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1,),
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([2, 2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(0,),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 1),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1,),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1, 0),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
],
)
def test_local_subtensor_SpecifyShape_lift(self, x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
rewrites = RewriteDatabaseQuery(include=[None])
no_rewrites_mode = Mode(optimizer=rewrites)
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
y_val = y_val_fn(*([x_val, *s_val]))
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether
assert isinstance(y_opt.owner.op, Subtensor)
assert y_opt.owner.inputs[0] is x
else:
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val, *s_val]))
assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize(
"x, s, idx",
[
(
matrix(),
(iscalar(), iscalar()),
(slice(1, None),),
),
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 0),
),
],
)
def test_local_subtensor_SpecifyShape_lift_fail(self, x, s, idx):
y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
class TestLocalSubtensorMakeVector:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论