提交 aad5d35c authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Implement DimShuffle lifting optimization for RandomVariables

This optimization does *not* preserve equality between the numeric results of the untransformed and transformed graphs when the RNGs and seeds are equal. The reason is that the underlying sampler methods themselves are not implemented in Theano, so we cannot apply the requisite DimShuffle-like operations to the intermediate samples used to generate multiple replications and/or independent variates. For example, sampling a normal of size (3, 2) requires a draw of size (3, 2) from a standard normal and we can't transpose that (3, 2) array. If we could, then we would be able to maintain numerical equality between graphs.
上级 a328dd5d
import numpy as np
import pytest
import theano.tensor as tt
from theano import change_flags, config, shared
from theano.compile.function import function
from theano.compile.mode import Mode
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Constant
from theano.gof.opt import EquilibriumOptimizer
from theano.gof.optdb import Query
from theano.tensor.random.basic import normal
from theano.tensor.elemwise import DimShuffle
from theano.tensor.random.basic import dirichlet, multivariate_normal, normal, poisson
from theano.tensor.random.opt import lift_rv_shapes, local_dimshuffle_rv_lift
opts = Query(include=["random_make_inplace"], exclude=[])
inplace_mode = Mode("py", opts)
inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[]))
no_mode = Mode("py", Query(include=[], exclude=[]))
def test_inplace_optimization():
......@@ -30,3 +38,277 @@ def test_inplace_optimization():
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[1:], out.owner.inputs[1:])
)
def check_shape_lifted_rv(rv, params, size, rng):
tt_params = []
for p in params:
p_tt = tt.as_tensor(p)
p_tt = p_tt.type()
p_tt.tag.test_value = p
tt_params.append(p_tt)
tt_size = []
for s in size:
s_tt = tt.as_tensor(s)
s_tt = s_tt.type()
s_tt.tag.test_value = s
tt_size.append(s_tt)
rv = rv(*tt_params, size=tt_size, rng=rng)
rv_lifted = lift_rv_shapes(rv.owner)
# Make sure the size input is empty
assert np.array_equal(rv_lifted.inputs[1].data, [])
f_ref = function(
tt_params + tt_size,
rv,
mode=no_mode,
)
f_lifted = function(
tt_params + tt_size,
rv_lifted.outputs[1],
mode=no_mode,
)
f_ref_val = f_ref(*(params + size))
f_lifted_val = f_lifted(*(params + size))
assert np.array_equal(f_ref_val, f_lifted_val)
@change_flags(compute_test_value="raise")
def test_lift_rv_shapes():
rng = shared(np.random.RandomState(1233532), borrow=False)
test_params = [
np.array(1.0, dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng)
test_params = [
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = [3, 2]
check_shape_lifted_rv(normal, test_params, test_size, rng)
test_params = [
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
]
test_size = [2, 3]
check_shape_lifted_rv(multivariate_normal, test_params, test_size, rng)
test_params = [
np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)
]
test_size = [2, 3]
check_shape_lifted_rv(dirichlet, test_params, test_size, rng)
@pytest.mark.parametrize(
"ds_order, lifted, dist_op, dist_params, size, rtol",
[
(
(1, 0, 2),
True,
normal,
(
np.arange(2 * 2 * 2).reshape((2, 2, 2)).astype(config.floatX),
np.array(1e-6).astype(config.floatX),
),
(),
1e-3,
),
(
(0, 1, 2),
True,
normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
(2, 1, 2),
1e-3,
),
(
(0, 2, 1),
True,
normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
(2, 1, 2),
1e-3,
),
(
(1, 0, 2),
True,
normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
(2, 1, 2),
1e-3,
),
(
(0, 2, 1),
True,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.array([[1e-6, 2e-6]], dtype=config.floatX),
),
(3, 2, 2),
1e-3,
),
(
("x", 0, 2, 1, "x"),
True,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.array([[1e-6, 2e-6]], dtype=config.floatX),
),
(3, 2, 2),
1e-3,
),
(
("x", 0, "x", 2, "x", 1, "x"),
True,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.array([[1e-6, 2e-6]], dtype=config.floatX),
),
(3, 2, 2),
1e-3,
),
(
("x", 0, 2, 1, "x"),
True,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.array([[1e-6, 2e-6]], dtype=config.floatX),
),
(3, 2, 2),
1e-3,
),
(
("x", 1, 0, 2, "x"),
False,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.array([[1e-6, 2e-6]], dtype=config.floatX),
),
(3, 2, 2),
1e-3,
),
# Only one distribution parameter
(
(0, 2, 1),
True,
poisson,
(np.array([[10, 50], [100, 150]], dtype=config.floatX),),
(3, 2, 2),
1,
),
# A multi-dimensional case
(
(0, 2, 1),
False,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3,),
1e-3,
),
],
)
@change_flags(compute_test_value_opt="raise", compute_test_value="raise")
def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_tt = []
for p in dist_params:
p_tt = tt.as_tensor(p).type()
p_tt.tag.test_value = p
dist_params_tt.append(p_tt)
size_tt = []
for s in size:
s_tt = tt.iscalar()
s_tt.tag.test_value = s
size_tt.append(s_tt)
dist_st = dist_op(*dist_params_tt, size=size_tt, rng=rng).dimshuffle(ds_order)
f_inputs = [
p for p in dist_params_tt + size_tt if not isinstance(p, (slice, Constant))
]
mode = Mode(
"py", EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100)
)
f_opt = function(
f_inputs,
dist_st,
mode=mode,
)
(new_out,) = f_opt.maker.fgraph.outputs
if lifted:
assert new_out.owner.op == dist_op
assert all(
isinstance(i.owner.op, DimShuffle)
for i in new_out.owner.inputs[3:]
if i.owner
)
else:
assert isinstance(new_out.owner.op, DimShuffle)
return
f_base = function(
f_inputs,
dist_st,
mode=no_mode,
)
arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values)
res_opt = f_opt(*arg_values)
np.testing.assert_allclose(res_base, res_opt, rtol=rtol)
def test_Dimshuffle_lift_restrictions():
rng = shared(np.random.RandomState(1233532), borrow=False)
x = normal(tt.arange(2).reshape((2,)), 100, size=(2, 2, 2), rng=rng)
y = x.dimshuffle(1, 0, 2)
# The non-`Dimshuffle` client depends on the RNG state, so we can't
# perform the lift
z = x - y
fg = FunctionGraph([rng], [z], clone=False)
_ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
assert dimshuffle_node == y.owner
assert isinstance(dimshuffle_node.op, DimShuffle)
assert dimshuffle_node.inputs[0].owner.op == normal
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
z = tt.ones(x.shape) - y
fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner
assert rv_node.op == normal
assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle)
assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle)
import theano.tensor as tt
from theano import config
from theano.compile import optdb
from theano.compile.ops import Shape
from theano.gof.op import compute_test_value
from theano.gof.opt import local_optimizer
from theano.tensor.elemwise import DimShuffle
from theano.tensor.extra_ops import broadcast_to
from theano.tensor.opt import in2out
from theano.tensor.random.op import RandomVariable
from theano.tensor.random.utils import broadcast_params
@local_optimizer([RandomVariable])
......@@ -9,11 +16,8 @@ def random_make_inplace(fgraph, node):
op = node.op
if isinstance(op, RandomVariable) and not op.inplace:
name, ndim_supp, ndims_params, dtype, _ = op._props()
new_op = type(op)(name, ndim_supp, ndims_params, dtype, True)
# rng, size, dtype, *dist_params = node.inputs
return new_op.make_node(*node.inputs).outputs
return False
......@@ -26,3 +30,170 @@ optdb.register(
"fast_run",
"inplace",
)
def lift_rv_shapes(node):
"""Lift `RandomVariable`'s shape-related parameters.
In other words, this will broadcast the distribution parameters and
extra dimensions added by the `size` parameter.
For example, ``normal([0.0, 1.0], 5.0, size=(3, 2))`` becomes
``normal([[0., 1.], [0., 1.], [0., 1.]], [[5., 5.], [5., 5.], [5., 5.]])``.
"""
if not isinstance(node.op, RandomVariable):
return False
rng, size, dtype, *dist_params = node.inputs
dist_params = broadcast_params(dist_params, node.op.ndims_params)
dist_params = [
broadcast_to(
p, (tuple(size) + tuple(p.shape)) if node.op.ndim_supp > 0 else size
)
for p in dist_params
]
return node.op.make_node(rng, None, dtype, *dist_params)
@local_optimizer([DimShuffle])
def local_dimshuffle_rv_lift(fgraph, node):
"""Lift `DimShuffle`s through `RandomVariable` `Op`s.
For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
The basic idea behind this optimization is that we need to separate the
`DimShuffle`ing into independent `DimShuffle`s that each occur in two
distinct sub-spaces: the parameters and ``size`` (i.e. replications)
sub-spaces.
If a `DimShuffle` exchanges dimensions across those two sub-spaces, then we
don't do anything.
Otherwise, if the `DimShuffle` only exchanges dimensions within each of
those sub-spaces, we can break it apart and apply the parameter-space
`DimShuffle` to the `RandomVariable`'s distribution parameters, and the
apply the replications-space `DimShuffle` to the `RandomVariable`'s``size``
tuple. The latter is a particularly simple rearranging of a tuple, but the
former requires a little more work.
"""
ds_op = node.op
if not isinstance(ds_op, DimShuffle):
return False
base_rv = node.inputs[0]
rv_node = base_rv.owner
if not (
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
):
return False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if not all(
(n == node or isinstance(n.op, Shape)) for n, i in fgraph.clients[base_rv]
):
return False
rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
# We need to know the dimensions that were *not* added by the `size`
# parameter (i.e. the dimensions corresponding to independent variates with
# different parameter values)
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)
# If the indices in `ds_new_order` are entirely within the replication
# indices group or the independent variates indices group, then we can apply
# this optimization.
ds_new_order = ds_op.new_order
# Create a map from old index order to new/`DimShuffled` index order
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]
# Find the index at which the replications/independents split occurs
reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)
ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_only_in_ind = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
)
if ds_only_in_ind:
# Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent
# variates dimensions (for univariate distributions, at least)
new_size = (
[
tt.constant(1, dtype="int64") if o == "x" else size[o]
for o in ds_new_order
]
if tt.get_vector_length(size) > 0
else size
)
# Compute the new axes parameter(s) for the `DimShuffle` that will be
# applied to the `RandomVariable` parameters (they need to be offset)
rv_params_new_order = [
d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :]
]
# Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so
# we can only apply the parts of the `DimShuffle` that are relevant.
new_dist_params = []
for d in dist_params:
if d.ndim < len(ds_ind_new_dims):
_rv_params_new_order = [
o
for o in rv_params_new_order
if (isinstance(o, int) and o < d.ndim) or o == "x"
]
else:
_rv_params_new_order = rv_params_new_order
new_dist_params.append(
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
)
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
return [new_node.outputs[1]]
ds_only_in_reps = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims
)
if ds_only_in_reps:
# Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now.
new_size = [
tt.constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]
new_node = rv_op.make_node(rng, new_size, dtype, *dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
return [new_node.outputs[1]]
return False
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论