提交 0a242f29 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Use pytest-benchmark

上级 382be934
...@@ -115,7 +115,7 @@ jobs: ...@@ -115,7 +115,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: bash -l {0} shell: bash -l {0}
run: | run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov sympy mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib
pip install -e ./ pip install -e ./
...@@ -132,7 +132,7 @@ jobs: ...@@ -132,7 +132,7 @@ jobs:
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
python -m pytest -x -r A --verbose --runslow --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART python -m pytest -x -r A --verbose --runslow --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
env: env:
MATRIX_ID: ${{ steps.matrix-id.outputs.id }} MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
MKL_THREADING_LAYER: GNU MKL_THREADING_LAYER: GNU
......
...@@ -30,6 +30,7 @@ dependencies: ...@@ -30,6 +30,7 @@ dependencies:
- pytest - pytest
- pytest-cov - pytest-cov
- pytest-xdist - pytest-xdist
- pytest-benchmark
# For building docs # For building docs
- sphinx>=5.1.0 - sphinx>=5.1.0
- sphinx_rtd_theme - sphinx_rtd_theme
......
...@@ -86,6 +86,7 @@ tests = [ ...@@ -86,6 +86,7 @@ tests = [
"pre-commit", "pre-commit",
"pytest-cov>=2.6.1", "pytest-cov>=2.6.1",
"coverage>=5.1", "coverage>=5.1",
"pytest-benchmark",
] ]
rtd = [ rtd = [
"sphinx>=1.3.0", "sphinx>=1.3.0",
......
import numpy as np import numpy as np
import pytest import pytest
import scipy.special
import pytensor
import pytensor.tensor as at
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
...@@ -98,3 +101,24 @@ def test_softmax_grad(axis): ...@@ -98,3 +101,24 @@ def test_softmax_grad(axis):
out = SoftmaxGrad(axis=axis)(dy, sm) out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out]) fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = at.matrix("X")
X_max = at.max(X, axis=axis, keepdims=True)
X_max = at.switch(at.isinf(X_max), 0, X_max)
X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max
X_val = np.random.normal(size=size)
X_lse_fn = pytensor.function([X], X_lse, mode="JAX")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
import contextlib import contextlib
import inspect import inspect
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from unittest import mock from unittest import mock
import numba import numba
...@@ -190,7 +190,7 @@ def compare_numba_and_py( ...@@ -190,7 +190,7 @@ def compare_numba_and_py(
numba_mode=numba_mode, numba_mode=numba_mode,
py_mode=py_mode, py_mode=py_mode,
updates=None, updates=None,
): ) -> Tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
...@@ -209,6 +209,10 @@ def compare_numba_and_py( ...@@ -209,6 +209,10 @@ def compare_numba_and_py(
updates updates
Updates to be passed to `pytensor.function`. Updates to be passed to `pytensor.function`.
Returns
-------
The compiled PyTensor function and its last computed result.
""" """
if assert_fn is None: if assert_fn is None:
...@@ -248,7 +252,7 @@ def compare_numba_and_py( ...@@ -248,7 +252,7 @@ def compare_numba_and_py(
else: else:
assert_fn(numba_res, py_res) assert_fn(numba_res, py_res)
return numba_res return pytensor_numba_fn, numba_res
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -159,7 +159,7 @@ def test_xit_xot_types( ...@@ -159,7 +159,7 @@ def test_xit_xot_types(
assert np.allclose(res_val, output_vals) assert np.allclose(res_val, output_vals)
def test_scan_multiple_output(): def test_scan_multiple_output(benchmark):
"""Test a scan implementation of a SEIR model. """Test a scan implementation of a SEIR model.
SEIR model definition: SEIR model definition:
...@@ -244,7 +244,9 @@ def test_scan_multiple_output(): ...@@ -244,7 +244,9 @@ def test_scan_multiple_output():
gamma_val, gamma_val,
delta_val, delta_val,
] ]
compare_numba_and_py(out_fg, test_input_vals) scan_fn, _ = compare_numba_and_py(out_fg, test_input_vals)
benchmark(scan_fn, *test_input_vals)
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
......
...@@ -32,7 +32,7 @@ def test_Alloc(v, shape): ...@@ -32,7 +32,7 @@ def test_Alloc(v, shape):
g = at.alloc(v, *shape) g = at.alloc(v, *shape)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
(numba_res,) = compare_numba_and_py( _, (numba_res,) = compare_numba_and_py(
g_fg, g_fg,
[ [
i.tag.test_value i.tag.test_value
......
...@@ -13,7 +13,6 @@ import os ...@@ -13,7 +13,6 @@ import os
import pickle import pickle
import shutil import shutil
import sys import sys
import timeit
from collections import OrderedDict from collections import OrderedDict
from tempfile import mkdtemp from tempfile import mkdtemp
...@@ -2179,15 +2178,13 @@ def test_cvm_exception_handling(mode): ...@@ -2179,15 +2178,13 @@ def test_cvm_exception_handling(mode):
@pytest.mark.skipif( @pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test." not config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_cython_performance(): def test_cython_performance(benchmark):
# This implicitly confirms that the Cython version is being used # This implicitly confirms that the Cython version is being used
from pytensor.scan import scan_perform_ext # noqa: F401 from pytensor.scan import scan_perform_ext # noqa: F401
# Python usually out-performs PyTensor below 100 iterations # Python usually out-performs PyTensor below 100 iterations
N = 200 N = 200
n_timeit = 50
M = -1 / np.arange(1, 11).astype(config.floatX) M = -1 / np.arange(1, 11).astype(config.floatX)
r = np.arange(N * 10).astype(config.floatX).reshape(N, 10) r = np.arange(N * 10).astype(config.floatX).reshape(N, 10)
...@@ -2216,17 +2213,11 @@ def test_cython_performance(): ...@@ -2216,17 +2213,11 @@ def test_cython_performance():
# Make sure we're actually computing a `Scan` # Make sure we're actually computing a `Scan`
assert any(isinstance(node.op, Scan) for node in f_cvm.maker.fgraph.apply_nodes) assert any(isinstance(node.op, Scan) for node in f_cvm.maker.fgraph.apply_nodes)
cvm_res = f_cvm() cvm_res = benchmark(f_cvm)
# Make sure the results are the same between the two implementations # Make sure the results are the same between the two implementations
assert np.allclose(cvm_res, py_res) assert np.allclose(cvm_res, py_res)
python_duration = timeit.timeit(lambda: f_py(), number=n_timeit)
cvm_duration = timeit.timeit(lambda: f_cvm(), number=n_timeit)
print(f"python={python_duration}, cvm={cvm_duration}")
assert cvm_duration <= python_duration
@config.change_flags(mode="FAST_COMPILE", compute_test_value="raise") @config.change_flags(mode="FAST_COMPILE", compute_test_value="raise")
def test_compute_test_values(): def test_compute_test_values():
...@@ -2662,7 +2653,7 @@ class TestExamples: ...@@ -2662,7 +2653,7 @@ class TestExamples:
n_result = numpy_implementation(v_vsample) n_result = numpy_implementation(v_vsample)
utt.assert_allclose(t_result, n_result) utt.assert_allclose(t_result, n_result)
def test_reordering(self): def test_reordering(self, benchmark):
"""Test re-ordering of inputs. """Test re-ordering of inputs.
some rnn with multiple outputs and multiple inputs; other some rnn with multiple outputs and multiple inputs; other
...@@ -2722,14 +2713,14 @@ class TestExamples: ...@@ -2722,14 +2713,14 @@ class TestExamples:
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW) v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1] v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]
(pytensor_dump1, pytensor_dump2, pytensor_x, pytensor_y) = f4( (pytensor_dump1, pytensor_dump2, pytensor_x, pytensor_y) = benchmark(
v_u1, v_u2, v_x0, v_y0, vW_in1 f4, v_u1, v_u2, v_x0, v_y0, vW_in1
) )
utt.assert_allclose(pytensor_x, v_x) utt.assert_allclose(pytensor_x, v_x)
utt.assert_allclose(pytensor_y, v_y) utt.assert_allclose(pytensor_y, v_y)
def test_scan_as_tensor_on_gradients(self): def test_scan_as_tensor_on_gradients(self, benchmark):
to_scan = dvector("to_scan") to_scan = dvector("to_scan")
seq = dmatrix("seq") seq = dmatrix("seq")
f1 = dscalar("f1") f1 = dscalar("f1")
...@@ -2743,7 +2734,12 @@ class TestExamples: ...@@ -2743,7 +2734,12 @@ class TestExamples:
function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True) function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True)
t_grad = grad(scanned.sum(), wrt=[to_scan, f1], consider_constant=[seq]) t_grad = grad(scanned.sum(), wrt=[to_scan, f1], consider_constant=[seq])
function(inputs=[to_scan, seq, f1], outputs=t_grad, allow_input_downcast=True) benchmark(
function,
inputs=[to_scan, seq, f1],
outputs=t_grad,
allow_input_downcast=True,
)
def caching_nsteps_by_scan_op(self): def caching_nsteps_by_scan_op(self):
W = matrix("weights") W = matrix("weights")
...@@ -3060,7 +3056,7 @@ class TestExamples: ...@@ -3060,7 +3056,7 @@ class TestExamples:
utt.assert_allclose(outputs, expected_outputs) utt.assert_allclose(outputs, expected_outputs)
@pytest.mark.slow @pytest.mark.slow
def test_hessian_bug_grad_grad_two_scans(self): def test_hessian_bug_grad_grad_two_scans(self, benchmark):
# Bug reported by Bitton Tenessi # Bug reported by Bitton Tenessi
# NOTE : The test to reproduce the bug reported by Bitton Tenessi # NOTE : The test to reproduce the bug reported by Bitton Tenessi
# was modified from its original version to be faster to run. # was modified from its original version to be faster to run.
...@@ -3094,7 +3090,7 @@ class TestExamples: ...@@ -3094,7 +3090,7 @@ class TestExamples:
H = hessian(cost, W) H = hessian(cost, W)
print(".", file=sys.stderr) print(".", file=sys.stderr)
f = function([W, n_steps], H) f = function([W, n_steps], H)
f(np.ones((8,), dtype="float32"), 1) benchmark(f, np.ones((8,), dtype="float32"), 1)
def test_grad_connectivity_matrix(self): def test_grad_connectivity_matrix(self):
def inner_fn(x_tm1, y_tm1, z_tm1): def inner_fn(x_tm1, y_tm1, z_tm1):
...@@ -3710,7 +3706,7 @@ class TestExamples: ...@@ -3710,7 +3706,7 @@ class TestExamples:
utt.assert_allclose(pytensor_x, v_x) utt.assert_allclose(pytensor_x, v_x)
utt.assert_allclose(pytensor_y, v_y) utt.assert_allclose(pytensor_y, v_y)
def test_multiple_outs_taps(self): def test_multiple_outs_taps(self, benchmark):
l = 5 l = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -3805,6 +3801,8 @@ class TestExamples: ...@@ -3805,6 +3801,8 @@ class TestExamples:
np.testing.assert_almost_equal(res[1], ny1) np.testing.assert_almost_equal(res[1], ny1)
np.testing.assert_almost_equal(res[2], ny2) np.testing.assert_almost_equal(res[2], ny2)
benchmark(f, v_u1, v_u2, v_x0, v_y0, vW_in1)
def _grad_mout_helper(self, n_iters, mode): def _grad_mout_helper(self, n_iters, mode):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
n_hid = 3 n_hid = 3
......
...@@ -620,7 +620,7 @@ class TestPushOutAddScan: ...@@ -620,7 +620,7 @@ class TestPushOutAddScan:
vB = rng.uniform(size=(5, 5)).astype(config.floatX) vB = rng.uniform(size=(5, 5)).astype(config.floatX)
utt.assert_allclose(f(vA, vB), np.dot(vA.T, vB)) utt.assert_allclose(f(vA, vB), np.dot(vA.T, vB))
def test_pregreedy_optimizer(self): def test_pregreedy_optimizer(self, benchmark):
W = at.zeros((5, 4)) W = at.zeros((5, 4))
bv = at.zeros((5,)) bv = at.zeros((5,))
bh = at.zeros((4,)) bh = at.zeros((4,))
...@@ -634,7 +634,9 @@ class TestPushOutAddScan: ...@@ -634,7 +634,9 @@ class TestPushOutAddScan:
n_steps=2, n_steps=2,
) )
# TODO FIXME: Make this a real test and assert something. # TODO FIXME: Make this a real test and assert something.
function([v], chain)(np.zeros((3, 5), dtype=config.floatX)) chain_fn = function([v], chain)
benchmark(chain_fn, np.zeros((3, 5), dtype=config.floatX))
def test_machine_translation(self): def test_machine_translation(self):
""" """
...@@ -1291,7 +1293,7 @@ class TestSaveMem: ...@@ -1291,7 +1293,7 @@ class TestSaveMem:
] ]
assert len(scan_nodes) == 1 assert len(scan_nodes) == 1
def test_savemem_opt(self): def test_savemem_opt(self, benchmark):
y0 = shared(np.ones((2, 10))) y0 = shared(np.ones((2, 10)))
[y1, y2], updates = scan( [y1, y2], updates = scan(
lambda y: [y, y], lambda y: [y, y],
...@@ -1299,7 +1301,8 @@ class TestSaveMem: ...@@ -1299,7 +1301,8 @@ class TestSaveMem:
n_steps=5, n_steps=5,
) )
# TODO FIXME: Make this a real test and assert something. # TODO FIXME: Make this a real test and assert something.
function([], y2.sum(), mode=self.mode)() fn = function([], y2.sum(), mode=self.mode)
benchmark(fn)
def test_savemem_opt_0_step(self): def test_savemem_opt_0_step(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论