提交 4d261b30 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change infer_broadcastable to infer_static_shape

上级 e4b15e48
...@@ -10,7 +10,7 @@ import warnings ...@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Optional from typing import TYPE_CHECKING, Optional
from typing import Sequence as TypeSequence from typing import Sequence as TypeSequence
from typing import Tuple, Union from typing import Tuple, Union
from typing import cast as type_cast from typing import cast as type_cast
...@@ -68,6 +68,10 @@ from aesara.tensor.type import ( ...@@ -68,6 +68,10 @@ from aesara.tensor.type import (
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
if TYPE_CHECKING:
from aesara.tensor import TensorLike
def __oplist_tag(thing, tag): def __oplist_tag(thing, tag):
tags = getattr(thing, "__oplist_tags", []) tags = getattr(thing, "__oplist_tags", [])
tags.append(tag) tags.append(tag)
...@@ -1334,11 +1338,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): ...@@ -1334,11 +1338,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype) return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
def infer_broadcastable(shape): def infer_static_shape(
"""Infer the broadcastable dimensions for `shape`. shape: Union[Variable, TypeSequence[Union[Variable, int]]]
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]:
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.
`shape` will be validated and constant folded. As a result, this function
can be expensive and shouldn't be used unless absolutely necessary.
It mostly exists as a hold-over from pre-static shape times, when it was
required in order to produce correct broadcastable arrays and prevent
some graphs from being unusable. Now, it is no longer strictly required,
so don't use it unless you want the same shape graphs to be rewritten
multiple times during graph construction.
Returns
-------
A validated sequence of symbolic shape values, and a sequence of
``None``/``int`` values that can be used as `TensorType.shape` values.
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
""" """
from aesara.tensor.rewriting.basic import topo_constant_folding from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature from aesara.tensor.rewriting.shape import ShapeFeature
...@@ -1362,9 +1380,10 @@ def infer_broadcastable(shape): ...@@ -1362,9 +1380,10 @@ def infer_broadcastable(shape):
clone=True, clone=True,
) )
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs
static_shape = tuple(
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape) s.data.item() if isinstance(s, Constant) else None for s in folded_shape
return sh, bcast )
return sh, static_shape
class Alloc(COp): class Alloc(COp):
...@@ -1394,7 +1413,7 @@ class Alloc(COp): ...@@ -1394,7 +1413,7 @@ class Alloc(COp):
def make_node(self, value, *shape): def make_node(self, value, *shape):
v = as_tensor_variable(value) v = as_tensor_variable(value)
sh, bcast = infer_broadcastable(shape) sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh): if v.ndim > len(sh):
raise TypeError( raise TypeError(
"The Alloc value to use has more dimensions" "The Alloc value to use has more dimensions"
...@@ -1402,7 +1421,7 @@ class Alloc(COp): ...@@ -1402,7 +1421,7 @@ class Alloc(COp):
v.ndim, v.ndim,
len(sh), len(sh),
) )
otype = TensorType(dtype=v.dtype, shape=bcast) otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()]) return Apply(self, [v] + sh, [otype()])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
...@@ -3823,8 +3842,8 @@ class AllocEmpty(COp): ...@@ -3823,8 +3842,8 @@ class AllocEmpty(COp):
return np.dtype(self.dtype).num return np.dtype(self.dtype).num
def make_node(self, *_shape): def make_node(self, *_shape):
_shape, bcast = infer_broadcastable(_shape) _shape, static_shape = infer_static_shape(_shape)
otype = TensorType(dtype=self.dtype, shape=bcast) otype = TensorType(dtype=self.dtype, shape=static_shape)
output = otype() output = otype()
output.tag.values_eq_approx = values_eq_approx_always_true output.tag.values_eq_approx = values_eq_approx_always_true
......
...@@ -1646,9 +1646,9 @@ class BroadcastTo(COp): ...@@ -1646,9 +1646,9 @@ class BroadcastTo(COp):
def make_node(self, a, *shape): def make_node(self, a, *shape):
a = at.as_tensor_variable(a) a = at.as_tensor_variable(a)
shape, bcast = at.infer_broadcastable(shape) shape, static_shape = at.infer_static_shape(shape)
out = TensorType(dtype=a.type.dtype, shape=bcast)() out = TensorType(dtype=a.type.dtype, shape=static_shape)()
# Attempt to prevent in-place operations on this view-based output # Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True out.tag.indestructible = True
...@@ -1670,11 +1670,14 @@ class BroadcastTo(COp): ...@@ -1670,11 +1670,14 @@ class BroadcastTo(COp):
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
# Determine the dimensions that were broadcast # Determine the dimensions that were broadcast
_, shape_bcast = at.infer_broadcastable(shape) _, static_shape = at.infer_static_shape(shape)
# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [ bcast_sums = [
i i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :])) for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_b and not s_b if a_s == 1 and s_s != 1
] ]
if bcast_sums: if bcast_sums:
......
...@@ -14,7 +14,7 @@ from aesara.tensor.basic import ( ...@@ -14,7 +14,7 @@ from aesara.tensor.basic import (
constant, constant,
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
infer_broadcastable, infer_static_shape,
) )
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
...@@ -322,7 +322,7 @@ class RandomVariable(Op): ...@@ -322,7 +322,7 @@ class RandomVariable(Op):
) )
shape = self._infer_shape(size, dist_params) shape = self._infer_shape(size, dist_params)
_, bcast = infer_broadcastable(shape) _, static_shape = infer_static_shape(shape)
dtype = self.dtype or dtype dtype = self.dtype or dtype
if dtype == "floatX": if dtype == "floatX":
...@@ -336,7 +336,7 @@ class RandomVariable(Op): ...@@ -336,7 +336,7 @@ class RandomVariable(Op):
dtype_idx = constant(dtype, dtype="int64") dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data] dtype = all_dtypes[dtype_idx.data]
outtype = TensorType(dtype=dtype, shape=bcast) outtype = TensorType(dtype=dtype, shape=static_shape)
out_var = outtype() out_var = outtype()
inputs = (rng, size, dtype_idx) + dist_params inputs = (rng, size, dtype_idx) + dist_params
outputs = (rng.type(), out_var) outputs = (rng.type(), out_var)
......
...@@ -276,8 +276,9 @@ class TestLocalCanonicalizeAlloc: ...@@ -276,8 +276,9 @@ class TestLocalCanonicalizeAlloc:
assert a.owner and isinstance(a.owner.op, Alloc) assert a.owner and isinstance(a.owner.op, Alloc)
# `local_useless_alloc` should replace the `Alloc` with an `Assert` # `local_useless_alloc` should attempt to replace the `Alloc` with an
with pytest.raises(AssertionError): # `Assert` and fail when the static shape information conflicts.
with pytest.raises(TypeError):
f = function([], a, mode=rewrite_mode) f = function([], a, mode=rewrite_mode)
x = at.as_tensor(self.rng.standard_normal((6, 7))) x = at.as_tensor(self.rng.standard_normal((6, 7)))
......
...@@ -55,7 +55,7 @@ from aesara.tensor.basic import ( ...@@ -55,7 +55,7 @@ from aesara.tensor.basic import (
get_vector_length, get_vector_length,
horizontal_stack, horizontal_stack,
identity_like, identity_like,
infer_broadcastable, infer_static_shape,
inverse_permutation, inverse_permutation,
join, join,
make_vector, make_vector,
...@@ -796,20 +796,20 @@ class TestAlloc: ...@@ -796,20 +796,20 @@ class TestAlloc:
def test_infer_broadcastable(): def test_infer_broadcastable():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_broadcastable([constant(1.0)]) infer_static_shape([constant(1.0)])
with config.change_flags(exception_verbosity="high"), pytest.raises( with config.change_flags(exception_verbosity="high"), pytest.raises(
TypeError, match=r"A\. x" TypeError, match=r"A\. x"
): ):
infer_broadcastable([dscalar("x")]) infer_static_shape([dscalar("x")])
with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"): with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"):
infer_broadcastable((as_tensor_variable([[1, 2]]),)) infer_static_shape((as_tensor_variable([[1, 2]]),))
constant_size = constant([1]) constant_size = constant([1])
specify_size = specify_shape(constant_size, [1]) specify_size = specify_shape(constant_size, [1])
sh, bcast = infer_broadcastable(specify_size) sh, static_shape = infer_static_shape(specify_size)
assert bcast == (True,) assert static_shape == (1,)
# This is slow for the ('int8', 3) version. # This is slow for the ('int8', 3) version.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论