提交 e44edc3e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Improve broadcastable inference for Shape and MakeVector Ops

After fixing `Shape`'s broadcastable information, `local_subtensor_remove_broadcastable_index` started replacing all forms of `shape(x)[0]` with `shape(x).dimshuffle(())` when `shape(x).broadcastable == (True,)`, so a lot of the changes in this commit compensate for that difference.
上级 70694957
...@@ -1613,32 +1613,29 @@ class MakeVector(COp): ...@@ -1613,32 +1613,29 @@ class MakeVector(COp):
self.dtype = np.dtype(dtype).name self.dtype = np.dtype(dtype).name
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs)) inputs = [as_tensor_variable(x) for x in inputs]
if not all(a.type == inputs[0].type for a in inputs) or (
if not all(a.ndim == 0 for a in inputs):
raise ValueError("All inputs to MakeVector must be scalars")
if not all(a.type.dtype == inputs[0].type.dtype for a in inputs) or (
len(inputs) > 0 and inputs[0].dtype != self.dtype len(inputs) > 0 and inputs[0].dtype != self.dtype
): ):
dtype = aes.upcast(self.dtype, *[i.dtype for i in inputs]) dtype = aes.upcast(self.dtype, *[i.dtype for i in inputs])
# upcast the input to the determined dtype, inputs = [cast(i, dtype=dtype) for i in inputs]
# but don't downcast anything
assert ( if not all(self.dtype == i.dtype for i in inputs):
dtype == self.dtype
), f"Upcasted inputs do not match the Op's dtype: {dtype} != {self.dtype}"
if not all(self.dtype == cast(i, dtype=dtype).dtype for i in inputs):
raise TypeError( raise TypeError(
f"Expected inputs upcastable to {self.dtype}; " f"Expected inputs to be upcastable to {self.dtype}; "
f"got {[i.dtype for i in inputs]}" f"got {[i.dtype for i in inputs]}"
) )
inputs = [cast(i, dtype=dtype) for i in inputs]
assert all(self.dtype == a.dtype for a in inputs)
assert all(a.ndim == 0 for a in inputs)
if inputs: if inputs:
dtype = inputs[0].type.dtype dtype = inputs[0].type.dtype
else: else:
dtype = self.dtype dtype = self.dtype
# bcastable = (len(inputs) == 1)
bcastable = False otype = TensorType(dtype=dtype, broadcastable=(len(inputs) == 1,))
otype = TensorType(broadcastable=(bcastable,), dtype=dtype)
return Apply(self, inputs, [otype()]) return Apply(self, inputs, [otype()])
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
......
...@@ -653,6 +653,36 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -653,6 +653,36 @@ def local_dimshuffle_lift(fgraph, node):
return [ret] return [ret]
@register_canonicalize
@register_specialize
@local_optimizer([DimShuffle])
def local_useless_dimshuffle_makevector(fgraph, node):
r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s.
This rewrite is needed in order to clean up after
`local_subtensor_remove_broadcastable_index`, which produces a
not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)`
(i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`.
"""
# The `DimShuffle` should be removing the single broadcastable dimension
if node.op.new_order != ():
return
makevector_out = node.inputs[0]
if (
not makevector_out.owner
or not isinstance(makevector_out.owner.op, MakeVector)
or not makevector_out.broadcastable == (True,)
):
return
assert len(makevector_out.owner.inputs) == 1
return [makevector_out.owner.inputs[0]]
@register_canonicalize @register_canonicalize
@local_optimizer([Reshape]) @local_optimizer([Reshape])
def local_useless_dimshuffle_in_reshape(fgraph, node): def local_useless_dimshuffle_in_reshape(fgraph, node):
......
...@@ -57,13 +57,8 @@ from aesara.tensor.math import ( ...@@ -57,13 +57,8 @@ from aesara.tensor.math import (
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh, tensordot, true_div from aesara.tensor.math import tanh, tensordot, true_div
from aesara.tensor.nnet.blocksparse import sparse_block_dot from aesara.tensor.nnet.blocksparse import sparse_block_dot
from aesara.tensor.shape import shape, shape_padleft from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
get_constant_idx,
)
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
discrete_dtypes, discrete_dtypes,
...@@ -2024,6 +2019,7 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels): ...@@ -2024,6 +2019,7 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
""" """
shape_of = None
if hasattr(fgraph, "shape_feature"): if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
# TODO: consider cases where shape_of[labels] is constant, and # TODO: consider cases where shape_of[labels] is constant, and
...@@ -2045,15 +2041,11 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels): ...@@ -2045,15 +2041,11 @@ def _check_rows_is_arange_len_labels(fgraph, rows, labels):
# Not sure if that case happens any more after the introduction of # Not sure if that case happens any more after the introduction of
# ShapeOptimizer, but we keep it if ShapeOptimizer is not present # ShapeOptimizer, but we keep it if ShapeOptimizer is not present
if isinstance(stop.owner.op, Subtensor): if isinstance(stop.owner.op, DimShuffle) and stop.owner.op.new_order == ():
shape_subtensor = stop.owner shape_var = stop.owner.inputs[0]
if get_constant_idx( if shape_var.owner and isinstance(shape_var.owner.op, Shape):
shape_subtensor.op.idx_list, shape_subtensor.inputs, allow_partial=True return shape_var.owner.inputs[0] is labels
) == [0]: elif shape_of:
shape_var = shape_subtensor.inputs[0]
if shape_var.owner and shape_var.owner.op == shape:
return shape_var.owner.inputs[0] is labels
else:
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
return shape_of[labels][0] is stop return shape_of[labels][0] is stop
......
...@@ -57,11 +57,15 @@ class Shape(COp): ...@@ -57,11 +57,15 @@ class Shape(COp):
__props__ = () __props__ = ()
def make_node(self, x): def make_node(self, x):
# Must work for all type that have a shape attribute.
# This will fail at execution time.
if not isinstance(x, Variable): if not isinstance(x, Variable):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
return Apply(self, [x], [aesara.tensor.type.lvector()])
if hasattr(x, "ndim") and x.ndim == 1:
out_var = TensorType(np.int64, (True,))()
else:
out_var = aesara.tensor.type.lvector()
return Apply(self, [x], [out_var])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
(x,) = inp (x,) = inp
......
...@@ -94,6 +94,7 @@ from aesara.tensor.math import sum as at_sum ...@@ -94,6 +94,7 @@ from aesara.tensor.math import sum as at_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
bscalar,
bvector, bvector,
col, col,
dmatrix, dmatrix,
...@@ -372,6 +373,133 @@ TestOnesLikeBroadcast = makeBroadcastTester( ...@@ -372,6 +373,133 @@ TestOnesLikeBroadcast = makeBroadcastTester(
) )
class TestMakeVector(utt.InferShapeTester):
b = bscalar()
i = iscalar()
d = dscalar()
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, b)),
("int32", (i, b)),
("int32", (b, i)),
("float64", (b, i)),
("float64", (b, d)),
("float64", (d, i)),
("float64", ()),
("int64", ()),
],
)
def test_make_vector(self, dtype, inputs):
b, i, d = self.b, self.i, self.d
val = {b: 2, i: -3, d: 0.7}
mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
if dtype in int_dtypes:
# The gradient should be 0
utt.assert_allclose(g_val, 0)
else:
for var, grval in zip((b, i, d), g_val):
float_inputs = []
if var.dtype in int_dtypes:
pass
# Currently we don't do any checks on these variables
# verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# be 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
def test_make_vector_fail(self):
with pytest.raises(ValueError):
a, b = vector(), vector()
MakeVector()(a, b)
a, b = iscalar(), lscalar()
res = MakeVector("int64")(a, b)
assert res.dtype == "int64"
with pytest.raises(TypeError):
res = MakeVector("int32")(a, b)
res = MakeVector()(a)
assert res.broadcastable == (True,)
res = MakeVector()()
assert res.broadcastable == (False,)
def test_infer_shape(self):
adscal = dscalar()
bdscal = dscalar()
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
discal = iscalar()
adscal_val = np.random.random()
bdscal_val = np.random.random()
aiscal_val = self.rng.integers(10)
biscal_val = self.rng.integers(10)
ciscal_val = self.rng.integers(10)
discal_val = self.rng.integers(10)
self._compile_and_check(
[adscal, aiscal],
[MakeVector("float64")(adscal, aiscal)],
[adscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[adscal, bdscal, aiscal],
[MakeVector("float64")(adscal, bdscal, aiscal)],
[adscal_val, bdscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[aiscal, biscal, ciscal, discal],
[MakeVector("int32")(aiscal, biscal, ciscal, discal)],
[aiscal_val, biscal_val, ciscal_val, discal_val],
MakeVector,
)
class ApplyDefaultTestOp(Op): class ApplyDefaultTestOp(Op):
def __init__(self, id): def __init__(self, id):
self.default_output = id self.default_output = id
......
...@@ -86,7 +86,14 @@ from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub ...@@ -86,7 +86,14 @@ from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import local_lift_transpose_through_dot from aesara.tensor.math_opt import local_lift_transpose_through_dot
from aesara.tensor.shape import Reshape, Shape_i, SpecifyShape, reshape, specify_shape from aesara.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
reshape,
shape,
specify_shape,
)
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
Subtensor, Subtensor,
...@@ -97,7 +104,6 @@ from aesara.tensor.subtensor import ( ...@@ -97,7 +104,6 @@ from aesara.tensor.subtensor import (
) )
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
bscalar,
dmatrices, dmatrices,
dmatrix, dmatrix,
dscalar, dscalar,
...@@ -106,7 +112,6 @@ from aesara.tensor.type import ( ...@@ -106,7 +112,6 @@ from aesara.tensor.type import (
fscalar, fscalar,
fvector, fvector,
imatrices, imatrices,
int_dtypes,
iscalar, iscalar,
ivector, ivector,
lscalar, lscalar,
...@@ -2456,131 +2461,6 @@ class TestLocalOptAllocF16(TestLocalOptAlloc): ...@@ -2456,131 +2461,6 @@ class TestLocalOptAllocF16(TestLocalOptAlloc):
dtype = "float16" dtype = "float16"
class TestMakeVector(utt.InferShapeTester):
b = bscalar()
i = iscalar()
d = dscalar()
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())
super().setup_method()
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, b)),
("int32", (i, b)),
("int32", (b, i)),
("float64", (b, i)),
("float64", (b, d)),
("float64", (d, i)),
("float64", ()),
("int64", ()),
],
)
def test_make_vector(self, dtype, inputs):
b, i, d = self.b, self.i, self.d
val = {b: 2, i: -3, d: 0.7}
mv = MakeVector(dtype=dtype)(*inputs)
assert mv.dtype == dtype
f = function([b, i, d], mv, on_unused_input="ignore")
f(val[b], val[i], val[d])
s = mv.sum()
gb = aesara.gradient.grad(s, b, disconnected_inputs="ignore")
gi = aesara.gradient.grad(s, i, disconnected_inputs="ignore")
gd = aesara.gradient.grad(s, d, disconnected_inputs="ignore")
g = function([b, i, d], [gb, gi, gd])
g_val = g(val[b], val[i], val[d])
if dtype in int_dtypes:
# The gradient should be 0
utt.assert_allclose(g_val, 0)
else:
for var, grval in zip((b, i, d), g_val):
float_inputs = []
if var.dtype in int_dtypes:
pass
# Currently we don't do any checks on these variables
# verify_grad doesn't support integer inputs yet
# however, the gradient on them is *not* defined to
# be 0
elif var not in inputs:
assert grval == 0
else:
float_inputs.append(var)
# Build a function that takes float_inputs, use fix values for the
# other inputs, and returns the MakeVector. Use it for verify_grad.
if float_inputs:
def fun(*fl_inputs):
f_inputs = []
for var in f_inputs:
if var in fl_inputs:
# use symbolic variable
f_inputs.append(var)
else:
# use constant value
f_inputs.append(val[var])
return MakeVector(dtype=dtype)(*f_inputs)
utt.verify_grad(fun, [val[ri] for ri in float_inputs])
@pytest.mark.parametrize(
"dtype, inputs",
[
("int8", (b, i)),
("int8", (i, b)),
("int8", (b, d)),
("int8", (i, i)),
("int32", (d, i)),
("int32", (i, d)),
("float32", (i, d)),
],
)
def test_make_vector_fail(self, dtype, inputs):
with pytest.raises(AssertionError):
MakeVector(dtype=dtype)(*inputs)
def test_infer_shape(self):
adscal = dscalar()
bdscal = dscalar()
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
discal = iscalar()
adscal_val = np.random.random()
bdscal_val = np.random.random()
aiscal_val = self.rng.integers(10)
biscal_val = self.rng.integers(10)
ciscal_val = self.rng.integers(10)
discal_val = self.rng.integers(10)
self._compile_and_check(
[adscal, aiscal],
[MakeVector("float64")(adscal, aiscal)],
[adscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[adscal, bdscal, aiscal],
[MakeVector("float64")(adscal, bdscal, aiscal)],
[adscal_val, bdscal_val, aiscal_val],
MakeVector,
)
self._compile_and_check(
[aiscal, biscal, ciscal, discal],
[MakeVector("int32")(aiscal, biscal, ciscal, discal)],
[aiscal_val, biscal_val, ciscal_val, discal_val],
MakeVector,
)
def test_local_join_1(): def test_local_join_1():
# test for vector # test for vector
a = vector("a") a = vector("a")
...@@ -3582,3 +3462,48 @@ def test_local_remove_scalar_BroadcastTo(): ...@@ -3582,3 +3462,48 @@ def test_local_remove_scalar_BroadcastTo():
) )
assert res is x assert res is x
def test_local_useless_dimshuffle_makevector():
a = scalar()
x = MakeVector(config.floatX)(a)
y = x.dimshuffle(())
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_useless_dimshuffle_makevector"],
)
assert y_opt_fg.outputs[0] == a
def test_Shape_i_canonicalize():
"""Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension.
In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)``
and nothing else. The rewrites `local_shape_to_shape_i`,
`local_subtensor_remove_broadcastable_index`, and
`local_useless_dimshuffle_makevector` need to work together to accomplish
this, and we confirm that here.
"""
x = vector()
y = shape(x)[0]
y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()])
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=[
"canonicalize",
],
)
y_opt = y_opt_fg.outputs[0]
assert isinstance(y_opt.owner.op, Shape_i)
assert y_opt.owner.op.i == 0
assert y_opt.owner.inputs[0] == x
...@@ -2110,6 +2110,8 @@ class TestLocalUselessElemwiseComparison: ...@@ -2110,6 +2110,8 @@ class TestLocalUselessElemwiseComparison:
"local_shape_to_shape_i", "local_shape_to_shape_i",
"local_track_shape_i", "local_track_shape_i",
"local_subtensor_make_vector", "local_subtensor_make_vector",
"local_subtensor_remove_broadcastable_index",
"local_useless_dimshuffle_makevector",
) )
f = function([x], lt(x.shape[0], 0), mode=mode) f = function([x], lt(x.shape[0], 0), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
......
...@@ -5,7 +5,9 @@ import aesara ...@@ -5,7 +5,9 @@ import aesara
from aesara import Mode, function from aesara import Mode, function
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.basic import MakeVector, as_tensor_variable, constant from aesara.tensor.basic import MakeVector, as_tensor_variable, constant
...@@ -28,6 +30,7 @@ from aesara.tensor.type import ( ...@@ -28,6 +30,7 @@ from aesara.tensor.type import (
dvector, dvector,
fvector, fvector,
ivector, ivector,
lscalar,
matrix, matrix,
scalar, scalar,
tensor3, tensor3,
...@@ -41,6 +44,25 @@ from tests.test_rop import RopLopChecker ...@@ -41,6 +44,25 @@ from tests.test_rop import RopLopChecker
def test_shape_basic(): def test_shape_basic():
s = shape([])
assert s.type.broadcastable == (True,)
s = shape([10])
assert s.type.broadcastable == (True,)
s = shape(lscalar())
assert s.type.broadcastable == (False,)
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType()))
assert s.type.broadcastable == (False,)
s = shape(np.array(1)) s = shape(np.array(1))
assert np.array_equal(eval_outputs([s]), []) assert np.array_equal(eval_outputs([s]), [])
......
...@@ -158,7 +158,7 @@ def test_local_useless_inc_subtensor_increment_zeros(): ...@@ -158,7 +158,7 @@ def test_local_useless_inc_subtensor_increment_zeros():
r"""Make sure we remove `IncSubtensor`\s that are increments on entire zero arrays.""" r"""Make sure we remove `IncSubtensor`\s that are increments on entire zero arrays."""
y = matrix("y") y = matrix("y")
s = aet.zeros((2, 2))[:, :] s = at.zeros((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape)) o_shape = inc_subtensor(s, specify_shape(y, s.shape))
mode = get_default_mode().including("local_useless_inc_subtensor") mode = get_default_mode().including("local_useless_inc_subtensor")
...@@ -195,7 +195,7 @@ def test_local_useless_inc_subtensor_no_opt(): ...@@ -195,7 +195,7 @@ def test_local_useless_inc_subtensor_no_opt():
assert any(isinstance(n.op, IncSubtensor) for n in topo) assert any(isinstance(n.op, IncSubtensor) for n in topo)
# This is an increment with a non-zero target array # This is an increment with a non-zero target array
s = aet.ones((2, 2))[:, :] s = at.ones((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape)) o_shape = inc_subtensor(s, specify_shape(y, s.shape))
f_shape = function([y], o_shape, mode=mode) f_shape = function([y], o_shape, mode=mode)
...@@ -618,8 +618,7 @@ class TestLocalSubtensorMakeVector: ...@@ -618,8 +618,7 @@ class TestLocalSubtensorMakeVector:
opt_fgraph = f.maker.fgraph opt_fgraph = f.maker.fgraph
assert opt_fgraph.outputs[0].dtype == "int32" assert opt_fgraph.outputs[0].dtype == "int32"
assert isinstance(opt_fgraph.outputs[0].owner.op, Rebroadcast) assert isinstance(opt_fgraph.outputs[0].owner.op, MakeVector)
assert isinstance(opt_fgraph.outputs[0].owner.inputs[0].owner.op, MakeVector)
assert f(0, 1, 2) == np.array([0], dtype=np.int32) assert f(0, 1, 2) == np.array([0], dtype=np.int32)
def test_slice_idx_start(self): def test_slice_idx_start(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论