提交 e44edc3e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Improve broadcastable inference for Shape and MakeVector Ops

After fixing `Shape`'s broadcastable information, `local_subtensor_remove_broadcastable_index` started replacing all forms of `shape(x)[0]` with `shape(x).dimshuffle(())` when `shape(x).broadcastable == (True,)`, so a lot of the changes in this commit compensate for that difference.
上级 70694957
......@@ -1613,32 +1613,29 @@ class MakeVector(COp):
self.dtype = np.dtype(dtype).name
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if not all(a.type == inputs[0].type for a in inputs) or (
inputs = [as_tensor_variable(x) for x in inputs]
if not all(a.ndim == 0 for a in inputs):
raise ValueError("All inputs to MakeVector must be scalars")
if not all(a.type.dtype == inputs[0].type.dtype for a in inputs) or (
len(inputs) > 0 and inputs[0].dtype != self.dtype
):
dtype = aes.upcast(self.dtype, *[i.dtype for i in inputs])
# upcast the input to the determined dtype,
# but don't downcast anything
assert (
dtype == self.dtype
), f"Upcasted inputs do not match the Op's dtype: {dtype} != {self.dtype}"
if not all(self.dtype == cast(i, dtype=dtype).dtype for i in inputs):
inputs = [cast(i, dtype=dtype) for i in inputs]
if not all(self.dtype == i.dtype for i in inputs):
raise TypeError(
f"Expected inputs upcastable to {self.dtype}; "
f"Expected inputs to be upcastable to {self.dtype}; "
f"got {[i.dtype for i in inputs]}"
)
inputs = [cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs)
if inputs:
dtype = inputs[0].type.dtype
else:
dtype = self.dtype
# bcastable = (len(inputs) == 1)
bcastable = False
otype = TensorType(broadcastable=(bcastable,), dtype=dtype)
otype = TensorType(dtype=dtype, broadcastable=(len(inputs) == 1,))
return Apply(self, inputs, [otype()])
def perform(self, node, inputs, out_):
......
......@@ -653,6 +653,36 @@ def local_dimshuffle_lift(fgraph, node):
return [ret]
@register_canonicalize
@register_specialize
@local_optimizer([DimShuffle])
def local_useless_dimshuffle_makevector(fgraph, node):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
This rewrite is needed in order to clean up after
`local_subtensor_remove_broadcastable_index`, which produces a
not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)`
(i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`.
"""
# The `DimShuffle` should be removing the single broadcastable dimension
if node.op.new_order != ():
return
makevector_out = node.inputs[0]
if (
not makevector_out.owner
or not isinstance(makevector_out.owner.op, MakeVector)
or not makevector_out.broadcastable == (True,)
):
return
assert len(makevector_out.owner.inputs) == 1
return [makevector_out.owner.inputs[0]]
@register_canonicalize
@local_optimizer([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node):
......
......@@ -57,13 +57,8 @@ from aesara.tensor.math import (
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.shape import shape, shape_padleft
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
get_constant_idx,
)
from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from aesara.tensor.type import (
TensorType,
discrete_dtypes,
......@@ -2024,6 +2019,7 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
"""
shape_of = None
if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of
# TODO: consider cases where shape_of[labels] is constant, and
......@@ -2045,15 +2041,11 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
# Not sure if that case happens any more after the introduction of
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, Subtensor):
shape_subtensor = stop.owner
if get_constant_idx(
shape_subtensor.op.idx_list, shape_subtensor.inputs, allow_partial=True
) == [0]:
shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == shape:
return shape_var.owner.inputs[0] is labels
else:
if isinstance(stop.owner.op, DimShuffle) and stop.owner.op.new_order == ():
shape_var = stop.owner.inputs[0]
if shape_var.owner and isinstance(shape_var.owner.op, Shape):
return shape_var.owner.inputs[0] is labels
elif shape_of:
shape_of = fgraph.shape_feature.shape_of
return shape_of[labels][0] is stop
......
......@@ -57,11 +57,15 @@ class Shape(COp):
__props__ = ()
def make_node(self, x):
# Must work for all type that have a shape attribute.
# This will fail at execution time.
if not isinstance(x, Variable):
x = at.as_tensor_variable(x)
return Apply(self, [x], [aesara.tensor.type.lvector()])
if hasattr(x, "ndim") and x.ndim == 1:
out_var = TensorType(np.int64, (True,))()
else:
out_var = aesara.tensor.type.lvector()
return Apply(self, [x], [out_var])
def perform(self, node, inp, out_):
(x,) = inp
......
......@@ -94,6 +94,7 @@ from aesara.tensor.math import sum as at_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
from aesara.tensor.type import (
TensorType,
bscalar,
bvector,
col,
dmatrix,
......@@ -372,6 +373,133 @@ TestOnesLikeBroadcast = makeBroadcastTester(
)
class TestMakeVector(utt.InferShapeTester):
b = bscalar()
i = iscalar()
d = dscalar()
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, b)),
("int32", (i, b)),
("int32", (b, i)),
("float64", (b, i)),
("float64", (b, d)),
("float64", (d, i)),
("float64", ()),
("int64", ()),
],
)
def test_make_vector(self, dtype, inputs):
b, i, d = self.b, self.i, self.d
val = {b: 2, i: -3, d: 0.7}
mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
if dtype in int_dtypes:
# The gradient should be 0
utt.assert_allclose(g_val, 0)
else:
for var, grval in zip((b, i, d), g_val):
float_inputs = []
if var.dtype in int_dtypes:
pass
# Currently we don't do any checks on these variables
# verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# be 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
def test_make_vector_fail(self):
with pytest.raises(ValueError):
a, b = vector(), vector()
MakeVector()(a, b)
a, b = iscalar(), lscalar()
res = MakeVector("int64")(a, b)
assert res.dtype == "int64"
with pytest.raises(TypeError):
res = MakeVector("int32")(a, b)
res = MakeVector()(a)
assert res.broadcastable == (True,)
res = MakeVector()()
assert res.broadcastable == (False,)
def test_infer_shape(self):
adscal = dscalar()
bdscal = dscalar()
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
discal = iscalar()
adscal_val = np.random.random()
bdscal_val = np.random.random()
aiscal_val = self.rng.integers(10)
biscal_val = self.rng.integers(10)
ciscal_val = self.rng.integers(10)
discal_val = self.rng.integers(10)
self._compile_and_check(
[adscal, aiscal],
[MakeVector("float64")(adscal, aiscal)],
[adscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[adscal, bdscal, aiscal],
[MakeVector("float64")(adscal, bdscal, aiscal)],
[adscal_val, bdscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[aiscal, biscal, ciscal, discal],
[MakeVector("int32")(aiscal, biscal, ciscal, discal)],
[aiscal_val, biscal_val, ciscal_val, discal_val],
MakeVector,
)
class ApplyDefaultTestOp(Op):
def __init__(self, id):
self.default_output = id
......
......@@ -86,7 +86,14 @@ from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape
from aesara.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
reshape,
shape,
specify_shape,
)
from aesara.tensor.subtensor import (
AdvancedIncSubtensor1,
Subtensor,
......@@ -97,7 +104,6 @@ from aesara.tensor.subtensor import (
)
from aesara.tensor.type import (
TensorType,
bscalar,
dmatrices,
dmatrix,
dscalar,
......@@ -106,7 +112,6 @@ from aesara.tensor.type import (
fscalar,
fvector,
imatrices,
int_dtypes,
iscalar,
ivector,
lscalar,
......@@ -2456,131 +2461,6 @@ class TestLocalOptAllocF16(TestLocalOptAlloc):
dtype = "float16"
class TestMakeVector(utt.InferShapeTester):
b = bscalar()
i = iscalar()
d = dscalar()
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, b)),
("int32", (i, b)),
("int32", (b, i)),
("float64", (b, i)),
("float64", (b, d)),
("float64", (d, i)),
("float64", ()),
("int64", ()),
],
)
def test_make_vector(self, dtype, inputs):
b, i, d = self.b, self.i, self.d
val = {b: 2, i: -3, d: 0.7}
mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
if dtype in int_dtypes:
# The gradient should be 0
utt.assert_allclose(g_val, 0)
else:
for var, grval in zip((b, i, d), g_val):
float_inputs = []
if var.dtype in int_dtypes:
pass
# Currently we don't do any checks on these variables
# verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# be 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, i)),
("int8", (i, b)),
("int8", (b, d)),
("int8", (i, i)),
("int32", (d, i)),
("int32", (i, d)),
("float32", (i, d)),
],
)
def test_make_vector_fail(self, dtype, inputs):
with pytest.raises(AssertionError):
MakeVector(dtype=dtype)(*inputs)
def test_infer_shape(self):
adscal = dscalar()
bdscal = dscalar()
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
discal = iscalar()
adscal_val = np.random.random()
bdscal_val = np.random.random()
aiscal_val = self.rng.integers(10)
biscal_val = self.rng.integers(10)
ciscal_val = self.rng.integers(10)
discal_val = self.rng.integers(10)
self._compile_and_check(
[adscal, aiscal],
[MakeVector("float64")(adscal, aiscal)],
[adscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[adscal, bdscal, aiscal],
[MakeVector("float64")(adscal, bdscal, aiscal)],
[adscal_val, bdscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[aiscal, biscal, ciscal, discal],
[MakeVector("int32")(aiscal, biscal, ciscal, discal)],
[aiscal_val, biscal_val, ciscal_val, discal_val],
MakeVector,
)
def test_local_join_1():
# test for vector
a = vector("a")
......@@ -3582,3 +3462,48 @@ def test_local_remove_scalar_BroadcastTo():
)
assert res is x
def test_local_useless_dimshuffle_makevector():
a = scalar()
x = MakeVector(config.floatX)(a)
y = x.dimshuffle(())
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_useless_dimshuffle_makevector"],
)
assert y_opt_fg.outputs[0] == a
def test_Shape_i_canonicalize():
"""Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension.
In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)``
and nothing else. The rewrites `local_shape_to_shape_i`,
`local_subtensor_remove_broadcastable_index`, and
`local_useless_dimshuffle_makevector` need to work together to accomplish
this, and we confirm that here.
"""
x = vector()
y = shape(x)[0]
y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()])
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=[
"canonicalize",
],
)
y_opt = y_opt_fg.outputs[0]
assert isinstance(y_opt.owner.op, Shape_i)
assert y_opt.owner.op.i == 0
assert y_opt.owner.inputs[0] == x
......@@ -2110,6 +2110,8 @@ class TestLocalUselessElemwiseComparison:
"local_shape_to_shape_i",
"local_track_shape_i",
"local_subtensor_make_vector",
"local_subtensor_remove_broadcastable_index",
"local_useless_dimshuffle_makevector",
)
f = function([x], lt(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0)
......
......@@ -5,7 +5,9 @@ import aesara
from aesara import Mode, function
from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import get_vector_length
from aesara.tensor.basic import MakeVector, as_tensor_variable, constant
......@@ -28,6 +30,7 @@ from aesara.tensor.type import (
dvector,
fvector,
ivector,
lscalar,
matrix,
scalar,
tensor3,
......@@ -41,6 +44,25 @@ from tests.test_rop import RopLopChecker
def test_shape_basic():
s = shape([])
assert s.type.broadcastable == (True,)
s = shape([10])
assert s.type.broadcastable == (True,)
s = shape(lscalar())
assert s.type.broadcastable == (False,)
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType()))
assert s.type.broadcastable == (False,)
s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), [])
......
......@@ -158,7 +158,7 @@ def test_local_useless_inc_subtensor_increment_zeros():
r"""Make sure we remove `IncSubtensor`\s that are increments on entire zero arrays."""
y = matrix("y")
s = aet.zeros((2, 2))[:, :]
s = at.zeros((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
mode = get_default_mode().including("local_useless_inc_subtensor")
......@@ -195,7 +195,7 @@ def test_local_useless_inc_subtensor_no_opt():
assert any(isinstance(n.op, IncSubtensor) for n in topo)
# This is an increment with a non-zero target array
s = aet.ones((2, 2))[:, :]
s = at.ones((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
f_shape = function([y], o_shape, mode=mode)
......@@ -618,8 +618,7 @@ class TestLocalSubtensorMakeVector:
opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, Rebroadcast)
assert isinstance(opt_fgraph.outputs[0].owner.inputs[0].owner.op, MakeVector)
assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert f(0, 1, 2) == np.array([0], dtype=np.int32)
def test_slice_idx_start(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论