提交 9768905a authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix local_subtensor_SpecifyShape_lift

上级 434a70b9
...@@ -1632,7 +1632,7 @@ def local_subtensor_shape_constant(fgraph, node): ...@@ -1632,7 +1632,7 @@ def local_subtensor_shape_constant(fgraph, node):
@register_canonicalize @register_canonicalize
@local_optimizer([Subtensor]) @local_optimizer([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node): def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i]`` to ``specify_shape(x[i], s[i])``.""" """Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
if not isinstance(node.op, Subtensor): if not isinstance(node.op, Subtensor):
return False return False
...@@ -1650,19 +1650,14 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -1650,19 +1650,14 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list) indices = get_idx_list(node.inputs, node.op.idx_list)
if len(indices) > 1 or len(indices) == 0: if any(
return False isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
if isinstance(indices[0], slice) or isinstance(
getattr(indices[0], "type", None), SliceType
): ):
return False return False
new_obj_arg = obj_arg[indices] new_obj_arg = obj_arg[indices]
# No need to specify shape for scalar outputs
if new_obj_arg.ndim == 0: if new_obj_arg.ndim == 0:
new_shape_arg = as_tensor([], dtype=np.int64) return [new_obj_arg]
else: return [specify_shape(new_obj_arg, shape_arg[len(indices) :])]
new_shape_arg = as_tensor(shape_arg[indices], dtype=np.int64, ndim=1)
return [specify_shape(new_obj_arg, new_shape_arg)]
...@@ -2102,6 +2102,13 @@ def test_local_subtensor_shape_constant(): ...@@ -2102,6 +2102,13 @@ def test_local_subtensor_shape_constant():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, s, idx, x_val, s_val", "x, s, idx, x_val, s_val",
[ [
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
...@@ -2110,16 +2117,38 @@ def test_local_subtensor_shape_constant(): ...@@ -2110,16 +2117,38 @@ def test_local_subtensor_shape_constant():
np.array([2, 2], dtype=np.int64), np.array([2, 2], dtype=np.int64),
), ),
( (
vector(), matrix(),
(iscalar(),), (iscalar(), iscalar()),
(1,), (0,),
np.array([1, 2], dtype=config.floatX), np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2], dtype=np.int64), np.array([2, 3], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 1),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1,),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1, 0),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
), ),
], ],
) )
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx] y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
opts = OptimizationQuery(include=[None]) opts = OptimizationQuery(include=[None])
no_opt_mode = Mode(optimizer=opts) no_opt_mode = Mode(optimizer=opts)
...@@ -2130,6 +2159,11 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -2130,6 +2159,11 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
# This optimization should appear in the canonicalizations # This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False) y_opt = optimize_graph(y, clone=False)
if y.ndim == 0:
# SpecifyShape should be removed altogether
assert isinstance(y_opt.owner.op, Subtensor)
assert y_opt.owner.inputs[0] is x
else:
assert isinstance(y_opt.owner.op, SpecifyShape) assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore") y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论