提交 68d9b808 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Fix ruff lints

上级 e9ee7d0a
......@@ -307,7 +307,7 @@ def raise_with_op(
if exc_info is None:
exc_info = sys.exc_info()
exc_type, exc_value, exc_trace = exc_info
if exc_type == KeyboardInterrupt:
if exc_type is KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt
raise exc_value.with_traceback(exc_trace)
......
......@@ -35,7 +35,7 @@ class TestFunctionGraph:
assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs))
assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs))
assert all(
type(a.op) is type(b.op) # noqa: E721
type(a.op) is type(b.op)
for a, b in zip(func.apply_nodes, new_func.apply_nodes)
)
assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
......
......@@ -5,7 +5,7 @@ from pytensor.compile import get_mode
jax = pytest.importorskip("jax")
import jax.errors
from jax import errors
import pytensor
import pytensor.tensor.basic as ptb
......@@ -200,7 +200,7 @@ class TestJaxSplit:
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
# Both errors are included for backwards compatibility
with pytest.raises((AttributeError, jax.errors.TracerIntegerConversionError)):
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
......
......@@ -228,7 +228,7 @@ def grab_scan_node(output):
ri = grab_scan_node(i)
if ri is not None:
rval += ri
if rval is []:
if rval == []:
return None
else:
return rval
......
......@@ -49,7 +49,7 @@ class TestLoadTensor:
path = Variable(Generic(), None)
x = load(path, "int32", (None,), mmap_mode="c")
fn = function([path], x)
assert type(fn(self.filename)) == np.core.memmap
assert isinstance(fn(self.filename), np.core.memmap)
def teardown_method(self):
(pytensor.config.compiledir / "_test.npy").unlink()
......@@ -12,8 +12,7 @@ scipy = pytest.importorskip("scipy")
from functools import partial
import scipy.special
import scipy.stats
from scipy import special, stats
from pytensor import function, grad
from pytensor import tensor as pt
......@@ -45,43 +44,43 @@ mode_no_scipy = get_default_mode()
def scipy_special_gammau(k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
return special.gammaincc(k, x) * special.gamma(k)
def scipy_special_gammal(k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
return special.gammainc(k, x) * special.gamma(k)
# Precomputing the result is brittle(it have been broken!)
# As if we do any modification to random number here,
# The input random number will change and the output!
expected_erf = scipy.special.erf
expected_erfc = scipy.special.erfc
expected_erfinv = scipy.special.erfinv
expected_erfcinv = scipy.special.erfcinv
expected_owenst = scipy.special.owens_t
expected_gamma = scipy.special.gamma
expected_gammaln = scipy.special.gammaln
expected_psi = scipy.special.psi
expected_tri_gamma = partial(scipy.special.polygamma, 1)
expected_chi2sf = scipy.stats.chi2.sf
expected_gammainc = scipy.special.gammainc
expected_gammaincc = scipy.special.gammaincc
expected_erf = special.erf
expected_erfc = special.erfc
expected_erfinv = special.erfinv
expected_erfcinv = special.erfcinv
expected_owenst = special.owens_t
expected_gamma = special.gamma
expected_gammaln = special.gammaln
expected_psi = special.psi
expected_tri_gamma = partial(special.polygamma, 1)
expected_chi2sf = stats.chi2.sf
expected_gammainc = special.gammainc
expected_gammaincc = special.gammaincc
expected_gammau = scipy_special_gammau
expected_gammal = scipy_special_gammal
expected_gammaincinv = scipy.special.gammaincinv
expected_gammainccinv = scipy.special.gammainccinv
expected_j0 = scipy.special.j0
expected_j1 = scipy.special.j1
expected_jv = scipy.special.jv
expected_i0 = scipy.special.i0
expected_i1 = scipy.special.i1
expected_iv = scipy.special.iv
expected_ive = scipy.special.ive
expected_erfcx = scipy.special.erfcx
expected_sigmoid = scipy.special.expit
expected_hyp2f1 = scipy.special.hyp2f1
expected_betaincinv = scipy.special.betaincinv
expected_gammaincinv = special.gammaincinv
expected_gammainccinv = special.gammainccinv
expected_j0 = special.j0
expected_j1 = special.j1
expected_jv = special.jv
expected_i0 = special.i0
expected_i1 = special.i1
expected_iv = special.iv
expected_ive = special.ive
expected_erfcx = special.erfcx
expected_sigmoid = special.expit
expected_hyp2f1 = special.hyp2f1
expected_betaincinv = special.betaincinv
TestErfBroadcast = makeBroadcastTester(
op=pt.erf,
......@@ -837,14 +836,14 @@ _good_broadcast_ternary_betainc = dict(
TestBetaincBroadcast = makeBroadcastTester(
op=pt.betainc,
expected=scipy.special.betainc,
expected=special.betainc,
good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc,
)
TestBetaincInplaceBroadcast = makeBroadcastTester(
op=inplace.betainc_inplace,
expected=scipy.special.betainc,
expected=special.betainc,
good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc,
inplace=True,
......@@ -936,13 +935,13 @@ _good_broadcast_ternary_betaincinv = dict(
TestBetaincinvBroadcast = makeBroadcastTester(
op=pt.betaincinv,
expected=scipy.special.betaincinv,
expected=special.betaincinv,
good=_good_broadcast_ternary_betaincinv,
)
TestBetaincinvInplaceBroadcast = makeBroadcastTester(
op=inplace.betaincinv_inplace,
expected=scipy.special.betaincinv,
expected=special.betaincinv,
good=_good_broadcast_ternary_betaincinv,
inplace=True,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论