提交 34375f41 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize broadcastables inference and fix Alloc broadcastables case

Closes #692
上级 eff08cb3
...@@ -19,7 +19,7 @@ from aesara.graph.utils import MethodNotDefined ...@@ -19,7 +19,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.link.c.interface import HideC from aesara.link.c.interface import HideC
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int32_t from aesara.scalar import int32 as int32_t
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, alloc_validate_shape from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, infer_broadcastable
from aesara.tensor.shape import Reshape from aesara.tensor.shape import Reshape
from aesara.tensor.type import TensorType, values_eq_approx_always_true from aesara.tensor.type import TensorType, values_eq_approx_always_true
...@@ -909,7 +909,7 @@ class GpuAlloc(HideC, Alloc): ...@@ -909,7 +909,7 @@ class GpuAlloc(HideC, Alloc):
def make_node(self, value, *shape): def make_node(self, value, *shape):
value = as_gpuarray_variable(value, context_name=self.context_name) value = as_gpuarray_variable(value, context_name=self.context_name)
sh, bcast = alloc_validate_shape(shape) sh, bcast = infer_broadcastable(shape)
if value.ndim > len(sh): if value.ndim > len(sh):
TypeError( TypeError(
"The GpuAlloc value to use has more dimensions " "The GpuAlloc value to use has more dimensions "
...@@ -1071,7 +1071,7 @@ class GpuAllocEmpty(HideC, AllocEmpty): ...@@ -1071,7 +1071,7 @@ class GpuAllocEmpty(HideC, AllocEmpty):
) )
def make_node(self, *shape): def make_node(self, *shape):
sh, bcast = alloc_validate_shape(shape) sh, bcast = infer_broadcastable(shape)
output = GpuArrayType( output = GpuArrayType(
dtype=self.dtype, broadcastable=bcast, context_name=self.context_name dtype=self.dtype, broadcastable=bcast, context_name=self.context_name
)() )()
......
...@@ -22,7 +22,9 @@ from aesara import compile, config, printing ...@@ -22,7 +22,9 @@ from aesara import compile, config, printing
from aesara import scalar as aes from aesara import scalar as aes
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -1324,43 +1326,44 @@ def identity_like(x): ...@@ -1324,43 +1326,44 @@ def identity_like(x):
return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype) return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)
def alloc_validate_shape(shape): def infer_broadcastable(shape):
sh = [as_tensor_variable(s) for s in shape] """Infer the broadcastable dimensions for `shape`.
bcast = []
for i, s in enumerate(sh):
def err_str(): `shape` will be validated and constant folded in order to determine
if config.exception_verbosity == "high": which dimensions are broadcastable (i.e. equal to ``1``).
return "\n" + min_informative_str(s) """
else: from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
return str(s)
if s.type.dtype not in integer_dtypes: def check_type(s):
s_as_str = err_str() if s.type.dtype in integer_dtypes:
raise TypeError( return s
"Shape arguments to Alloc must be integers, "
f"but argument {i} is not for apply node: {s_as_str}" if config.exception_verbosity == "high":
) s_as_str = "\n" + min_informative_str(s)
if s.ndim != 0: else:
s_as_str = err_str() s_as_str = str(s)
raise TypeError(
"Each shape dimension to Alloc must be a scalar, ", raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
f"but dimension {i} have {int(s.ndim)} dimensions for apply node: {s_as_str}",
)
# if s is constant 1, then we're broadcastable in that dim sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
try:
const_shp = get_scalar_constant_value(s) shape_fg = FunctionGraph(
except NotScalarConstantError: outputs=sh,
const_shp = None features=[ShapeFeature()],
bcast.append(1 == const_shp) clone=True,
)
folded_shape = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast return sh, bcast
class Alloc(COp): class Alloc(COp):
"""Create a `TensorVariable` from an initial value and a desired shape. """Create a `TensorVariable` from an initial value and a desired shape.
alloc(value, shape0, shape1, ..., shapeN) Usage:
alloc(value, shape0, shape1, ..., shapeN)
Returns an N-dimensional tensor initialized by a value, using something Returns an N-dimensional tensor initialized by a value, using something
equivalent to equivalent to
...@@ -1380,12 +1383,9 @@ class Alloc(COp): ...@@ -1380,12 +1383,9 @@ class Alloc(COp):
_f16_ok = True _f16_ok = True
__props__ = () __props__ = ()
def validate_shape(self, shape):
return alloc_validate_shape(shape)
def make_node(self, value, *shape): def make_node(self, value, *shape):
v = as_tensor_variable(value) v = as_tensor_variable(value)
sh, bcast = alloc_validate_shape(shape) sh, bcast = infer_broadcastable(shape)
if v.ndim > len(sh): if v.ndim > len(sh):
raise TypeError( raise TypeError(
"The Alloc value to use has more dimensions" "The Alloc value to use has more dimensions"
...@@ -4102,7 +4102,7 @@ class AllocEmpty(COp): ...@@ -4102,7 +4102,7 @@ class AllocEmpty(COp):
return np.dtype(self.dtype).num return np.dtype(self.dtype).num
def make_node(self, *_shape): def make_node(self, *_shape):
_shape, bcast = alloc_validate_shape(_shape) _shape, bcast = infer_broadcastable(_shape)
otype = TensorType(dtype=self.dtype, broadcastable=bcast) otype = TensorType(dtype=self.dtype, broadcastable=bcast)
output = otype() output = otype()
...@@ -4363,7 +4363,6 @@ __all__ = [ ...@@ -4363,7 +4363,6 @@ __all__ = [
"tensor_copy", "tensor_copy",
"transfer", "transfer",
"alloc", "alloc",
"alloc_validate_shape",
"identity_like", "identity_like",
"eye", "eye",
"triu", "triu",
......
...@@ -1585,7 +1585,7 @@ class BroadcastTo(Op): ...@@ -1585,7 +1585,7 @@ class BroadcastTo(Op):
a = aet.as_tensor_variable(a) a = aet.as_tensor_variable(a)
shape = aet.as_tensor_variable(shape, ndim=1) shape = aet.as_tensor_variable(shape, ndim=1)
shape, bcast = aet.alloc_validate_shape(shape) shape, bcast = aet.infer_broadcastable(shape)
out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)() out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)()
...@@ -1609,7 +1609,7 @@ class BroadcastTo(Op): ...@@ -1609,7 +1609,7 @@ class BroadcastTo(Op):
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
# Determine the dimensions that were broadcast # Determine the dimensions that were broadcast
_, shape_bcast = aet.alloc_validate_shape(shape) _, shape_bcast = aet.infer_broadcastable(shape)
bcast_sums = [ bcast_sums = [
i i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :])) for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
......
...@@ -7,9 +7,7 @@ import numpy as np ...@@ -7,9 +7,7 @@ import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar import ScalarVariable from aesara.scalar import ScalarVariable
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -17,8 +15,8 @@ from aesara.tensor.basic import ( ...@@ -17,8 +15,8 @@ from aesara.tensor.basic import (
constant, constant,
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
infer_broadcastable,
) )
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.random.type import RandomType from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple from aesara.tensor.shape import shape_tuple
...@@ -276,31 +274,6 @@ class RandomVariable(Op): ...@@ -276,31 +274,6 @@ class RandomVariable(Op):
return shape return shape
@config.change_flags(compute_test_value="off")
def compute_bcast(self, dist_params, size):
"""Compute the broadcast array for this distribution's `TensorType`.
Parameters
----------
dist_params: list
Distribution parameters.
size: int or Sequence (optional)
Numpy-like size of the output (i.e. replications).
"""
shape = self._infer_shape(size, dist_params)
shape_fg = FunctionGraph(
outputs=[as_tensor_variable(s, ndim=0) for s in shape],
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(
shape_fg, custom_opt=topo_constant_folding
).outputs
return [getattr(s, "data", s) == 1 for s in folded_shape]
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs _, size, _, *dist_params = node.inputs
_, size_shape, _, *param_shapes = input_shapes _, size_shape, _, *param_shapes = input_shapes
...@@ -362,7 +335,8 @@ class RandomVariable(Op): ...@@ -362,7 +335,8 @@ class RandomVariable(Op):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType" "The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
) )
bcast = self.compute_bcast(dist_params, size) shape = self._infer_shape(size, dist_params)
_, bcast = infer_broadcastable(shape)
dtype = self.dtype or dtype dtype = self.dtype or dtype
if dtype == "floatX": if dtype == "floatX":
......
...@@ -129,12 +129,12 @@ def test_RandomVariable_bcast(): ...@@ -129,12 +129,12 @@ def test_RandomVariable_bcast():
s3.tag.test_value = 3 s3.tag.test_value = 3
s3 = Assert("testing")(s3, eq(s1, 1)) s3 = Assert("testing")(s3, eq(s1, 1))
res = rv.compute_bcast([mu, sd], (s1, s2, s3)) res = rv(mu, sd, size=(s1, s2, s3))
assert res == [False] * 3 assert res.broadcastable == (False,) * 3
size = aet.as_tensor((1, 2, 3), dtype=np.int32).astype(np.int64) size = aet.as_tensor((1, 2, 3), dtype=np.int32).astype(np.int64)
res = rv.compute_bcast([mu, sd], size) res = rv(mu, sd, size=size)
assert res == [True, False, False] assert res.broadcastable == (True, False, False)
res = rv(0, 1, size=aet.as_tensor(1, dtype=np.int64)) res = rv(0, 1, size=aet.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,) assert res.broadcastable == (True,)
......
...@@ -59,6 +59,7 @@ from aesara.tensor.basic import ( ...@@ -59,6 +59,7 @@ from aesara.tensor.basic import (
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
horizontal_stack, horizontal_stack,
infer_broadcastable,
inverse_permutation, inverse_permutation,
join, join,
make_vector, make_vector,
...@@ -90,7 +91,7 @@ from aesara.tensor.elemwise import DimShuffle ...@@ -90,7 +91,7 @@ from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import dense_dot, eq from aesara.tensor.math import dense_dot, eq
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
bvector, bvector,
...@@ -658,6 +659,24 @@ class TestAlloc: ...@@ -658,6 +659,24 @@ class TestAlloc:
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
def test_infer_broadcastable():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_broadcastable([constant(1.0)])
with config.change_flags(exception_verbosity="high"), pytest.raises(
TypeError, match=r"A\. x"
):
infer_broadcastable([dscalar("x")])
with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"):
infer_broadcastable((as_tensor_variable([[1, 2]]),))
constant_size = constant([1])
specify_size = specify_shape(constant_size, [1])
sh, bcast = infer_broadcastable(specify_size)
assert bcast == (True,)
# This is slow for the ('int8', 3) version. # This is slow for the ('int8', 3) version.
def test_eye(): def test_eye():
def check(dtype, N, M_=None, k=0): def check(dtype, N, M_=None, k=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论