提交 68d9b808 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Fix ruff lints

上级 e9ee7d0a
...@@ -307,7 +307,7 @@ def raise_with_op( ...@@ -307,7 +307,7 @@ def raise_with_op(
if exc_info is None: if exc_info is None:
exc_info = sys.exc_info() exc_info = sys.exc_info()
exc_type, exc_value, exc_trace = exc_info exc_type, exc_value, exc_trace = exc_info
if exc_type == KeyboardInterrupt: if exc_type is KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt # print a simple traceback from KeyboardInterrupt
raise exc_value.with_traceback(exc_trace) raise exc_value.with_traceback(exc_trace)
......
...@@ -35,7 +35,7 @@ class TestFunctionGraph: ...@@ -35,7 +35,7 @@ class TestFunctionGraph:
assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs)) assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs))
assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs))
assert all( assert all(
type(a.op) is type(b.op) # noqa: E721 type(a.op) is type(b.op)
for a, b in zip(func.apply_nodes, new_func.apply_nodes) for a, b in zip(func.apply_nodes, new_func.apply_nodes)
) )
assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables)) assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
......
...@@ -5,7 +5,7 @@ from pytensor.compile import get_mode ...@@ -5,7 +5,7 @@ from pytensor.compile import get_mode
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
import jax.errors from jax import errors
import pytensor import pytensor
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
...@@ -200,7 +200,7 @@ class TestJaxSplit: ...@@ -200,7 +200,7 @@ class TestJaxSplit:
fn = pytensor.function([a, split_axis], a_splits, mode="JAX") fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError` # Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
# Both errors are included for backwards compatibility # Both errors are included for backwards compatibility
with pytest.raises((AttributeError, jax.errors.TracerIntegerConversionError)): with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0) fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
......
...@@ -228,7 +228,7 @@ def grab_scan_node(output): ...@@ -228,7 +228,7 @@ def grab_scan_node(output):
ri = grab_scan_node(i) ri = grab_scan_node(i)
if ri is not None: if ri is not None:
rval += ri rval += ri
if rval is []: if rval == []:
return None return None
else: else:
return rval return rval
......
...@@ -49,7 +49,7 @@ class TestLoadTensor: ...@@ -49,7 +49,7 @@ class TestLoadTensor:
path = Variable(Generic(), None) path = Variable(Generic(), None)
x = load(path, "int32", (None,), mmap_mode="c") x = load(path, "int32", (None,), mmap_mode="c")
fn = function([path], x) fn = function([path], x)
assert type(fn(self.filename)) == np.core.memmap assert isinstance(fn(self.filename), np.core.memmap)
def teardown_method(self): def teardown_method(self):
(pytensor.config.compiledir / "_test.npy").unlink() (pytensor.config.compiledir / "_test.npy").unlink()
...@@ -12,8 +12,7 @@ scipy = pytest.importorskip("scipy") ...@@ -12,8 +12,7 @@ scipy = pytest.importorskip("scipy")
from functools import partial from functools import partial
import scipy.special from scipy import special, stats
import scipy.stats
from pytensor import function, grad from pytensor import function, grad
from pytensor import tensor as pt from pytensor import tensor as pt
...@@ -45,43 +44,43 @@ mode_no_scipy = get_default_mode() ...@@ -45,43 +44,43 @@ mode_no_scipy = get_default_mode()
def scipy_special_gammau(k, x): def scipy_special_gammau(k, x):
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k) return special.gammaincc(k, x) * special.gamma(k)
def scipy_special_gammal(k, x): def scipy_special_gammal(k, x):
return scipy.special.gammainc(k, x) * scipy.special.gamma(k) return special.gammainc(k, x) * special.gamma(k)
# Precomputing the result is brittle(it have been broken!) # Precomputing the result is brittle(it have been broken!)
# As if we do any modification to random number here, # As if we do any modification to random number here,
# The input random number will change and the output! # The input random number will change and the output!
expected_erf = scipy.special.erf expected_erf = special.erf
expected_erfc = scipy.special.erfc expected_erfc = special.erfc
expected_erfinv = scipy.special.erfinv expected_erfinv = special.erfinv
expected_erfcinv = scipy.special.erfcinv expected_erfcinv = special.erfcinv
expected_owenst = scipy.special.owens_t expected_owenst = special.owens_t
expected_gamma = scipy.special.gamma expected_gamma = special.gamma
expected_gammaln = scipy.special.gammaln expected_gammaln = special.gammaln
expected_psi = scipy.special.psi expected_psi = special.psi
expected_tri_gamma = partial(scipy.special.polygamma, 1) expected_tri_gamma = partial(special.polygamma, 1)
expected_chi2sf = scipy.stats.chi2.sf expected_chi2sf = stats.chi2.sf
expected_gammainc = scipy.special.gammainc expected_gammainc = special.gammainc
expected_gammaincc = scipy.special.gammaincc expected_gammaincc = special.gammaincc
expected_gammau = scipy_special_gammau expected_gammau = scipy_special_gammau
expected_gammal = scipy_special_gammal expected_gammal = scipy_special_gammal
expected_gammaincinv = scipy.special.gammaincinv expected_gammaincinv = special.gammaincinv
expected_gammainccinv = scipy.special.gammainccinv expected_gammainccinv = special.gammainccinv
expected_j0 = scipy.special.j0 expected_j0 = special.j0
expected_j1 = scipy.special.j1 expected_j1 = special.j1
expected_jv = scipy.special.jv expected_jv = special.jv
expected_i0 = scipy.special.i0 expected_i0 = special.i0
expected_i1 = scipy.special.i1 expected_i1 = special.i1
expected_iv = scipy.special.iv expected_iv = special.iv
expected_ive = scipy.special.ive expected_ive = special.ive
expected_erfcx = scipy.special.erfcx expected_erfcx = special.erfcx
expected_sigmoid = scipy.special.expit expected_sigmoid = special.expit
expected_hyp2f1 = scipy.special.hyp2f1 expected_hyp2f1 = special.hyp2f1
expected_betaincinv = scipy.special.betaincinv expected_betaincinv = special.betaincinv
TestErfBroadcast = makeBroadcastTester( TestErfBroadcast = makeBroadcastTester(
op=pt.erf, op=pt.erf,
...@@ -837,14 +836,14 @@ _good_broadcast_ternary_betainc = dict( ...@@ -837,14 +836,14 @@ _good_broadcast_ternary_betainc = dict(
TestBetaincBroadcast = makeBroadcastTester( TestBetaincBroadcast = makeBroadcastTester(
op=pt.betainc, op=pt.betainc,
expected=scipy.special.betainc, expected=special.betainc,
good=_good_broadcast_ternary_betainc, good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc, grad=_good_broadcast_ternary_betainc,
) )
TestBetaincInplaceBroadcast = makeBroadcastTester( TestBetaincInplaceBroadcast = makeBroadcastTester(
op=inplace.betainc_inplace, op=inplace.betainc_inplace,
expected=scipy.special.betainc, expected=special.betainc,
good=_good_broadcast_ternary_betainc, good=_good_broadcast_ternary_betainc,
grad=_good_broadcast_ternary_betainc, grad=_good_broadcast_ternary_betainc,
inplace=True, inplace=True,
...@@ -936,13 +935,13 @@ _good_broadcast_ternary_betaincinv = dict( ...@@ -936,13 +935,13 @@ _good_broadcast_ternary_betaincinv = dict(
TestBetaincinvBroadcast = makeBroadcastTester( TestBetaincinvBroadcast = makeBroadcastTester(
op=pt.betaincinv, op=pt.betaincinv,
expected=scipy.special.betaincinv, expected=special.betaincinv,
good=_good_broadcast_ternary_betaincinv, good=_good_broadcast_ternary_betaincinv,
) )
TestBetaincinvInplaceBroadcast = makeBroadcastTester( TestBetaincinvInplaceBroadcast = makeBroadcastTester(
op=inplace.betaincinv_inplace, op=inplace.betaincinv_inplace,
expected=scipy.special.betaincinv, expected=special.betaincinv,
good=_good_broadcast_ternary_betaincinv, good=_good_broadcast_ternary_betaincinv,
inplace=True, inplace=True,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论