提交 ea528820 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make local_rv_size_lift a local optimization and simplify tests

上级 5db98be1
...@@ -40,19 +40,21 @@ optdb.register( ...@@ -40,19 +40,21 @@ optdb.register(
) )
def lift_rv_shapes(node): @local_optimizer(tracks=None)
"""Lift `RandomVariable`'s shape-related parameters. def local_rv_size_lift(fgraph, node):
"""Lift the ``size`` parameter in a ``RandomVariable``.
In other words, this will broadcast the distribution parameters and In other words, this will broadcast the distribution parameters by adding
extra dimensions added by the `size` parameter. the extra dimensions implied by the ``size`` parameter, and remove the
``size`` parameter in the process.
For example, ``normal([0.0, 1.0], 5.0, size=(3, 2))`` becomes For example, ``normal(0, 1, size=(1, 2))`` becomes
``normal([[0., 1.], [0., 1.], [0., 1.]], [[5., 5.], [5., 5.], [5., 5.]])``. ``normal([[0, 0]], [[1, 1]], size=())``.
""" """
if not isinstance(node.op, RandomVariable): if not isinstance(node.op, RandomVariable):
return False return
rng, size, dtype, *dist_params = node.inputs rng, size, dtype, *dist_params = node.inputs
...@@ -65,13 +67,15 @@ def lift_rv_shapes(node): ...@@ -65,13 +67,15 @@ def lift_rv_shapes(node):
) )
for p in dist_params for p in dist_params
] ]
else:
return
new_node = node.op.make_node(rng, None, dtype, *dist_params) new_node = node.op.make_node(rng, None, dtype, *dist_params)
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(new_node) compute_test_value(new_node)
return new_node return new_node.outputs
@local_optimizer([DimShuffle]) @local_optimizer([DimShuffle])
......
...@@ -19,26 +19,59 @@ from aesara.tensor.random.basic import ( ...@@ -19,26 +19,59 @@ from aesara.tensor.random.basic import (
) )
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import ( from aesara.tensor.random.opt import (
lift_rv_shapes,
local_dimshuffle_rv_lift, local_dimshuffle_rv_lift,
local_rv_size_lift,
local_subtensor_rv_lift, local_subtensor_rv_lift,
) )
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector from aesara.tensor.type import iscalar, vector
inplace_mode = Mode(
"py", OptimizationQuery(include=["random_make_inplace"], exclude=[])
)
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[])) no_mode = Mode("py", OptimizationQuery(include=[], exclude=[]))
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng):
dist_params_aet = []
for p in dist_params:
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)
size_aet = []
for s in size:
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)
dist_st = op_fn(dist_op(*dist_params_aet, size=size_aet, rng=rng))
f_inputs = [
p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
]
mode = Mode("py", EquilibriumOptimizer([opt], max_use_ratio=100))
f_opt = function(
f_inputs,
dist_st,
mode=mode,
)
(new_out,) = f_opt.maker.fgraph.outputs
return new_out, f_inputs, dist_st, f_opt
def test_inplace_optimization(): def test_inplace_optimization():
out = normal(0, 1) out = normal(0, 1)
assert out.owner.op.inplace is False assert out.owner.op.inplace is False
inplace_mode = Mode(
"py", OptimizationQuery(include=["random_make_inplace"], exclude=[])
)
f = function( f = function(
[], [],
out, out,
...@@ -55,80 +88,62 @@ def test_inplace_optimization(): ...@@ -55,80 +88,62 @@ def test_inplace_optimization():
) )
def check_shape_lifted_rv(rv, params, size, rng):
aet_params = []
for p in params:
p_aet = aet.as_tensor(p)
p_aet = p_aet.type()
p_aet.tag.test_value = p
aet_params.append(p_aet)
aet_size = []
for s in size:
s_aet = aet.as_tensor(s)
s_aet = s_aet.type()
s_aet.tag.test_value = s
aet_size.append(s_aet)
rv = rv(*aet_params, size=aet_size, rng=rng)
rv_lifted = lift_rv_shapes(rv.owner)
# Make sure the size input is empty
assert np.array_equal(rv_lifted.inputs[1].data, [])
f_ref = function(
aet_params + aet_size,
rv,
mode=no_mode,
)
f_lifted = function(
aet_params + aet_size,
rv_lifted.outputs[1],
mode=no_mode,
)
f_ref_val = f_ref(*(params + size))
f_lifted_val = f_lifted(*(params + size))
assert np.array_equal(f_ref_val, f_lifted_val)
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_lift_rv_shapes(): @pytest.mark.parametrize(
"dist_op, dist_params, size",
[
(
normal,
[
np.array(1.0, dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[],
),
(
normal,
[
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[],
),
(
normal,
[
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[3, 2],
),
(
multivariate_normal,
[
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
],
[2, 3],
),
(
dirichlet,
[np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)],
[2, 3],
),
],
)
def test_local_rv_size_lift(dist_op, dist_params, size):
rng = shared(np.random.RandomState(1233532), borrow=False) rng = shared(np.random.RandomState(1233532), borrow=False)
test_params = [ new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
np.array(1.0, dtype=config.floatX), local_rv_size_lift,
np.array(5.0, dtype=config.floatX), lambda rv: rv,
] dist_op,
test_size = [] dist_params,
check_shape_lifted_rv(normal, test_params, test_size, rng) size,
rng,
test_params = [ )
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng)
test_params = [
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = [3, 2]
check_shape_lifted_rv(normal, test_params, test_size, rng)
test_params = [
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
]
test_size = [2, 3]
check_shape_lifted_rv(multivariate_normal, test_params, test_size, rng)
test_params = [ assert aet.get_vector_length(new_out.owner.inputs[1]) == 0
np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)
]
test_size = [2, 3]
check_shape_lifted_rv(dirichlet, test_params, test_size, rng)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -274,36 +289,15 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -274,36 +289,15 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.RandomState(1233532), borrow=False) rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_aet = [] new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
for p in dist_params: local_dimshuffle_rv_lift,
p_aet = aet.as_tensor(p).type() lambda rv: rv.dimshuffle(ds_order),
p_aet.tag.test_value = p dist_op,
dist_params_aet.append(p_aet) dist_params,
size,
size_aet = [] rng,
for s in size:
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)
dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng).dimshuffle(ds_order)
f_inputs = [
p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
]
mode = Mode(
"py", EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100)
)
f_opt = function(
f_inputs,
dist_st,
mode=mode,
) )
(new_out,) = f_opt.maker.fgraph.outputs
if lifted: if lifted:
assert new_out.owner.op == dist_op assert new_out.owner.op == dist_op
assert all( assert all(
...@@ -407,23 +401,10 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -407,23 +401,10 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
) )
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise") @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")
def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
from aesara.tensor.subtensor import as_index_constant
rng = shared(np.random.RandomState(1233532), borrow=False) rng = shared(np.random.RandomState(1233532), borrow=False)
dist_params_aet = []
for p in dist_params:
p_aet = aet.as_tensor(p).type()
p_aet.tag.test_value = p
dist_params_aet.append(p_aet)
size_aet = []
for s in size:
s_aet = iscalar()
s_aet.tag.test_value = s
size_aet.append(s_aet)
from aesara.tensor.subtensor import as_index_constant
indices_aet = () indices_aet = ()
for i in indices: for i in indices:
i_aet = as_index_constant(i) i_aet = as_index_constant(i)
...@@ -431,26 +412,15 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): ...@@ -431,26 +412,15 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
i_aet.tag.test_value = i i_aet.tag.test_value = i
indices_aet += (i_aet,) indices_aet += (i_aet,)
dist_st = dist_op(*dist_params_aet, size=size_aet, rng=rng)[indices_aet] new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
local_subtensor_rv_lift,
f_inputs = [ lambda rv: rv[indices_aet],
p dist_op,
for p in dist_params_aet + size_aet + list(indices_aet) dist_params,
if not isinstance(p, (slice, Constant)) size,
] rng,
mode = Mode(
"py", EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100)
) )
f_opt = function(
f_inputs,
dist_st,
mode=mode,
)
(new_out,) = f_opt.maker.fgraph.outputs
if lifted: if lifted:
assert isinstance(new_out.owner.op, RandomVariable) assert isinstance(new_out.owner.op, RandomVariable)
assert all( assert all(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论