提交 844a18e1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename aesara.tensor.random.opt to aesara.tensor.random.rewriting

上级 380fca03
# Initialize `RandomVariable` optimizations
import aesara.tensor.random.opt
# Initialize `RandomVariable` rewrites
import aesara.tensor.random.rewriting
import aesara.tensor.random.utils
from aesara.tensor.random.basic import *
from aesara.tensor.random.op import RandomState, default_rng
......
......@@ -60,8 +60,8 @@ from aesara.tensor.type import iscalar, scalar, tensor
from tests.unittest_tools import create_aesara_param
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
py_mode = Mode("py", opts)
rewrites_query = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
py_mode = Mode("py", rewrites_query)
def fixed_scipy_rvs(rvs_name):
......
......@@ -19,7 +19,7 @@ from aesara.tensor.random.basic import (
uniform,
)
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
from aesara.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
local_rv_size_lift,
local_subtensor_rv_lift,
......@@ -31,7 +31,9 @@ from aesara.tensor.type import iscalar, vector
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None):
def apply_local_rewrite_to_rv(
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
):
dist_params_at = []
for p in dist_params:
p_at = at.as_tensor(p).type()
......@@ -50,20 +52,20 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None
p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant))
]
mode = Mode("py", EquilibriumGraphRewriter([opt], max_use_ratio=100))
mode = Mode("py", EquilibriumGraphRewriter([rewrite], max_use_ratio=100))
f_opt = function(
f_rewritten = function(
f_inputs,
dist_st,
mode=mode,
)
(new_out,) = f_opt.maker.fgraph.outputs
(new_out,) = f_rewritten.maker.fgraph.outputs
return new_out, f_inputs, dist_st, f_opt
return new_out, f_inputs, dist_st, f_rewritten
def test_inplace_optimization():
def test_inplace_rewrites():
out = normal(0, 1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
......@@ -87,7 +89,7 @@ def test_inplace_optimization():
assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_optimization_extra_props():
def test_inplace_rewrites_extra_props():
class Test(RandomVariable):
name = "test"
ndim_supp = 0
......@@ -183,7 +185,7 @@ def test_inplace_optimization_extra_props():
def test_local_rv_size_lift(dist_op, dist_params, size):
rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv(
local_rv_size_lift,
lambda rv: rv,
dist_op,
......@@ -349,7 +351,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv(
local_dimshuffle_rv_lift,
lambda rv: rv.dimshuffle(ds_order),
dist_op,
......@@ -377,9 +379,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values)
res_opt = f_opt(*arg_values)
res_rewritten = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_opt, rtol=rtol)
np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol)
@pytest.mark.parametrize(
......@@ -472,7 +474,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
i_at.tag.test_value = i
indices_at += (i_at,)
new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv(
new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv(
local_subtensor_rv_lift,
lambda rv: rv[indices_at],
dist_op,
......@@ -502,9 +504,9 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values)
res_opt = f_opt(*arg_values)
res_rewritten = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)
np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3)
def test_Subtensor_lift_restrictions():
......@@ -615,7 +617,7 @@ def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rt
rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, *_ = apply_local_opt_to_rv(
new_out, *_ = apply_local_rewrite_to_rv(
local_dimshuffle_rv_lift,
lambda rv: rv.dimshuffle(ds_order),
dist_op,
......
......@@ -11,8 +11,8 @@ from tests import unittest_tools as utt
@pytest.fixture(scope="module", autouse=True)
def set_aesara_flags():
opts = RewriteDatabaseQuery(include=[None], exclude=[])
py_mode = Mode("py", opts)
rewrites_query = RewriteDatabaseQuery(include=[None], exclude=[])
py_mode = Mode("py", rewrites_query)
with config.change_flags(mode=py_mode, compute_test_value="warn"):
yield
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论