提交 9768905a authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Fix local_subtensor_SpecifyShape_lift

上级 434a70b9
......@@ -1632,7 +1632,7 @@ def local_subtensor_shape_constant(fgraph, node):
@register_canonicalize
@local_optimizer([Subtensor])
def local_subtensor_SpecifyShape_lift(fgraph, node):
"""Lift ``specify_shape(x, s)[i]`` to ``specify_shape(x[i], s[i])``."""
"""Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``."""
if not isinstance(node.op, Subtensor):
return False
......@@ -1650,19 +1650,14 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list)
if len(indices) > 1 or len(indices) == 0:
return False
if isinstance(indices[0], slice) or isinstance(
getattr(indices[0], "type", None), SliceType
if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
return False
new_obj_arg = obj_arg[indices]
# No need to specify shape for scalar outputs
if new_obj_arg.ndim == 0:
new_shape_arg = as_tensor([], dtype=np.int64)
else:
new_shape_arg = as_tensor(shape_arg[indices], dtype=np.int64, ndim=1)
return [specify_shape(new_obj_arg, new_shape_arg)]
return [new_obj_arg]
return [specify_shape(new_obj_arg, shape_arg[len(indices) :])]
......@@ -2102,6 +2102,13 @@ def test_local_subtensor_shape_constant():
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val",
[
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
......@@ -2110,16 +2117,38 @@ def test_local_subtensor_shape_constant():
np.array([2, 2], dtype=np.int64),
),
(
vector(),
(iscalar(),),
(1,),
np.array([1, 2], dtype=config.floatX),
np.array([2], dtype=np.int64),
matrix(),
(iscalar(), iscalar()),
(0,),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
matrix(),
(iscalar(), iscalar()),
(1, 1),
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
np.array([2, 3], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1,),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
(
tensor3(),
(iscalar(), iscalar(), iscalar()),
(-1, 0),
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
np.array([2, 3, 5], dtype=np.int64),
),
],
)
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
opts = OptimizationQuery(include=[None])
no_opt_mode = Mode(optimizer=opts)
......@@ -2130,7 +2159,12 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
# This optimization should appear in the canonicalizations
y_opt = optimize_graph(y, clone=False)
assert isinstance(y_opt.owner.op, SpecifyShape)
if y.ndim == 0:
# SpecifyShape should be removed altogether
assert isinstance(y_opt.owner.op, Subtensor)
assert y_opt.owner.inputs[0] is x
else:
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val] + [s_ for s_ in s_val]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论