提交 e3fabae3 authored 作者: Michael Osthege's avatar Michael Osthege

Make SpecifyShape work for scalar shapes too

Also check exception messages in the tests.
上级 4cda2e52
...@@ -371,22 +371,33 @@ class SpecifyShape(COp): ...@@ -371,22 +371,33 @@ class SpecifyShape(COp):
def make_node(self, x, shape): def make_node(self, x, shape):
if not isinstance(x, Variable): if not isinstance(x, Variable):
x = aet.as_tensor_variable(x) x = aet.as_tensor_variable(x)
shape = aet.as_tensor_variable(shape) if shape == () or shape == []:
if shape.ndim > 1: tshape = aet.constant([], dtype="int64")
raise AssertionError() else:
if shape.dtype not in aesara.tensor.type.integer_dtypes: tshape = aet.as_tensor_variable(shape, ndim=1)
raise AssertionError() if tshape.dtype not in aesara.tensor.type.integer_dtypes:
if isinstance(shape, TensorConstant) and shape.data.size != x.ndim: raise AssertionError(
raise AssertionError() f"The `shape` must be an integer type. Got {tshape.dtype} instead."
return Apply(self, [x, shape], [x.type()]) )
if isinstance(tshape, TensorConstant) and tshape.data.size != x.ndim:
ndim = len(tshape.data)
raise AssertionError(
f"Input `x` is {x.ndim}-dimensional and will never match a {ndim}-dimensional shape."
)
return Apply(self, [x, tshape], [x.type()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, shape = inp x, shape = inp
(out,) = out_ (out,) = out_
if x.ndim != shape.size: ndim = len(shape)
raise AssertionError() if x.ndim != ndim:
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
if not np.all(x.shape == shape): if not np.all(x.shape == shape):
raise AssertionError(f"Got shape {x.shape}, expected {shape}") raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(shape)}."
)
out[0] = x out[0] = x
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
...@@ -804,10 +815,10 @@ register_specify_shape_c_code( ...@@ -804,10 +815,10 @@ register_specify_shape_c_code(
""" """
if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) { if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) {
PyErr_Format(PyExc_AssertionError, PyErr_Format(PyExc_AssertionError,
"SpecifyShape: vector of shape has %%d elements," "SpecifyShape: Got %%d dimensions, expected %%d dimensions.",
" but the input has %%d dimensions.", PyArray_NDIM(%(iname)s),
PyArray_DIMS(%(shape)s)[0], PyArray_DIMS(%(shape)s)[0]
PyArray_NDIM(%(iname)s)); );
%(fail)s; %(fail)s;
} }
for(int i = 0; i < PyArray_NDIM(%(iname)s); i++){ for(int i = 0; i < PyArray_NDIM(%(iname)s); i++){
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import aesara import aesara
from aesara import function from aesara import Mode, function
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -28,6 +28,7 @@ from aesara.tensor.type import ( ...@@ -28,6 +28,7 @@ from aesara.tensor.type import (
fvector, fvector,
ivector, ivector,
matrix, matrix,
scalar,
tensor3, tensor3,
vector, vector,
) )
...@@ -306,6 +307,56 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -306,6 +307,56 @@ class TestSpecifyShape(utt.InferShapeTester):
def shortDescription(self): def shortDescription(self):
return None return None
def test_check_inputs(self):
with pytest.raises(AssertionError, match="must be an integer type"):
specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
specify_shape([[1, 2, 3], [4, 5, 6]], (2, 3))
# Incompatible dimensionality is detected right away
with pytest.raises(AssertionError, match="will never match"):
specify_shape(
matrix(),
[
4,
],
)
pass
def test_scalar_shapes(self):
with pytest.raises(AssertionError, match="will never match"):
specify_shape(vector(), shape=())
with pytest.raises(AssertionError, match="will never match"):
specify_shape(matrix(), shape=[])
x = scalar()
y = specify_shape(x, shape=())
f = aesara.function([x], y, mode=self.mode)
assert f(15) == 15
pass
def test_python_perform(self):
x = scalar()
s = vector(dtype="int32")
y = specify_shape(x, s)
f = aesara.function([x, s], y, mode=Mode("py"))
assert f(12, ()) == 12
with pytest.raises(
AssertionError,
match=r"Got 0 dimensions \(shape \(\)\), expected 1 dimensions with shape \(2,\).",
):
f(12, (2,))
x = matrix()
s = vector(dtype="int32")
y = specify_shape(x, s)
f = aesara.function([x, s], y, mode=Mode("py"))
f(np.ones((2, 3)).astype(config.floatX), (2, 3))
with pytest.raises(
AssertionError, match=r"Got shape \(3, 4\), expected \(2, 3\)."
):
f(np.ones((3, 4)).astype(config.floatX), (2, 3))
pass
def test_bad_shape(self): def test_bad_shape(self):
# Test that at run time we raise an exception when the shape # Test that at run time we raise an exception when the shape
# is not the one specified # is not the one specified
...@@ -316,7 +367,9 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -316,7 +367,9 @@ class TestSpecifyShape(utt.InferShapeTester):
f = aesara.function([x], specify_shape(x, [2]), mode=self.mode) f = aesara.function([x], specify_shape(x, [2]), mode=self.mode)
f(xval) f(xval)
xval = np.random.rand(3).astype(config.floatX) xval = np.random.rand(3).astype(config.floatX)
with pytest.raises(AssertionError): expected = r"(Got shape \(3,\), expected \(2,\))"
expected += r"|(dim 0 of input has shape 3, expected 2.)"
with pytest.raises(AssertionError, match=expected):
f(xval) f(xval)
assert isinstance( assert isinstance(
...@@ -336,9 +389,13 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -336,9 +389,13 @@ class TestSpecifyShape(utt.InferShapeTester):
self.input_type, self.input_type,
) )
f(xval) f(xval)
for shape_ in [(1, 3), (2, 2), (5, 5)]: for shape_ in [(4, 3), (2, 8)]:
xval = np.random.rand(*shape_).astype(config.floatX) xval = np.random.rand(*shape_).astype(config.floatX)
with pytest.raises(AssertionError): s_exp = str(shape_).replace("(", r"\(").replace(")", r"\)")
expected = rf"(Got shape {s_exp}, expected \(2, 3\).)"
expected += r"|(dim 0 of input has shape 4, expected 2)"
expected += r"|(dim 1 of input has shape 8, expected 3)"
with pytest.raises(AssertionError, match=expected):
f(xval) f(xval)
def test_bad_number_of_shape(self): def test_bad_number_of_shape(self):
...@@ -348,9 +405,9 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -348,9 +405,9 @@ class TestSpecifyShape(utt.InferShapeTester):
x = vector() x = vector()
shape_vec = ivector() shape_vec = ivector()
xval = np.random.rand(2).astype(config.floatX) xval = np.random.rand(2).astype(config.floatX)
with pytest.raises(AssertionError): with pytest.raises(AssertionError, match="will never match"):
specify_shape(x, []) specify_shape(x, [])
with pytest.raises(AssertionError): with pytest.raises(AssertionError, match="will never match"):
specify_shape(x, [2, 2]) specify_shape(x, [2, 2])
f = aesara.function([x, shape_vec], specify_shape(x, shape_vec), mode=self.mode) f = aesara.function([x, shape_vec], specify_shape(x, shape_vec), mode=self.mode)
...@@ -360,15 +417,19 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -360,15 +417,19 @@ class TestSpecifyShape(utt.InferShapeTester):
.type, .type,
self.input_type, self.input_type,
) )
with pytest.raises(AssertionError): expected = r"(Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).)"
expected += r"|(Got 1 dimensions, expected 0 dimensions.)"
with pytest.raises(AssertionError, match=expected):
f(xval, []) f(xval, [])
with pytest.raises(AssertionError): expected = r"(Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(2, 2\).)"
expected += r"|(SpecifyShape: Got 1 dimensions, expected 2 dimensions.)"
with pytest.raises(AssertionError, match=expected):
f(xval, [2, 2]) f(xval, [2, 2])
x = matrix() x = matrix()
xval = np.random.rand(2, 3).astype(config.floatX) xval = np.random.rand(2, 3).astype(config.floatX)
for shape_ in [(), (1,), (2, 3, 4)]: for shape_ in [(), (1,), (2, 3, 4)]:
with pytest.raises(AssertionError): with pytest.raises(AssertionError, match="will never match"):
specify_shape(x, shape_) specify_shape(x, shape_)
f = aesara.function( f = aesara.function(
[x, shape_vec], specify_shape(x, shape_vec), mode=self.mode [x, shape_vec], specify_shape(x, shape_vec), mode=self.mode
...@@ -383,7 +444,10 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -383,7 +444,10 @@ class TestSpecifyShape(utt.InferShapeTester):
.type, .type,
self.input_type, self.input_type,
) )
with pytest.raises(AssertionError): s_exp = str(shape_).replace("(", r"\(").replace(")", r"\)")
expected = rf"(Got 2 dimensions \(shape \(2, 3\)\), expected {len(shape_)} dimensions with shape {s_exp}.)"
expected += rf"|(SpecifyShape: Got 2 dimensions, expected {len(shape_)} dimensions.)"
with pytest.raises(AssertionError, match=expected):
f(xval, shape_) f(xval, shape_)
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论