提交 181d5566 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Allow partial shape information in `SpecifyShape` `Op`

上级 383600bc
...@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs): ...@@ -335,7 +335,7 @@ def jax_funcify_Shape_i(op, **kwargs):
@jax_funcify.register(SpecifyShape) @jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs): def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, shape): def specifyshape(x, *shape):
assert x.ndim == len(shape) assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), ( assert jnp.all(x.shape == tuple(shape)), (
"got shape", "got shape",
......
...@@ -2,6 +2,7 @@ import operator ...@@ -2,6 +2,7 @@ import operator
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch from functools import singledispatch
from textwrap import dedent
import numba import numba
import numba.np.unsafe.ndarray as numba_ndarray import numba.np.unsafe.ndarray as numba_ndarray
...@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import ( ...@@ -40,7 +41,7 @@ from aesara.tensor.subtensor import (
Subtensor, Subtensor,
) )
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice, NoneConst
def numba_njit(*args, **kwargs): def numba_njit(*args, **kwargs):
...@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs): ...@@ -609,13 +610,28 @@ def numba_funcify_Reshape(op, **kwargs):
@numba_funcify.register(SpecifyShape) @numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs): def numba_funcify_SpecifyShape(op, node, **kwargs):
@numba_njit shape_inputs = node.inputs[1:]
def specifyshape(x, shape): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
assert np.array_equal(x.shape, shape)
func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
zip(shape_inputs, shape_input_names)
)
if shape_input is not NoneConst
]
func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
return x return x
"""
)
return specifyshape specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape)
def int_to_float_fn(inputs, out_dtype): def int_to_float_fn(inputs, out_dtype):
......
...@@ -64,6 +64,7 @@ from aesara.tensor.basic import ( ...@@ -64,6 +64,7 @@ from aesara.tensor.basic import (
join, join,
ones_like, ones_like,
patternbroadcast, patternbroadcast,
stack,
switch, switch,
tensor_copy, tensor_copy,
unbroadcast, unbroadcast,
...@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError ...@@ -75,7 +76,14 @@ from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft from aesara.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
shape_i,
shape_padleft,
)
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -84,6 +92,7 @@ from aesara.tensor.type import ( ...@@ -84,6 +92,7 @@ from aesara.tensor.type import (
discrete_dtypes, discrete_dtypes,
integer_dtypes, integer_dtypes,
) )
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
...@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -3521,7 +3530,14 @@ def local_Shape_of_SpecifyShape(fgraph, node):
if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape): if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape):
return False return False
return [specified_shape.owner.inputs[1].astype(np.int64)] x, *shape = specified_shape.owner.inputs
# Replace `NoneConst` by `shape_i`
for i, sh in enumerate(shape):
if NoneConst.equals(sh):
shape[i] = shape_i(x, i, fgraph)
return [stack(shape).astype(np.int64)]
@register_useless @register_useless
......
import warnings import warnings
from numbers import Number from numbers import Number
from textwrap import dedent
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
import numpy as np import numpy as np
import aesara import aesara
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Variable
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -15,7 +16,8 @@ from aesara.tensor import _get_vector_length ...@@ -15,7 +16,8 @@ from aesara.tensor import _get_vector_length
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor from aesara.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorConstant, TensorVariable from aesara.tensor.var import TensorConstant, TensorVariable
...@@ -362,28 +364,6 @@ def register_shape_i_c_code(typ, code, check_input, version=()): ...@@ -362,28 +364,6 @@ def register_shape_i_c_code(typ, code, check_input, version=()):
Shape_i.c_code_and_version[typ] = (code, check_input, version) Shape_i.c_code_and_version[typ] = (code, check_input, version)
def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=None):
"""
Tell SpecifyShape how to generate C code for an Aesara Type.
Parameters
----------
typ : Aesara type
It must be the Aesara class itself and not an instance of the class.
code : C code
Checks the shape and returns a view for the Aesara type 'typ'.
Use %(iname)s and %(oname)s for the input and output C variable names
respectively. %(shape)s is the vector of shape of %(iname)s.
Check that its length is good.
version
A number indicating the version of the code, for cache.
c_support_code_apply
Extra code.
"""
SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply)
class SpecifyShape(COp): class SpecifyShape(COp):
""" """
L{Op} that puts into the graph the user-provided shape. L{Op} that puts into the graph the user-provided shape.
...@@ -396,33 +376,29 @@ class SpecifyShape(COp): ...@@ -396,33 +376,29 @@ class SpecifyShape(COp):
Notes Notes
----- -----
Maybe in the future we will never do the assert! Maybe in the future we will never do the assert!
We currently don't support specifying partial shape information.
TODO : test this op with sparse. Do C code for them too.
""" """
view_map = {0: [0]} view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: Dict = {}
__props__ = () __props__ = ()
_f16_ok = True _f16_ok = True
def make_node(self, x, shape): def make_node(self, x, *shape):
if not isinstance(x, Variable): from aesara.tensor.basic import get_scalar_constant_value
x = at.as_tensor_variable(x)
shape = at.as_tensor_variable(shape, ndim=1) x = at.as_tensor_variable(x)
if isinstance(shape, Constant): shape = tuple(
shape = tuple(shape.data) NoneConst
else: if (s is None or NoneConst.equals(s))
shape = tuple(at.as_tensor_variable(s, ndim=0) for s in shape) else at.as_tensor_variable(s, ndim=0)
for s in shape
)
if any(s.dtype not in aesara.tensor.type.integer_dtypes for s in shape): if any(
s.dtype not in aesara.tensor.type.integer_dtypes
for s in shape
if hasattr(s, "dtype")
):
raise TypeError("Shape values must be integer types") raise TypeError("Shape values must be integer types")
if len(shape) != x.type.ndim: if len(shape) != x.type.ndim:
...@@ -430,102 +406,127 @@ class SpecifyShape(COp): ...@@ -430,102 +406,127 @@ class SpecifyShape(COp):
f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}." f"Input `x` is {x.type.ndim}-dimensional and will never match a shape of length {len(shape)}."
) )
if isinstance(x.type, TensorType) and all(isinstance(s, Number) for s in shape): type_shape = [None] * x.ndim
out_var = x.type.clone(shape=shape)() for i, (xts, s) in enumerate(zip(x.type.shape, shape)):
if xts is not None:
type_shape[i] = xts
else: else:
out_var = x.type() try:
type_s = get_scalar_constant_value(s)
if type_s is not None:
type_shape[i] = int(type_s)
except NotScalarConstantError:
pass
out_var = x.type.clone(shape=type_shape)()
in_shape = at.as_tensor_variable(shape, ndim=1) return Apply(self, [x, *shape], [out_var])
return Apply(self, [x, in_shape], [out_var])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, shape = inp x, *shape = inp
(out,) = out_ (out,) = out_
ndim = len(shape) ndim = len(shape)
if x.ndim != ndim: if x.ndim != ndim:
raise AssertionError( raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
) )
if x.shape != tuple(shape): if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
raise AssertionError( raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(shape)}." f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
) )
out[0] = x out[0] = x
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
xshape, sshape = shapes xshape, *_ = shapes
shape = node.inputs[1:]
new_shape = [] new_shape = []
for dim in range(node.inputs[0].type.ndim): for dim in range(node.inputs[0].type.ndim):
s = shape[dim]
try: try:
s = at.get_scalar_constant_value(node.inputs[1][dim]) s = at.get_scalar_constant_value(s)
s = at.as_tensor_variable(s) # We assume that `None` shapes are always retrieved by
new_shape.append(s) # `get_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable
if s is None:
s = xshape[dim]
except NotScalarConstantError: except NotScalarConstantError:
new_shape.append(node.inputs[1][dim]) pass
new_shape.append(at.as_tensor_variable(s))
assert len(new_shape) == len(xshape) assert len(new_shape) == len(xshape)
return [new_shape] return [new_shape]
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [False]] return [[True], *[[False]] * len(node.inputs[1:])]
def grad(self, inp, grads): def grad(self, inp, grads):
x, s = inp x, *shape = inp
(gz,) = grads (gz,) = grads
# Should I set an SpecifyShape on gz? I think so # Should I set an SpecifyShape on gz? I think so
# But I don't do it now as we need to make an optimization # But I don't do it now as we need to make an optimization
# to remove that op from the graph to don't block other optimization # to remove that op from the graph to don't block other optimization
# Should I do an optimizer that will remove the SpecifyShape? # Should I do an optimizer that will remove the SpecifyShape?
# I think Yes # I think Yes
return [gz, aesara.gradient.DisconnectedType()()] # return [specify_shape(gz, s)] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape))]
return [specify_shape(gz, s), aesara.gradient.DisconnectedType()()] return [gz] + [aesara.gradient.DisconnectedType()() for _ in range(len(shape))]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
# It means that the this op sits on top of a non-differentiable # It means that this op sits on top of a non-differentiable path
# path
return [None] return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def c_support_code_apply(self, node, name): def c_code(self, node, name, i_names, o_names, sub):
itype = node.inputs[0].type.__class__ if not isinstance(node.inputs[0].type, DenseTensorType):
if itype in self.c_code_and_version: raise NotImplementedError(
_, _, support_code = self.c_code_and_version[itype] f"Specify_shape c_code not implemented for input type {node.inputs[0].type}"
if support_code: )
return support_code
return super().c_support_code_apply(node, name)
def c_code(self, node, name, inames, onames, sub): x_name, *shape_names = i_names
iname, shape = inames (o_name,) = o_names
(oname,) = onames
fail = sub["fail"] fail = sub["fail"]
itype = node.inputs[0].type.__class__ code = dedent(
if itype in self.c_code_and_version: f"""
code, version, _ = self.c_code_and_version[itype] if (PyArray_NDIM({x_name}) != {len(shape_names)}) {{
return code % locals() PyErr_Format(PyExc_AssertionError,
"SpecifyShape: Got %d dimensions, expected %d dimensions.",
PyArray_NDIM({x_name}), {len(shape_names)}
);
{fail};
}}
"""
)
raise NotImplementedError() for i, (shp_name, shp) in enumerate(zip(shape_names, node.inputs[1:])):
if NoneConst.equals(shp):
continue
code += dedent(
f"""
if (py_{shp_name} != Py_None){{
dtype_{shp_name} shp = ((dtype_{shp_name}*)PyArray_GETPTR1({shp_name}, 0))[0];
if (PyArray_DIMS({x_name})[{i}] != shp) {{
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %d of input has shape %d, expected %d.",
{i}, PyArray_DIMS({x_name})[{i}], shp
);
{fail};
}}
}}
"""
)
def c_code_cache_version(self): code += dedent(
version = [] f"""
# If any of the c code is unversioned, we have to return () Py_XDECREF({o_name});
# Else, we will return a list of (type name, version) pairs. {o_name} = {x_name};
for t, (c, v, _) in sorted( Py_XINCREF({o_name});
self.c_code_and_version.items(), key=lambda pair: str(pair[0]) """
):
if not v:
warnings.warn(
"Type %s has C code for SpecifyShape, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_specify_shape_c_code." % t,
stacklevel=2,
) )
return () return code
version.append((str(t), v))
return tuple(version) def c_code_cache_version(self):
return (2,)
_specify_shape = SpecifyShape() _specify_shape = SpecifyShape()
...@@ -537,29 +538,31 @@ def specify_shape( ...@@ -537,29 +538,31 @@ def specify_shape(
int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable
], ],
): ):
"""Specify a fixed shape for a `Variable`.""" """Specify a fixed shape for a `Variable`.
if not isinstance(x, Variable): If a dimension's shape value is ``None``, the size of that dimension is not considered fixed/static at runtime.
x = at.as_tensor_variable(x) """
if np.ndim(shape) == 0: if not isinstance(shape, (tuple, list)):
shape = at.as_tensor_variable([shape]) shape = (shape,)
# If shape is a symbolic 1d vector of fixed length, we separate the items into a
# tuple with one entry per shape dimension
if len(shape) == 1 and shape[0] is not None:
shape_vector = at.as_tensor_variable(shape[0])
if shape_vector.ndim == 1:
try: try:
_ = get_vector_length(shape) shape = tuple(shape_vector)
except ValueError: except ValueError:
raise ValueError("Shape must have fixed dimensions") raise ValueError("Shape vector must have fixed dimensions")
if isinstance(shape, Constant): return _specify_shape(x, *shape)
shape = tuple(shape.data)
return _specify_shape(x, shape)
@_get_vector_length.register(SpecifyShape) @_get_vector_length.register(SpecifyShape)
def _get_vector_length_SpecifyShape(op, var): def _get_vector_length_SpecifyShape(op, var):
try: try:
return at.get_scalar_constant_value(var.owner.inputs[1]) return at.get_scalar_constant_value(var.owner.inputs[1]).item()
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
...@@ -882,34 +885,3 @@ register_shape_i_c_code( ...@@ -882,34 +885,3 @@ register_shape_i_c_code(
""", """,
version=3, version=3,
) )
register_specify_shape_c_code(
TensorType,
"""
if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: Got %%d dimensions, expected %%d dimensions.",
PyArray_NDIM(%(iname)s),
PyArray_DIMS(%(shape)s)[0]
);
%(fail)s;
}
for(int i = 0; i < PyArray_NDIM(%(iname)s); i++){
dtype_%(shape)s shp = ((dtype_%(shape)s*)PyArray_GETPTR1(%(shape)s,
i))[0];
if (PyArray_DIMS(%(iname)s)[i] != shp) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: dim %%d of input has shape %%d,"
" expected %%d.",
i, PyArray_DIMS(%(iname)s)[i],
shp);
%(fail)s;
}
}
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
version=1,
)
...@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -1646,7 +1646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
return False return False
obj_arg = specify_shape_node.owner.inputs[0] obj_arg = specify_shape_node.owner.inputs[0]
shape_arg = specify_shape_node.owner.inputs[1] shape_arg = specify_shape_node.owner.inputs[1:]
indices = get_idx_list(node.inputs, node.op.idx_list) indices = get_idx_list(node.inputs, node.op.idx_list)
......
...@@ -185,7 +185,7 @@ def test_jax_specify_shape(): ...@@ -185,7 +185,7 @@ def test_jax_specify_shape():
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(at.as_tensor_variable(x_np), (2, 3)) x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
......
...@@ -896,10 +896,15 @@ def test_Reshape_scalar(): ...@@ -896,10 +896,15 @@ def test_Reshape_scalar():
(1, 1), (1, 1),
True, True,
), ),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None),
False,
),
], ],
) )
def test_SpecifyShape(v, shape, fails): def test_SpecifyShape(v, shape, fails):
g = SpecifyShape()(v, shape) g = SpecifyShape()(v, *shape)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm: with cm:
......
...@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape): ...@@ -2975,6 +2975,24 @@ def test_local_Shape_of_SpecifyShape(shape):
assert shape in fgraph.variables assert shape in fgraph.variables
@pytest.mark.parametrize(
"s1",
[lscalar(), iscalar()],
)
def test_local_Shape_of_SpecifyShape_partial(s1):
x = matrix()
s = specify_shape(x, (s1, None)).shape
fgraph = FunctionGraph(outputs=[s], clone=False)
assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
_ = optimize_graph(fgraph, clone=False)
assert x in fgraph.variables
assert s1 in fgraph.variables
assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes)
def test_local_Shape_i_of_broadcastable(): def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True]) x = tensor(np.float64, [False, True])
s = Shape_i(1)(x) s = Shape_i(1)(x)
......
...@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -344,13 +344,13 @@ class TestSpecifyShape(utt.InferShapeTester):
specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3)) specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
with pytest.raises(TypeError, match="must be integer types"): with pytest.raises(TypeError, match="must be integer types"):
_specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3)) _specify_shape([[1, 2, 3], [4, 5, 6]], *(2.2, 3))
with pytest.raises(ValueError, match="will never match"): with pytest.raises(ValueError, match="will never match"):
specify_shape(matrix(), [4]) specify_shape(matrix(), [4])
with pytest.raises(ValueError, match="will never match"): with pytest.raises(ValueError, match="will never match"):
_specify_shape(matrix(), [4]) _specify_shape(matrix(), *[4])
with pytest.raises(ValueError, match="must have fixed dimensions"): with pytest.raises(ValueError, match="must have fixed dimensions"):
specify_shape(matrix(), vector(dtype="int32")) specify_shape(matrix(), vector(dtype="int32"))
...@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -378,6 +378,14 @@ class TestSpecifyShape(utt.InferShapeTester):
f = aesara.function([x], y, mode=self.mode) f = aesara.function([x], y, mode=self.mode)
assert f([15]) == [15] assert f([15]) == [15]
def test_partial_shapes(self):
x = matrix()
s1 = lscalar()
y = specify_shape(x, (s1, None))
f = aesara.function([x, s1], y, mode=self.mode)
assert f(np.zeros((2, 5), dtype=config.floatX), 2).shape == (2, 5)
assert f(np.zeros((3, 5), dtype=config.floatX), 3).shape == (3, 5)
def test_fixed_shapes(self): def test_fixed_shapes(self):
x = vector() x = vector()
shape = as_tensor_variable([2]) shape = as_tensor_variable([2])
...@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -385,6 +393,15 @@ class TestSpecifyShape(utt.InferShapeTester):
assert y.type.shape == (2,) assert y.type.shape == (2,)
assert y.shape.equals(shape) assert y.shape.equals(shape)
def test_fixed_partial_shapes(self):
x = TensorType("floatX", (None, None))("x")
y = specify_shape(x, (None, 5))
assert y.type.shape == (None, 5)
x = TensorType("floatX", (3, None))("x")
y = specify_shape(x, (None, 5))
assert y.type.shape == (3, 5)
def test_python_perform(self): def test_python_perform(self):
"""Test the Python `Op.perform` implementation.""" """Test the Python `Op.perform` implementation."""
x = scalar() x = scalar()
...@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -403,13 +420,20 @@ class TestSpecifyShape(utt.InferShapeTester):
with pytest.raises(AssertionError, match="SpecifyShape:.*"): with pytest.raises(AssertionError, match="SpecifyShape:.*"):
assert f([1], (2,)) == [1] assert f([1], (2,)) == [1]
x = matrix()
y = specify_shape(x, (None, 2))
f = aesara.function([x], y, mode=Mode("py"))
assert f(np.zeros((3, 2), dtype=config.floatX)).shape == (3, 2)
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
assert f(np.zeros((3, 3), dtype=config.floatX))
def test_bad_shape(self): def test_bad_shape(self):
"""Test that at run-time we raise an exception when the shape is not the one specified.""" """Test that at run-time we raise an exception when the shape is not the one specified."""
specify_shape = SpecifyShape() specify_shape = SpecifyShape()
x = vector() x = vector()
xval = np.random.random((2)).astype(config.floatX) xval = np.random.random((2)).astype(config.floatX)
f = aesara.function([x], specify_shape(x, [2]), mode=self.mode) f = aesara.function([x], specify_shape(x, 2), mode=self.mode)
assert np.array_equal(f(xval), xval) assert np.array_equal(f(xval), xval)
...@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -426,7 +450,7 @@ class TestSpecifyShape(utt.InferShapeTester):
x = matrix() x = matrix()
xval = np.random.random((2, 3)).astype(config.floatX) xval = np.random.random((2, 3)).astype(config.floatX)
f = aesara.function([x], specify_shape(x, [2, 3]), mode=self.mode) f = aesara.function([x], specify_shape(x, 2, 3), mode=self.mode)
assert isinstance( assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
.inputs[0] .inputs[0]
...@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -441,6 +465,13 @@ class TestSpecifyShape(utt.InferShapeTester):
with pytest.raises(AssertionError, match="SpecifyShape:.*"): with pytest.raises(AssertionError, match="SpecifyShape:.*"):
f(xval) f(xval)
s = iscalar("s")
f = aesara.function([x, s], specify_shape(x, None, s), mode=self.mode)
x_val = np.zeros((3, 2), dtype=config.floatX)
assert f(x_val, 2).shape == (3, 2)
with pytest.raises(AssertionError, match="SpecifyShape:.*"):
f(xval, 3)
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(3453) rng = np.random.default_rng(3453)
adtens4 = dtensor4() adtens4 = dtensor4()
...@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -454,6 +485,19 @@ class TestSpecifyShape(utt.InferShapeTester):
SpecifyShape, SpecifyShape,
) )
def test_infer_shape_partial(self):
rng = np.random.default_rng(3453)
adtens4 = dtensor4()
aivec = [iscalar(), iscalar(), None, iscalar()]
aivec_val = [3, 4, 5]
adtens4_val = rng.random((3, 4, 2, 5))
self._compile_and_check(
[adtens4, *(ivec for ivec in aivec if ivec is not None)],
[specify_shape(adtens4, aivec)],
[adtens4_val, *aivec_val],
SpecifyShape,
)
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_shape(self): def test_shape(self):
......
...@@ -474,7 +474,7 @@ def makeSharedTester( ...@@ -474,7 +474,7 @@ def makeSharedTester(
assert np.all(self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2)) assert np.all(self.ref_fct(specify_shape_fct()) == self.ref_fct(x1_2))
topo_specify = specify_shape_fct.maker.fgraph.toposort() topo_specify = specify_shape_fct.maker.fgraph.toposort()
if aesara.config.mode != "FAST_COMPILE": if aesara.config.mode != "FAST_COMPILE":
assert len(topo_specify) == 4 assert len(topo_specify) == 3
# Test that we put the shape info into the graph # Test that we put the shape info into the graph
shape_constant_fct = aesara.function([], x1_specify_shape.shape) shape_constant_fct = aesara.function([], x1_specify_shape.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论