提交 f9f930c3 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Simplify some type checking

上级 d62f4b19
......@@ -515,9 +515,7 @@ def construct_pfunc_ins_and_outs(
if not isinstance(params, list | tuple):
raise TypeError("The `params` argument must be a list or a tuple")
if not isinstance(no_default_updates, bool) and not isinstance(
no_default_updates, list
):
if not isinstance(no_default_updates, bool | list):
raise TypeError("The `no_default_update` argument must be a boolean or list")
if len(updates) > 0 and not all(
......
......@@ -207,7 +207,7 @@ class In(SymbolicInput):
if implicit is None:
from pytensor.compile.sharedvalue import SharedVariable
implicit = isinstance(value, Container) or isinstance(value, SharedVariable)
implicit = isinstance(value, Container | SharedVariable)
super().__init__(
variable=variable,
name=name,
......
......@@ -1788,7 +1788,7 @@ def verify_grad(
o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd")
o_fn_out = o_fn(*[p.copy() for p in pt])
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
if isinstance(o_fn_out, tuple | list):
raise TypeError(
"It seems like you are trying to use verify_grad "
"on an Op or a function which outputs a list: there should"
......
......@@ -68,14 +68,14 @@ def get_updates_and_outputs(ls):
"""
# Is `x` a container we can iterate on?
iter_on = None
if isinstance(x, list) or isinstance(x, tuple):
if isinstance(x, list | tuple):
iter_on = x
elif isinstance(x, dict):
iter_on = x.items()
if iter_on is not None:
return all(_filter(y) for y in iter_on)
else:
return isinstance(x, Variable) or isinstance(x, until)
return isinstance(x, Variable | until)
if not _filter(ls):
raise ValueError(
......@@ -840,11 +840,7 @@ def scan(
# add only the non-shared variables and non-constants to the arguments of
# the dummy function [ a function should not get shared variables or
# constants as input ]
dummy_args = [
arg
for arg in args
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
]
dummy_args = [arg for arg in args if not isinstance(arg, SharedVariable | Constant)]
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
......@@ -1043,16 +1039,14 @@ def scan(
other_inner_args = []
other_scan_args += [
arg
for arg in non_seqs
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
arg for arg in non_seqs if not isinstance(arg, SharedVariable | Constant)
]
# Step 5.6 all shared variables with no update rules
other_inner_args += [
safe_new(arg, "_copy")
for arg in non_seqs
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
if not isinstance(arg, SharedVariable | Constant)
]
inner_replacements.update(dict(zip(other_scan_args, other_inner_args)))
......
......@@ -1956,9 +1956,7 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
except NotScalarConstantError:
pass
if isinstance(x, ps.ScalarVariable) or isinstance(
x, ps.sharedvar.ScalarSharedVariable
):
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable):
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0]
else:
......
......@@ -2204,9 +2204,7 @@ class TestLocalErf:
assert len(topo) == 2
assert topo[0].op == erf
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Add) or isinstance(
topo[1].op.scalar_op, ps.Sub
)
assert isinstance(topo[1].op.scalar_op, ps.Add | ps.Sub)
def test_local_erf_minus_one(self):
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
......@@ -2227,9 +2225,7 @@ class TestLocalErf:
assert len(topo) == 2
assert topo[0].op == erf
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Add) or isinstance(
topo[1].op.scalar_op, ps.Sub
)
assert isinstance(topo[1].op.scalar_op, ps.Add | ps.Sub)
@pytest.mark.skipif(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论