提交 e3fabae3 authored 作者: Michael Osthege's avatar Michael Osthege

Make SpecifyShape work for scalar shapes too

Also check exception messages in the tests.
上级 4cda2e52
......@@ -371,22 +371,33 @@ class SpecifyShape(COp):
def make_node(self, x, shape):
if not isinstance(x, Variable):
x = aet.as_tensor_variable(x)
shape = aet.as_tensor_variable(shape)
if shape.ndim > 1:
raise AssertionError()
if shape.dtype not in aesara.tensor.type.integer_dtypes:
raise AssertionError()
if isinstance(shape, TensorConstant) and shape.data.size != x.ndim:
raise AssertionError()
return Apply(self, [x, shape], [x.type()])
if shape == () or shape == []:
tshape = aet.constant([], dtype="int64")
else:
tshape = aet.as_tensor_variable(shape, ndim=1)
if tshape.dtype not in aesara.tensor.type.integer_dtypes:
raise AssertionError(
f"The `shape` must be an integer type. Got {tshape.dtype} instead."
)
if isinstance(tshape, TensorConstant) and tshape.data.size != x.ndim:
ndim = len(tshape.data)
raise AssertionError(
f"Input `x` is {x.ndim}-dimensional and will never match a {ndim}-dimensional shape."
)
return Apply(self, [x, tshape], [x.type()])
def perform(self, node, inp, out_):
x, shape = inp
(out,) = out_
if x.ndim != shape.size:
raise AssertionError()
ndim = len(shape)
if x.ndim != ndim:
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
if not np.all(x.shape == shape):
raise AssertionError(f"Got shape {x.shape}, expected {shape}")
raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(shape)}."
)
out[0] = x
def infer_shape(self, fgraph, node, shapes):
......@@ -804,10 +815,10 @@ register_specify_shape_c_code(
"""
if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError,
"SpecifyShape: vector of shape has %%d elements,"
" but the input has %%d dimensions.",
PyArray_DIMS(%(shape)s)[0],
PyArray_NDIM(%(iname)s));
"SpecifyShape: Got %%d dimensions, expected %%d dimensions.",
PyArray_NDIM(%(iname)s),
PyArray_DIMS(%(shape)s)[0]
);
%(fail)s;
}
for(int i = 0; i < PyArray_NDIM(%(iname)s); i++){
......
......@@ -2,7 +2,7 @@ import numpy as np
import pytest
import aesara
from aesara import function
from aesara import Mode, function
from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
......@@ -28,6 +28,7 @@ from aesara.tensor.type import (
fvector,
ivector,
matrix,
scalar,
tensor3,
vector,
)
......@@ -306,6 +307,56 @@ class TestSpecifyShape(utt.InferShapeTester):
def shortDescription(self):
return None
def test_check_inputs(self):
with pytest.raises(AssertionError, match="must be an integer type"):
specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
specify_shape([[1, 2, 3], [4, 5, 6]], (2, 3))
# Incompatible dimensionality is detected right away
with pytest.raises(AssertionError, match="will never match"):
specify_shape(
matrix(),
[
4,
],
)
pass
def test_scalar_shapes(self):
with pytest.raises(AssertionError, match="will never match"):
specify_shape(vector(), shape=())
with pytest.raises(AssertionError, match="will never match"):
specify_shape(matrix(), shape=[])
x = scalar()
y = specify_shape(x, shape=())
f = aesara.function([x], y, mode=self.mode)
assert f(15) == 15
pass
def test_python_perform(self):
x = scalar()
s = vector(dtype="int32")
y = specify_shape(x, s)
f = aesara.function([x, s], y, mode=Mode("py"))
assert f(12, ()) == 12
with pytest.raises(
AssertionError,
match=r"Got 0 dimensions \(shape \(\)\), expected 1 dimensions with shape \(2,\).",
):
f(12, (2,))
x = matrix()
s = vector(dtype="int32")
y = specify_shape(x, s)
f = aesara.function([x, s], y, mode=Mode("py"))
f(np.ones((2, 3)).astype(config.floatX), (2, 3))
with pytest.raises(
AssertionError, match=r"Got shape \(3, 4\), expected \(2, 3\)."
):
f(np.ones((3, 4)).astype(config.floatX), (2, 3))
pass
def test_bad_shape(self):
# Test that at run time we raise an exception when the shape
# is not the one specified
......@@ -316,7 +367,9 @@ class TestSpecifyShape(utt.InferShapeTester):
f = aesara.function([x], specify_shape(x, [2]), mode=self.mode)
f(xval)
xval = np.random.rand(3).astype(config.floatX)
with pytest.raises(AssertionError):
expected = r"(Got shape \(3,\), expected \(2,\))"
expected += r"|(dim 0 of input has shape 3, expected 2.)"
with pytest.raises(AssertionError, match=expected):
f(xval)
assert isinstance(
......@@ -336,9 +389,13 @@ class TestSpecifyShape(utt.InferShapeTester):
self.input_type,
)
f(xval)
for shape_ in [(1, 3), (2, 2), (5, 5)]:
for shape_ in [(4, 3), (2, 8)]:
xval = np.random.rand(*shape_).astype(config.floatX)
with pytest.raises(AssertionError):
s_exp = str(shape_).replace("(", r"\(").replace(")", r"\)")
expected = rf"(Got shape {s_exp}, expected \(2, 3\).)"
expected += r"|(dim 0 of input has shape 4, expected 2)"
expected += r"|(dim 1 of input has shape 8, expected 3)"
with pytest.raises(AssertionError, match=expected):
f(xval)
def test_bad_number_of_shape(self):
......@@ -348,9 +405,9 @@ class TestSpecifyShape(utt.InferShapeTester):
x = vector()
shape_vec = ivector()
xval = np.random.rand(2).astype(config.floatX)
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match="will never match"):
specify_shape(x, [])
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match="will never match"):
specify_shape(x, [2, 2])
f = aesara.function([x, shape_vec], specify_shape(x, shape_vec), mode=self.mode)
......@@ -360,15 +417,19 @@ class TestSpecifyShape(utt.InferShapeTester):
.type,
self.input_type,
)
with pytest.raises(AssertionError):
expected = r"(Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).)"
expected += r"|(Got 1 dimensions, expected 0 dimensions.)"
with pytest.raises(AssertionError, match=expected):
f(xval, [])
with pytest.raises(AssertionError):
expected = r"(Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(2, 2\).)"
expected += r"|(SpecifyShape: Got 1 dimensions, expected 2 dimensions.)"
with pytest.raises(AssertionError, match=expected):
f(xval, [2, 2])
x = matrix()
xval = np.random.rand(2, 3).astype(config.floatX)
for shape_ in [(), (1,), (2, 3, 4)]:
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match="will never match"):
specify_shape(x, shape_)
f = aesara.function(
[x, shape_vec], specify_shape(x, shape_vec), mode=self.mode
......@@ -383,7 +444,10 @@ class TestSpecifyShape(utt.InferShapeTester):
.type,
self.input_type,
)
with pytest.raises(AssertionError):
s_exp = str(shape_).replace("(", r"\(").replace(")", r"\)")
expected = rf"(Got 2 dimensions \(shape \(2, 3\)\), expected {len(shape_)} dimensions with shape {s_exp}.)"
expected += rf"|(SpecifyShape: Got 2 dimensions, expected {len(shape_)} dimensions.)"
with pytest.raises(AssertionError, match=expected):
f(xval, shape_)
def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论