提交 79d98f1e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add benchmark test for FusionRewriter

上级 842bc52a
......@@ -48,7 +48,7 @@ from pytensor.tensor.math import round as at_round
from pytensor.tensor.math import sin, sinh, sqr, sqrt
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import tan, tanh, true_div, xor
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
from pytensor.tensor.shape import reshape
from pytensor.tensor.type import (
......@@ -302,6 +302,29 @@ class TestFusion:
fwx = fw + fx
ftanx = tan(fx)
def large_fuseable_graph(self, n):
factors = []
sd = dscalar()
means = dvector()
cst_05 = at.constant(0.5)
cst_m05 = at.constant(-0.5)
cst_2 = at.constant(2)
cst_m2 = at.constant(-2)
ones = at.constant(np.ones(10))
for i in range(n):
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
cst_05 * (sd**cst_m2) / np.pi
)
factors.append(at_sum(f))
logp = add(*factors)
vars = [sd, means]
dlogp = [pytensor.grad(logp, v) for v in vars]
return vars, dlogp
@pytest.mark.parametrize(
"case",
[
......@@ -1059,35 +1082,9 @@ class TestFusion:
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_big_fusion(self):
# In the past, pickle of Composite generated in that case
# crashed with max recursion limit. So we were not able to
# generate C code in that case.
factors = []
sd = dscalar()
means = dvector()
cst_05 = at.constant(0.5)
cst_m05 = at.constant(-0.5)
cst_2 = at.constant(2)
cst_m2 = at.constant(-2)
ones = at.constant(np.ones(10))
n = 85
if config.mode in ["DebugMode", "DEBUG_MODE"]:
n = 10
for i in range(n):
f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log(
cst_05 * (sd**cst_m2) / np.pi
)
factors.append(at_sum(f))
logp = add(*factors)
vars = [sd, means]
# Make sure that C compilation is used
mode = Mode("cvm", self.rewrites)
dlogp = function(vars, [pytensor.grad(logp, v) for v in vars], mode=mode)
dlogp = function(*self.large_fuseable_graph(n=85), mode=mode)
# Make sure something was fused
assert any(
......@@ -1362,6 +1359,18 @@ class TestFusion:
func = pytensor.function([], [logp, grad_logp], mode="FAST_RUN")
benchmark(func)
@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
def test_rewrite_benchmark(self, benchmark):
inps, outs = self.large_fuseable_graph(n=25)
fg = FunctionGraph(inps, outs)
opt = FusionOptimizer()
def rewrite_func():
nb_replacement = opt.apply(fg.clone())[2]
return nb_replacement
assert benchmark(rewrite_func) == 103
class TimesN(aes.basic.UnaryScalarOp):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论