提交 844a18e1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.tensor.random.opt to aesara.tensor.random.rewriting

上级 380fca03
# Initialize `RandomVariable` optimizations # Initialize `RandomVariable` rewrites
import aesara.tensor.random.opt import aesara.tensor.random.rewriting
import aesara.tensor.random.utils import aesara.tensor.random.utils
from aesara.tensor.random.basic import * from aesara.tensor.random.basic import *
from aesara.tensor.random.op import RandomState, default_rng from aesara.tensor.random.op import RandomState, default_rng
......
...@@ -60,8 +60,8 @@ from aesara.tensor.type import iscalar, scalar, tensor ...@@ -60,8 +60,8 @@ from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param from tests.unittest_tools import create_aesara_param
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) rewrites_query = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
py_mode = Mode("py", opts) py_mode = Mode("py", rewrites_query)
def fixed_scipy_rvs(rvs_name): def fixed_scipy_rvs(rvs_name):
......
...@@ -19,7 +19,7 @@ from aesara.tensor.random.basic import ( ...@@ -19,7 +19,7 @@ from aesara.tensor.random.basic import (
uniform, uniform,
) )
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import ( from aesara.tensor.random.rewriting import (
local_dimshuffle_rv_lift, local_dimshuffle_rv_lift,
local_rv_size_lift, local_rv_size_lift,
local_subtensor_rv_lift, local_subtensor_rv_lift,
...@@ -31,7 +31,9 @@ from aesara.tensor.type import iscalar, vector ...@@ -31,7 +31,9 @@ from aesara.tensor.type import iscalar, vector
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[])) no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None): def apply_local_rewrite_to_rv(
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
):
dist_params_at = [] dist_params_at = []
for p in dist_params: for p in dist_params:
p_at = at.as_tensor(p).type() p_at = at.as_tensor(p).type()
...@@ -50,20 +52,20 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None ...@@ -50,20 +52,20 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None
p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant)) p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant))
] ]
mode = Mode("py", EquilibriumGraphRewriter([opt], max_use_ratio=100)) mode = Mode("py", EquilibriumGraphRewriter([rewrite], max_use_ratio=100))
f_opt = function( f_rewritten = function(
f_inputs, f_inputs,
dist_st, dist_st,
mode=mode, mode=mode,
) )
(new_out,) = f_opt.maker.fgraph.outputs (new_out,) = f_rewritten.maker.fgraph.outputs
return new_out, f_inputs, dist_st, f_opt return new_out, f_inputs, dist_st, f_rewritten
def test_inplace_optimization(): def test_inplace_rewrites():
out = normal(0, 1) out = normal(0, 1)
out.owner.inputs[0].default_update = out.owner.outputs[0] out.owner.inputs[0].default_update = out.owner.outputs[0]
...@@ -87,7 +89,7 @@ def test_inplace_optimization(): ...@@ -87,7 +89,7 @@ def test_inplace_optimization():
assert np.array_equal(new_out.owner.inputs[1].data, []) assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_optimization_extra_props(): def test_inplace_rewrites_extra_props():
class Test(RandomVariable): class Test(RandomVariable):
name = "test" name = "test"
ndim_supp = 0 ndim_supp = 0
...@@ -183,7 +185,7 @@ def test_inplace_optimization_extra_props(): ...@@ -183,7 +185,7 @@ def test_inplace_optimization_extra_props():
def test_local_rv_size_lift(dist_op, dist_params, size): def test_local_rv_size_lift(dist_op, dist_params, size):
rng = shared(np.random.default_rng(1233532), borrow=False) rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv(
local_rv_size_lift, local_rv_size_lift,
lambda rv: rv, lambda rv: rv,
dist_op, dist_op,
...@@ -349,7 +351,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -349,7 +351,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.default_rng(1233532), borrow=False) rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv(
local_dimshuffle_rv_lift, local_dimshuffle_rv_lift,
lambda rv: rv.dimshuffle(ds_order), lambda rv: rv.dimshuffle(ds_order),
dist_op, dist_op,
...@@ -377,9 +379,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -377,9 +379,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
arg_values = [p.get_test_value() for p in f_inputs] arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values) res_base = f_base(*arg_values)
res_opt = f_opt(*arg_values) res_rewritten = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_opt, rtol=rtol) np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -472,7 +474,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): ...@@ -472,7 +474,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
i_at.tag.test_value = i i_at.tag.test_value = i
indices_at += (i_at,) indices_at += (i_at,)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv(
local_subtensor_rv_lift, local_subtensor_rv_lift,
lambda rv: rv[indices_at], lambda rv: rv[indices_at],
dist_op, dist_op,
...@@ -502,9 +504,9 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): ...@@ -502,9 +504,9 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
arg_values = [p.get_test_value() for p in f_inputs] arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values) res_base = f_base(*arg_values)
res_opt = f_opt(*arg_values) res_rewritten = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_opt, rtol=1e-3) np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3)
def test_Subtensor_lift_restrictions(): def test_Subtensor_lift_restrictions():
...@@ -615,7 +617,7 @@ def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rt ...@@ -615,7 +617,7 @@ def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rt
rng = shared(np.random.default_rng(1233532), borrow=False) rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, *_ = apply_local_opt_to_rv( new_out, *_ = apply_local_rewrite_to_rv(
local_dimshuffle_rv_lift, local_dimshuffle_rv_lift,
lambda rv: rv.dimshuffle(ds_order), lambda rv: rv.dimshuffle(ds_order),
dist_op, dist_op,
......
...@@ -11,8 +11,8 @@ from tests import unittest_tools as utt ...@@ -11,8 +11,8 @@ from tests import unittest_tools as utt
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def set_aesara_flags(): def set_aesara_flags():
opts = RewriteDatabaseQuery(include=[None], exclude=[]) rewrites_query = RewriteDatabaseQuery(include=[None], exclude=[])
py_mode = Mode("py", opts) py_mode = Mode("py", rewrites_query)
with config.change_flags(mode=py_mode, compute_test_value="warn"): with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield yield
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论