提交 34375f41 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize broadcastables inference and fix Alloc broadcastables case

Closes #692
上级 eff08cb3
......@@ -19,7 +19,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.link.c.interface import HideC
from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int32_t
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, alloc_validate_shape
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, infer_broadcastable
from aesara.tensor.shape import Reshape
from aesara.tensor.type import TensorType, values_eq_approx_always_true
......@@ -909,7 +909,7 @@ class GpuAlloc(HideC, Alloc):
def make_node(self, value, *shape):
value = as_gpuarray_variable(value, context_name=self.context_name)
sh, bcast = alloc_validate_shape(shape)
sh, bcast = infer_broadcastable(shape)
if value.ndim > len(sh):
TypeError(
"The GpuAlloc value to use has more dimensions "
......@@ -1071,7 +1071,7 @@ class GpuAllocEmpty(HideC, AllocEmpty):
)
def make_node(self, *shape):
sh, bcast = alloc_validate_shape(shape)
sh, bcast = infer_broadcastable(shape)
output = GpuArrayType(
dtype=self.dtype, broadcastable=bcast, context_name=self.context_name
)()
......
......@@ -22,7 +22,9 @@ from aesara import compile, config, printing
from aesara import scalar as aes
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp, Op
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
......@@ -1324,43 +1326,44 @@ def identity_like(x):
return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)
def alloc_validate_shape(shape):
sh = [as_tensor_variable(s) for s in shape]
bcast = []
for i, s in enumerate(sh):
def infer_broadcastable(shape):
"""Infer the broadcastable dimensions for `shape`.
def err_str():
if config.exception_verbosity == "high":
return "\n" + min_informative_str(s)
else:
return str(s)
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
if s.type.dtype not in integer_dtypes:
s_as_str = err_str()
raise TypeError(
"Shape arguments to Alloc must be integers, "
f"but argument {i} is not for apply node: {s_as_str}"
)
if s.ndim != 0:
s_as_str = err_str()
raise TypeError(
"Each shape dimension to Alloc must be a scalar, ",
f"but dimension {i} have {int(s.ndim)} dimensions for apply node: {s_as_str}",
)
def check_type(s):
if s.type.dtype in integer_dtypes:
return s
if config.exception_verbosity == "high":
s_as_str = "\n" + min_informative_str(s)
else:
s_as_str = str(s)
raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")
# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = get_scalar_constant_value(s)
except NotScalarConstantError:
const_shp = None
bcast.append(1 == const_shp)
sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]
shape_fg = FunctionGraph(
outputs=sh,
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs
bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast
class Alloc(COp):
"""Create a `TensorVariable` from an initial value and a desired shape.
alloc(value, shape0, shape1, ..., shapeN)
Usage:
alloc(value, shape0, shape1, ..., shapeN)
Returns an N-dimensional tensor initialized by a value, using something
equivalent to
......@@ -1380,12 +1383,9 @@ class Alloc(COp):
_f16_ok = True
__props__ = ()
def validate_shape(self, shape):
return alloc_validate_shape(shape)
def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, bcast = alloc_validate_shape(shape)
sh, bcast = infer_broadcastable(shape)
if v.ndim > len(sh):
raise TypeError(
"The Alloc value to use has more dimensions"
......@@ -4102,7 +4102,7 @@ class AllocEmpty(COp):
return np.dtype(self.dtype).num
def make_node(self, *_shape):
_shape, bcast = alloc_validate_shape(_shape)
_shape, bcast = infer_broadcastable(_shape)
otype = TensorType(dtype=self.dtype, broadcastable=bcast)
output = otype()
......@@ -4363,7 +4363,6 @@ __all__ = [
"tensor_copy",
"transfer",
"alloc",
"alloc_validate_shape",
"identity_like",
"eye",
"triu",
......
......@@ -1585,7 +1585,7 @@ class BroadcastTo(Op):
a = aet.as_tensor_variable(a)
shape = aet.as_tensor_variable(shape, ndim=1)
shape, bcast = aet.alloc_validate_shape(shape)
shape, bcast = aet.infer_broadcastable(shape)
out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)()
......@@ -1609,7 +1609,7 @@ class BroadcastTo(Op):
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
# Determine the dimensions that were broadcast
_, shape_bcast = aet.alloc_validate_shape(shape)
_, shape_bcast = aet.infer_broadcastable(shape)
bcast_sums = [
i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
......
......@@ -7,9 +7,7 @@ import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.misc.safe_asarray import _asarray
from aesara.scalar import ScalarVariable
from aesara.tensor.basic import (
......@@ -17,8 +15,8 @@ from aesara.tensor.basic import (
constant,
get_scalar_constant_value,
get_vector_length,
infer_broadcastable,
)
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple
......@@ -276,31 +274,6 @@ class RandomVariable(Op):
return shape
@config.change_flags(compute_test_value="off")
def compute_bcast(self, dist_params, size):
"""Compute the broadcast array for this distribution's `TensorType`.
Parameters
----------
dist_params: list
Distribution parameters.
size: int or Sequence (optional)
Numpy-like size of the output (i.e. replications).
"""
shape = self._infer_shape(size, dist_params)
shape_fg = FunctionGraph(
outputs=[as_tensor_variable(s, ndim=0) for s in shape],
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(
shape_fg, custom_opt=topo_constant_folding
).outputs
return [getattr(s, "data", s) == 1 for s in folded_shape]
def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs
_, size_shape, _, *param_shapes = input_shapes
......@@ -362,7 +335,8 @@ class RandomVariable(Op):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
bcast = self.compute_bcast(dist_params, size)
shape = self._infer_shape(size, dist_params)
_, bcast = infer_broadcastable(shape)
dtype = self.dtype or dtype
if dtype == "floatX":
......
......@@ -129,12 +129,12 @@ def test_RandomVariable_bcast():
s3.tag.test_value = 3
s3 = Assert("testing")(s3, eq(s1, 1))
res = rv.compute_bcast([mu, sd], (s1, s2, s3))
assert res == [False] * 3
res = rv(mu, sd, size=(s1, s2, s3))
assert res.broadcastable == (False,) * 3
size = aet.as_tensor((1, 2, 3), dtype=np.int32).astype(np.int64)
res = rv.compute_bcast([mu, sd], size)
assert res == [True, False, False]
res = rv(mu, sd, size=size)
assert res.broadcastable == (True, False, False)
res = rv(0, 1, size=aet.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)
......
......@@ -59,6 +59,7 @@ from aesara.tensor.basic import (
get_scalar_constant_value,
get_vector_length,
horizontal_stack,
infer_broadcastable,
inverse_permutation,
join,
make_vector,
......@@ -90,7 +91,7 @@ from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import dense_dot, eq
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
from aesara.tensor.type import (
TensorType,
bvector,
......@@ -658,6 +659,24 @@ class TestAlloc:
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
def test_infer_broadcastable():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_broadcastable([constant(1.0)])
with config.change_flags(exception_verbosity="high"), pytest.raises(
TypeError, match=r"A\. x"
):
infer_broadcastable([dscalar("x")])
with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"):
infer_broadcastable((as_tensor_variable([[1, 2]]),))
constant_size = constant([1])
specify_size = specify_shape(constant_size, [1])
sh, bcast = infer_broadcastable(specify_size)
assert bcast == (True,)
# This is slow for the ('int8', 3) version.
def test_eye():
def check(dtype, N, M_=None, k=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论