提交 181d5566 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Allow partial shape information in `SpecifyShape` `Op`

上级 383600bc
......@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs):
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, shape):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
"got shape",
......
......@@ -2,6 +2,7 @@ import operator
import warnings
from contextlib import contextmanager
from functools import singledispatch
from textwrap import dedent
import numba
import numba.np.unsafe.ndarray as numba_ndarray
......@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import (
Subtensor,
)
from aesara.tensor.type import TensorType
from aesara.tensor.type_other import MakeSlice
from aesara.tensor.type_other import MakeSlice, NoneConst
def numba_njit(*args, **kwargs):
......@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs):
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs):
@numba_njit
def specifyshape(x, shape):
assert np.array_equal(x.shape, shape)
return x
def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_inputs = node.inputs[1:]
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
zip(shape_inputs, shape_input_names)
)
if shape_input is not NoneConst
]
func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
)
return specifyshape
specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape)
def int_to_float_fn(inputs, out_dtype):
......
......@@ -64,6 +64,7 @@ from aesara.tensor.basic import (
join,
ones_like,
patternbroadcast,
stack,
switch,
tensor_copy,
unbroadcast,
......@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
shape_i,
shape_padleft,
)
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import (
......@@ -84,6 +92,7 @@ from aesara.tensor.type import (
discrete_dtypes,
integer_dtypes,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter
......@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node):
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
return False
return [specified_shape.owner.inputs[1].astype(np.int64)]
x, *shape = specified_shape.owner.inputs
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
return [stack(shape).astype(np.int64)]
@register_useless
......
差异被折叠。
......@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
return False
obj_arg = specify_shape_node.owner.inputs[0]
shape_arg = specify_shape_node.owner.inputs[1]
shape_arg = specify_shape_node.owner.inputs[1:]
indices = get_idx_list(node.inputs, node.op.idx_list)
......
......@@ -185,7 +185,7 @@ def test_jax_specify_shape():
with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(at.as_tensor_variable(x_np), (2, 3))
x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3))
x_fg = FunctionGraph([], [x])
with pytest.raises(AssertionError):
......
......@@ -896,10 +896,15 @@ def test_Reshape_scalar():
(1, 1),
True,
),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None),
False,
),
],
)
def test_SpecifyShape(v, shape, fails):
g = SpecifyShape()(v, shape)
g = SpecifyShape()(v, *shape)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm:
......
......@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape):
assert shape in fgraph.variables
@pytest.mark.parametrize(
"s1",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape_partial(s1):
x = matrix()
s = specify_shape(x, (s1, None)).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = optimize_graph(fgraph, clone=False)
assert x in fgraph.variables
assert s1 in fgraph.variables
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True])
s = Shape_i(1)(x)
......
......@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester):
specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
with pytest.raises(TypeError, match="must be integer types"):
_specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
_specify_shape([[1, 2, 3], [4, 5, 6]], *(2.2, 3))
with pytest.raises(ValueError, match="will never match"):
specify_shape(matrix(), [4])
with pytest.raises(ValueError, match="will never match"):
_specify_shape(matrix(), [4])
_specify_shape(matrix(), *[4])
with pytest.raises(ValueError, match="must have fixed dimensions"):
specify_shape(matrix(), vector(dtype="int32"))
......@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester):
f = aesara.function([x], y, mode=self.mode)
assert f([15]) == [15]
def test_partial_shapes(self):
x = matrix()
s1 = lscalar()
y = specify_shape(x, (s1, None))
f = aesara.function([x, s1], y, mode=self.mode)
assert f(np.zeros((2, 5), dtype=config.floatX), 2).shape == (2, 5)
assert f(np.zeros((3, 5), dtype=config.floatX), 3).shape == (3, 5)
def test_fixed_shapes(self):
x = vector()
shape = as_tensor_variable([2])
......@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester):
assert y.type.shape == (2,)
assert y.shape.equals(shape)
def test_fixed_partial_shapes(self):
x = TensorType("floatX", (None, None))("x")
y = specify_shape(x, (None, 5))
assert y.type.shape == (None, 5)
x = TensorType("floatX", (3, None))("x")
y = specify_shape(x, (None, 5))
assert y.type.shape == (3, 5)
def test_python_perform(self):
"""Test the Python `Op.perform` implementation."""
x = scalar()
......@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester):
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
assert f([1], (2,)) == [1]
x = matrix()
y = specify_shape(x, (None, 2))
f = aesara.function([x], y, mode=Mode("py"))
assert f(np.zeros((3, 2), dtype=config.floatX)).shape == (3, 2)
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
assert f(np.zeros((3, 3), dtype=config.floatX))
def test_bad_shape(self):
"""Test that at run-time we raise an exception when the shape is not the one specified."""
specify_shape = SpecifyShape()
x = vector()
xval = np.random.random((2)).astype(config.floatX)
f = aesara.function([x], specify_shape(x, [2]), mode=self.mode)
f = aesara.function([x], specify_shape(x, 2), mode=self.mode)
assert np.array_equal(f(xval), xval)
......@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester):
x = matrix()
xval = np.random.random((2, 3)).astype(config.floatX)
f = aesara.function([x], specify_shape(x, [2, 3]), mode=self.mode)
f = aesara.function([x], specify_shape(x, 2, 3), mode=self.mode)
assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
.inputs[0]
......@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester):
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
f(xval)
s = iscalar("s")
f = aesara.function([x, s], specify_shape(x, None, s), mode=self.mode)
x_val = np.zeros((3, 2), dtype=config.floatX)
assert f(x_val, 2).shape == (3, 2)
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
f(xval, 3)
def test_infer_shape(self):
rng = np.random.default_rng(3453)
adtens4 = dtensor4()
......@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester):
SpecifyShape,
)
def test_infer_shape_partial(self):
rng = np.random.default_rng(3453)
adtens4 = dtensor4()
aivec = [iscalar(), iscalar(), None, iscalar()]
aivec_val = [3, 4, 5]
adtens4_val = rng.random((3, 4, 2, 5))
self._compile_and_check(
[adtens4, *(ivec for ivec in aivec if ivec is not None)],
[specify_shape(adtens4, aivec)],
[adtens4_val, *aivec_val],
SpecifyShape,
)
class TestRopLop(RopLopChecker):
def test_shape(self):
......
......@@ -474,7 +474,7 @@ def makeSharedTester(
assert np.all(self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2))
topo_specify = specify_shape_fct.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE":
assert len(topo_specify) == 4
assert len(topo_specify) == 3
# Test that we put the shape info into the graph
shape_constant_fct = aesara.function([], x1_specify_shape.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论