提交 c6b08584 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Incorporate static shape of Alloc input

上级 6898f749
......@@ -1432,17 +1432,41 @@ class Alloc(COp):
__props__ = ()
def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh):
value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape)
if value.ndim > len(shape):
raise TypeError(
"The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim,
len(sh),
value.ndim,
len(shape),
)
otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()])
# Combine static shape information from value and shape
combined_static_shape = list(static_shape).copy()
new_dims = len(shape) - value.type.ndim
extended_value_static_shape = (None,) * new_dims + value.type.shape
extended_value_broadcastable = (False,) * new_dims + value.type.broadcastable
for i, (v_bc, v_st, sh_st) in enumerate(
zip(
extended_value_broadcastable,
extended_value_static_shape,
static_shape,
)
):
# If value is not broadcastable and we don't know the target static shape: use value static shape
if (not v_bc) and (sh_st is None):
combined_static_shape[i] = v_st
# Otherwise check if static shapes are compatible
elif (v_st is not None) and (sh_st is not None):
# They must match or if not, the value must be broadcastable
if v_st != sh_st and not v_bc:
raise ValueError(
f"Alloc static input type and target shape are incompatible: {value.type} vs {static_shape}"
)
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()])
def perform(self, node, inputs, out_):
(out,) = out_
......
......@@ -272,27 +272,6 @@ class TestLocalCanonicalizeAlloc:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
def test_inconsistent_constant(self):
x = at.as_tensor(self.rng.standard_normal((3, 7)))
a = at.alloc(x, 6, 7)
assert a.owner and isinstance(a.owner.op, Alloc)
# `local_useless_alloc` should attempt to replace the `Alloc` with an
# `Assert` and fail when the static shape information conflicts.
with pytest.raises(TypeError):
f = function([], a, mode=rewrite_mode)
x = at.as_tensor(self.rng.standard_normal((6, 7)))
a = at.alloc(x, 6, 7)
f = function([], a, mode=rewrite_mode)
# The rewrite should then be applied, and remove Alloc
assert not any(
isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort()
)
def test_inconsistent_shared(self):
# These shapes don't match!
x = shared(self.rng.standard_normal((3, 7)))
......
......@@ -835,6 +835,22 @@ class TestAlloc:
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
def test_static_shape(self):
x = tensor(shape=(None, 1, 5))
d0 = scalar("d0", dtype=int)
d1 = scalar("d1", dtype=int)
assert at.alloc(x, 3, 1, 5).type.shape == (3, 1, 5)
assert at.alloc(x, 3, 4, 5).type.shape == (3, 4, 5)
assert at.alloc(x, d0, d1, 5).type.shape == (None, None, 5)
assert at.alloc(x, d0, 1, d1).type.shape == (None, 1, 5)
msg = "Alloc static input type and target shape are incompatible"
with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 1)
with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 6)
def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论