提交 f9f930c3 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Simplify some type checking

上级 d62f4b19
...@@ -515,9 +515,7 @@ def construct_pfunc_ins_and_outs( ...@@ -515,9 +515,7 @@ def construct_pfunc_ins_and_outs(
if not isinstance(params, list | tuple): if not isinstance(params, list | tuple):
raise TypeError("The `params` argument must be a list or a tuple") raise TypeError("The `params` argument must be a list or a tuple")
if not isinstance(no_default_updates, bool) and not isinstance( if not isinstance(no_default_updates, bool | list):
no_default_updates, list
):
raise TypeError("The `no_default_update` argument must be a boolean or list") raise TypeError("The `no_default_update` argument must be a boolean or list")
if len(updates) > 0 and not all( if len(updates) > 0 and not all(
......
...@@ -207,7 +207,7 @@ class In(SymbolicInput): ...@@ -207,7 +207,7 @@ class In(SymbolicInput):
if implicit is None: if implicit is None:
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
implicit = isinstance(value, Container) or isinstance(value, SharedVariable) implicit = isinstance(value, Container | SharedVariable)
super().__init__( super().__init__(
variable=variable, variable=variable,
name=name, name=name,
......
...@@ -1788,7 +1788,7 @@ def verify_grad( ...@@ -1788,7 +1788,7 @@ def verify_grad(
o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd") o_fn = fn_maker(tensor_pt, o_output, name="gradient.py fwd")
o_fn_out = o_fn(*[p.copy() for p in pt]) o_fn_out = o_fn(*[p.copy() for p in pt])
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list): if isinstance(o_fn_out, tuple | list):
raise TypeError( raise TypeError(
"It seems like you are trying to use verify_grad " "It seems like you are trying to use verify_grad "
"on an Op or a function which outputs a list: there should" "on an Op or a function which outputs a list: there should"
......
...@@ -68,14 +68,14 @@ def get_updates_and_outputs(ls): ...@@ -68,14 +68,14 @@ def get_updates_and_outputs(ls):
""" """
# Is `x` a container we can iterate on? # Is `x` a container we can iterate on?
iter_on = None iter_on = None
if isinstance(x, list) or isinstance(x, tuple): if isinstance(x, list | tuple):
iter_on = x iter_on = x
elif isinstance(x, dict): elif isinstance(x, dict):
iter_on = x.items() iter_on = x.items()
if iter_on is not None: if iter_on is not None:
return all(_filter(y) for y in iter_on) return all(_filter(y) for y in iter_on)
else: else:
return isinstance(x, Variable) or isinstance(x, until) return isinstance(x, Variable | until)
if not _filter(ls): if not _filter(ls):
raise ValueError( raise ValueError(
...@@ -840,11 +840,7 @@ def scan( ...@@ -840,11 +840,7 @@ def scan(
# add only the non-shared variables and non-constants to the arguments of # add only the non-shared variables and non-constants to the arguments of
# the dummy function [ a function should not get shared variables or # the dummy function [ a function should not get shared variables or
# constants as input ] # constants as input ]
dummy_args = [ dummy_args = [arg for arg in args if not isinstance(arg, SharedVariable | Constant)]
arg
for arg in args
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
]
# when we apply the lambda expression we get a mixture of update rules # when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated # and outputs that needs to be separated
...@@ -1043,16 +1039,14 @@ def scan( ...@@ -1043,16 +1039,14 @@ def scan(
other_inner_args = [] other_inner_args = []
other_scan_args += [ other_scan_args += [
arg arg for arg in non_seqs if not isinstance(arg, SharedVariable | Constant)
for arg in non_seqs
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
] ]
# Step 5.6 all shared variables with no update rules # Step 5.6 all shared variables with no update rules
other_inner_args += [ other_inner_args += [
safe_new(arg, "_copy") safe_new(arg, "_copy")
for arg in non_seqs for arg in non_seqs
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant)) if not isinstance(arg, SharedVariable | Constant)
] ]
inner_replacements.update(dict(zip(other_scan_args, other_inner_args))) inner_replacements.update(dict(zip(other_scan_args, other_inner_args)))
......
...@@ -1956,9 +1956,7 @@ def extract_constant(x, elemwise=True, only_process_constants=False): ...@@ -1956,9 +1956,7 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if isinstance(x, ps.ScalarVariable) or isinstance( if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable):
x, ps.sharedvar.ScalarSharedVariable
):
if x.owner and isinstance(x.owner.op, ScalarFromTensor): if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0] x = x.owner.inputs[0]
else: else:
......
...@@ -2204,9 +2204,7 @@ class TestLocalErf: ...@@ -2204,9 +2204,7 @@ class TestLocalErf:
assert len(topo) == 2 assert len(topo) == 2
assert topo[0].op == erf assert topo[0].op == erf
assert isinstance(topo[1].op, Elemwise) assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Add) or isinstance( assert isinstance(topo[1].op.scalar_op, ps.Add | ps.Sub)
topo[1].op.scalar_op, ps.Sub
)
def test_local_erf_minus_one(self): def test_local_erf_minus_one(self):
val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX) val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX)
...@@ -2227,9 +2225,7 @@ class TestLocalErf: ...@@ -2227,9 +2225,7 @@ class TestLocalErf:
assert len(topo) == 2 assert len(topo) == 2
assert topo[0].op == erf assert topo[0].op == erf
assert isinstance(topo[1].op, Elemwise) assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Add) or isinstance( assert isinstance(topo[1].op.scalar_op, ps.Add | ps.Sub)
topo[1].op.scalar_op, ps.Sub
)
@pytest.mark.skipif( @pytest.mark.skipif(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论