提交 181d5566 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Allow partial shape information in `SpecifyShape` `Op`

上级 383600bc
...@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs): ...@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs):
@jax_funcify.register(SpecifyShape) @jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs): def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, shape): def specifyshape(x, *shape):
assert x.ndim == len(shape) assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), ( assert jnp.all(x.shape == tuple(shape)), (
"got shape", "got shape",
......
...@@ -2,6 +2,7 @@ import operator ...@@ -2,6 +2,7 @@ import operator
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch from functools import singledispatch
from textwrap import dedent
import numba import numba
import numba.np.unsafe.ndarray as numba_ndarray import numba.np.unsafe.ndarray as numba_ndarray
...@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import ( ...@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import (
Subtensor, Subtensor,
) )
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice, NoneConst
def numba_njit(*args, **kwargs): def numba_njit(*args, **kwargs):
...@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs): ...@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs):
@numba_funcify.register(SpecifyShape) @numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs): def numba_funcify_SpecifyShape(op, node, **kwargs):
@numba_njit shape_inputs = node.inputs[1:]
def specifyshape(x, shape): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
assert np.array_equal(x.shape, shape)
return x func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
zip(shape_inputs, shape_input_names)
)
if shape_input is not NoneConst
]
func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
)
return specifyshape specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape)
def int_to_float_fn(inputs, out_dtype): def int_to_float_fn(inputs, out_dtype):
......
...@@ -64,6 +64,7 @@ from aesara.tensor.basic import ( ...@@ -64,6 +64,7 @@ from aesara.tensor.basic import (
join, join,
ones_like, ones_like,
patternbroadcast, patternbroadcast,
stack,
switch, switch,
tensor_copy, tensor_copy,
unbroadcast, unbroadcast,
...@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError ...@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft from aesara.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
shape_i,
shape_padleft,
)
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -84,6 +92,7 @@ from aesara.tensor.type import ( ...@@ -84,6 +92,7 @@ from aesara.tensor.type import (
discrete_dtypes, discrete_dtypes,
integer_dtypes, integer_dtypes,
) )
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
...@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node):
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape): if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
return False return False
return [specified_shape.owner.inputs[1].astype(np.int64)] x, *shape = specified_shape.owner.inputs
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
return [stack(shape).astype(np.int64)]
@register_useless @register_useless
......
差异被折叠。
...@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
return False return False
obj_arg = specify_shape_node.owner.inputs[0] obj_arg = specify_shape_node.owner.inputs[0]
shape_arg = specify_shape_node.owner.inputs[1] shape_arg = specify_shape_node.owner.inputs[1:]
indices = get_idx_list(node.inputs, node.op.idx_list) indices = get_idx_list(node.inputs, node.op.idx_list)
......
...@@ -185,7 +185,7 @@ def test_jax_specify_shape(): ...@@ -185,7 +185,7 @@ def test_jax_specify_shape():
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(at.as_tensor_variable(x_np), (2, 3)) x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
......
...@@ -896,10 +896,15 @@ def test_Reshape_scalar(): ...@@ -896,10 +896,15 @@ def test_Reshape_scalar():
(1, 1), (1, 1),
True, True,
), ),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None),
False,
),
], ],
) )
def test_SpecifyShape(v, shape, fails): def test_SpecifyShape(v, shape, fails):
g = SpecifyShape()(v, shape) g = SpecifyShape()(v, *shape)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm: with cm:
......
...@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape): ...@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape):
assert shape in fgraph.variables assert shape in fgraph.variables
@pytest.mark.parametrize(
"s1",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape_partial(s1):
x = matrix()
s = specify_shape(x, (s1, None)).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = optimize_graph(fgraph, clone=False)
assert x in fgraph.variables
assert s1 in fgraph.variables
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_Shape_i_of_broadcastable(): def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True]) x = tensor(np.float64, [False, True])
s = Shape_i(1)(x) s = Shape_i(1)(x)
......
...@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester):
specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3)) specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
with pytest.raises(TypeError, match="must be integer types"): with pytest.raises(TypeError, match="must be integer types"):
_specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3)) _specify_shape([[1, 2, 3], [4, 5, 6]], *(2.2, 3))
with pytest.raises(ValueError, match="will never match"): with pytest.raises(ValueError, match="will never match"):
specify_shape(matrix(), [4]) specify_shape(matrix(), [4])
with pytest.raises(ValueError, match="will never match"): with pytest.raises(ValueError, match="will never match"):
_specify_shape(matrix(), [4]) _specify_shape(matrix(), *[4])
with pytest.raises(ValueError, match="must have fixed dimensions"): with pytest.raises(ValueError, match="must have fixed dimensions"):
specify_shape(matrix(), vector(dtype="int32")) specify_shape(matrix(), vector(dtype="int32"))
...@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester):
f = aesara.function([x], y, mode=self.mode) f = aesara.function([x], y, mode=self.mode)
assert f([15]) == [15] assert f([15]) == [15]
def test_partial_shapes(self):
x = matrix()
s1 = lscalar()
y = specify_shape(x, (s1, None))
f = aesara.function([x, s1], y, mode=self.mode)
assert f(np.zeros((2, 5), dtype=config.floatX), 2).shape == (2, 5)
assert f(np.zeros((3, 5), dtype=config.floatX), 3).shape == (3, 5)
def test_fixed_shapes(self): def test_fixed_shapes(self):
x = vector() x = vector()
shape = as_tensor_variable([2]) shape = as_tensor_variable([2])
...@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester):
assert y.type.shape == (2,) assert y.type.shape == (2,)
assert y.shape.equals(shape) assert y.shape.equals(shape)
def test_fixed_partial_shapes(self):
x = TensorType("floatX", (None, None))("x")
y = specify_shape(x, (None, 5))
assert y.type.shape == (None, 5)
x = TensorType("floatX", (3, None))("x")
y = specify_shape(x, (None, 5))
assert y.type.shape == (3, 5)
def test_python_perform(self): def test_python_perform(self):
"""Test the Python `Op.perform` implementation.""" """Test the Python `Op.perform` implementation."""
x = scalar() x = scalar()
...@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester):
with pytest.raises(AssertionError, match="SpecifyShape:.*"): with pytest.raises(AssertionError, match="SpecifyShape:.*"):
assert f([1], (2,)) == [1] assert f([1], (2,)) == [1]
x = matrix()
y = specify_shape(x, (None, 2))
f = aesara.function([x], y, mode=Mode("py"))
assert f(np.zeros((3, 2), dtype=config.floatX)).shape == (3, 2)
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
assert f(np.zeros((3, 3), dtype=config.floatX))
def test_bad_shape(self): def test_bad_shape(self):
"""Test that at run-time we raise an exception when the shape is not the one specified.""" """Test that at run-time we raise an exception when the shape is not the one specified."""
specify_shape = SpecifyShape() specify_shape = SpecifyShape()
x = vector() x = vector()
xval = np.random.random((2)).astype(config.floatX) xval = np.random.random((2)).astype(config.floatX)
f = aesara.function([x], specify_shape(x, [2]), mode=self.mode) f = aesara.function([x], specify_shape(x, 2), mode=self.mode)
assert np.array_equal(f(xval), xval) assert np.array_equal(f(xval), xval)
...@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester):
x = matrix() x = matrix()
xval = np.random.random((2, 3)).astype(config.floatX) xval = np.random.random((2, 3)).astype(config.floatX)
f = aesara.function([x], specify_shape(x, [2, 3]), mode=self.mode) f = aesara.function([x], specify_shape(x, 2, 3), mode=self.mode)
assert isinstance( assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
.inputs[0] .inputs[0]
...@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester):
with pytest.raises(AssertionError, match="SpecifyShape:.*"): with pytest.raises(AssertionError, match="SpecifyShape:.*"):
f(xval) f(xval)
s = iscalar("s")
f = aesara.function([x, s], specify_shape(x, None, s), mode=self.mode)
x_val = np.zeros((3, 2), dtype=config.floatX)
assert f(x_val, 2).shape == (3, 2)
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
f(xval, 3)
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(3453) rng = np.random.default_rng(3453)
adtens4 = dtensor4() adtens4 = dtensor4()
...@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester):
SpecifyShape, SpecifyShape,
) )
def test_infer_shape_partial(self):
rng = np.random.default_rng(3453)
adtens4 = dtensor4()
aivec = [iscalar(), iscalar(), None, iscalar()]
aivec_val = [3, 4, 5]
adtens4_val = rng.random((3, 4, 2, 5))
self._compile_and_check(
[adtens4, *(ivec for ivec in aivec if ivec is not None)],
[specify_shape(adtens4, aivec)],
[adtens4_val, *aivec_val],
SpecifyShape,
)
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_shape(self): def test_shape(self):
......
...@@ -474,7 +474,7 @@ def makeSharedTester( ...@@ -474,7 +474,7 @@ def makeSharedTester(
assert np.all(self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2)) assert np.all(self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2))
topo_specify = specify_shape_fct.maker.fgraph.toposort() topo_specify = specify_shape_fct.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE": if aesara.config.mode != "FAST_COMPILE":
assert len(topo_specify) == 4 assert len(topo_specify) == 3
# Test that we put the shape info into the graph # Test that we put the shape info into the graph
shape_constant_fct = aesara.function([], x1_specify_shape.shape) shape_constant_fct = aesara.function([], x1_specify_shape.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论