提交 07706002 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Group subtensor specify_shape lift tests in class

上级 ccbab653
...@@ -247,106 +247,106 @@ def test_local_subtensor_of_alloc(): ...@@ -247,106 +247,106 @@ def test_local_subtensor_of_alloc():
assert xval.__getitem__(slices).shape == val.shape assert xval.__getitem__(slices).shape == val.shape
@pytest.mark.parametrize( class TestLocalSubtensorSpecifyShapeLift:
"x, s, idx, x_val, s_val", @pytest.mark.parametrize(
[ "x, s, idx, x_val, s_val",
( [
vector(), (
(iscalar(),), vector(),
(1,), (iscalar(),),
np.array([1, 2], dtype=config.floatX), (1,),
np.array([2], dtype=np.int64), np.array([1, 2], dtype=config.floatX),
), np.array([2], dtype=np.int64),
( ),
matrix(), (
(iscalar(), iscalar()), matrix(),
(1,), (iscalar(), iscalar()),
np.array([[1, 2], [3, 4]], dtype=config.floatX), (1,),
np.array([2, 2], dtype=np.int64), np.array([[1, 2], [3, 4]], dtype=config.floatX),
), np.array([2, 2], dtype=np.int64),
( ),
matrix(), (
(iscalar(), iscalar()), matrix(),
(0,), (iscalar(), iscalar()),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), (0,),
np.array([2, 3], dtype=np.int64), np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
), np.array([2, 3], dtype=np.int64),
( ),
matrix(), (
(iscalar(), iscalar()), matrix(),
(1, 1), (iscalar(), iscalar()),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX), (1, 1),
np.array([2, 3], dtype=np.int64), np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
), np.array([2, 3], dtype=np.int64),
( ),
tensor3(), (
(iscalar(), iscalar(), iscalar()), tensor3(),
(-1,), (iscalar(), iscalar(), iscalar()),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), (-1,),
np.array([2, 3, 5], dtype=np.int64), np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
), np.array([2, 3, 5], dtype=np.int64),
( ),
tensor3(), (
(iscalar(), iscalar(), iscalar()), tensor3(),
(-1, 0), (iscalar(), iscalar(), iscalar()),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)), (-1, 0),
np.array([2, 3, 5], dtype=np.int64), np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
), np.array([2, 3, 5], dtype=np.int64),
], ),
) ],
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): )
y = specify_shape(x, s)[idx] def test_local_subtensor_SpecifyShape_lift(self, x, s, idx, x_val, s_val):
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
rewrites = RewriteDatabaseQuery(include=[None])
no_rewrites_mode = Mode(optimizer=rewrites) rewrites = RewriteDatabaseQuery(include=[None])
no_rewrites_mode = Mode(optimizer=rewrites)
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
y_val = y_val_fn(*([x_val, *s_val])) y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
y_val = y_val_fn(*([x_val, *s_val]))
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False) # This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether if y.ndim == 0:
assert isinstance(y_opt.owner.op, Subtensor) # SpecifyShape should be removed altogether
assert y_opt.owner.inputs[0] is x assert isinstance(y_opt.owner.op, Subtensor)
else: assert y_opt.owner.inputs[0] is x
assert isinstance(y_opt.owner.op, SpecifyShape) else:
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val, *s_val])) y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val, *s_val]))
assert np.allclose(y_val, y_opt_val)
assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, s, idx", "x, s, idx",
[ [
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
(slice(1, None),), (slice(1, None),),
), ),
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
(slicetype(),), (slicetype(),),
), ),
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
(1, 0), (1, 0),
), ),
], ],
) )
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): def test_local_subtensor_SpecifyShape_lift_fail(self, x, s, idx):
y = specify_shape(x, s)[idx] y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations # This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False) y_opt = rewrite_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape) assert not isinstance(y_opt.owner.op, SpecifyShape)
class TestLocalSubtensorMakeVector: class TestLocalSubtensorMakeVector:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论