提交 205da7f9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in infer_static_shape of graphs involving the shape of scalars

上级 2cef9c0e
...@@ -1406,11 +1406,8 @@ def infer_static_shape( ...@@ -1406,11 +1406,8 @@ def infer_static_shape(
`shape` will be validated and constant folded. As a result, this function `shape` will be validated and constant folded. As a result, this function
can be expensive and shouldn't be used unless absolutely necessary. can be expensive and shouldn't be used unless absolutely necessary.
It mostly exists as a hold-over from pre-static shape times, when it was It is often needed for `Op`s whose static shape and broadcastable flags
required in order to produce correct broadcastable arrays and prevent depend on the values of their inputs, such as `Alloc` and `RandomVariable`.
some graphs from being unusable. Now, it is no longer strictly required,
so don't use it unless you want the same shape graphs to be rewritten
multiple times during graph construction.
Returns Returns
------- -------
......
...@@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node): ...@@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
return [specify_shape(inner_obj, shape)] return [specify_shape(inner_obj, shape)]
_empty_shape = constant([], dtype="int64")
@register_infer_shape @register_infer_shape
@node_rewriter([Shape]) @node_rewriter([Shape])
def local_shape_ground(fgraph, node): def local_shape_ground(fgraph, node):
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant.""" """Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
[x] = node.inputs [x] = node.inputs
static_shape = x.type.shape static_shape = x.type.shape
if len(static_shape) == 0:
return [_empty_shape]
if not any(dim is None for dim in static_shape): if not any(dim is None for dim in static_shape):
return [stack([constant(dim, dtype="int64") for dim in static_shape])] return [stack([constant(dim, dtype="int64") for dim in static_shape])]
......
...@@ -908,7 +908,7 @@ class TestAlloc: ...@@ -908,7 +908,7 @@ class TestAlloc:
self.check_runtime_broadcast(mode) self.check_runtime_broadcast(mode)
def test_infer_shape(): def test_infer_static_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_static_shape([constant(1.0)]) infer_static_shape([constant(1.0)])
...@@ -925,6 +925,10 @@ def test_infer_shape(): ...@@ -925,6 +925,10 @@ def test_infer_shape():
sh, static_shape = infer_static_shape(specify_size) sh, static_shape = infer_static_shape(specify_size)
assert static_shape == (1,) assert static_shape == (1,)
x = scalar("x")
sh, static_shape = infer_static_shape([x.size])
assert static_shape == (1,)
# This is slow for the ('int8', 3) version. # This is slow for the ('int8', 3) version.
def test_eye(): def test_eye():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论