提交 d62f4b19 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Replace more "if not a or not b" with "if not (a and b)"

上级 f0e9354b
......@@ -15,8 +15,8 @@ def sparse_grad(var):
"""
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
if var.owner is None or not isinstance(
var.owner.op, AdvancedSubtensor | AdvancedSubtensor1
if not (
var.owner and isinstance(var.owner.op, AdvancedSubtensor | AdvancedSubtensor1)
):
raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
......
......@@ -2134,8 +2134,8 @@ def dense_dot(a, b):
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
if not isinstance(a.type, DenseTensorType) or not isinstance(
b.type, DenseTensorType
if not (
isinstance(a.type, DenseTensorType) and isinstance(b.type, DenseTensorType)
):
raise TypeError("The dense dot product is only supported for dense types")
......
......@@ -658,13 +658,13 @@ def local_cast_cast(fgraph, node):
and the first cast cause an upcast.
"""
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, ps.Cast):
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
return
x = node.inputs[0]
if (
x.owner is None
or not isinstance(x.owner.op, Elemwise)
or not isinstance(x.owner.op.scalar_op, ps.Cast)
if not (
x.owner
and isinstance(x.owner.op, Elemwise)
and isinstance(x.owner.op.scalar_op, ps.Cast)
):
return
......@@ -1053,8 +1053,9 @@ def local_merge_switch_same_cond(fgraph, node):
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
"""
# node must be binary elemwise or add or mul
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul
if not (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
):
return
# all inputs must be switch
......
......@@ -473,10 +473,10 @@ def local_useless_dimshuffle_makevector(fgraph, node):
makevector_out = node.inputs[0]
if (
not makevector_out.owner
or not isinstance(makevector_out.owner.op, MakeVector)
or not makevector_out.broadcastable == (True,)
if not (
makevector_out.owner
and isinstance(makevector_out.owner.op, MakeVector)
and makevector_out.broadcastable == (True,)
):
return
......@@ -570,8 +570,8 @@ def local_add_mul_fusion(fgraph, node):
This rewrite is almost useless after the AlgebraicCanonizer is used,
but it catches a few edge cases that are not canonicalized by it
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, ps.Add | ps.Mul
if not (
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul)
):
return False
......@@ -1094,8 +1094,8 @@ class FusionOptimizer(GraphRewriter):
@node_rewriter([Elemwise])
def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, ps.Composite
if not (
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
):
return
comp = node.op.scalar_op
......@@ -1135,7 +1135,7 @@ def local_careduce_fusion(fgraph, node):
elm_node = car_input.owner
if elm_node is None or not isinstance(elm_node.op, Elemwise):
if not (elm_node and isinstance(elm_node.op, Elemwise)):
return False
elm_scalar_op = elm_node.op.scalar_op
......
......@@ -2343,12 +2343,14 @@ def local_log_sum_exp(fgraph, node):
else:
dimshuffle_op = None
if not sum_node or not isinstance(sum_node.op, Sum):
if not (sum_node and isinstance(sum_node.op, Sum)):
return
exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node or not (
isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, ps.Exp)
if not (
exp_node
and isinstance(exp_node.op, Elemwise)
and isinstance(exp_node.op.scalar_op, ps.Exp)
):
return
......@@ -2660,7 +2662,7 @@ def local_log_erfc(fgraph, node):
10.0541948,10.0541951,.0000001)]
"""
if not node.inputs[0].owner or node.inputs[0].owner.op != erfc:
if not (node.inputs[0].owner and node.inputs[0].owner.op == erfc):
return False
if hasattr(node.tag, "local_log_erfc_applied"):
......@@ -2725,7 +2727,7 @@ def local_grad_log_erfc_neg(fgraph, node):
if node.inputs[0].owner.op != mul:
mul_in = None
y = []
if not node.inputs[0].owner or node.inputs[0].owner.op != exp:
if not (node.inputs[0].owner and node.inputs[0].owner.op == exp):
return False
exp_in = node.inputs[0]
else:
......@@ -2749,7 +2751,9 @@ def local_grad_log_erfc_neg(fgraph, node):
if exp_in.owner.inputs[0].owner.op == neg:
neg_in = exp_in.owner.inputs[0]
if not neg_in.owner.inputs[0].owner or neg_in.owner.inputs[0].owner.op != sqr:
if not (
neg_in.owner.inputs[0].owner and neg_in.owner.inputs[0].owner.op == sqr
):
return False
sqr_in = neg_in.owner.inputs[0]
x = sqr_in.owner.inputs[0]
......@@ -2794,9 +2798,9 @@ def local_grad_log_erfc_neg(fgraph, node):
return False
if len(mul_neg.owner.inputs) == 2:
if (
not mul_neg.owner.inputs[1].owner
or mul_neg.owner.inputs[1].owner.op != sqr
if not (
mul_neg.owner.inputs[1].owner
and mul_neg.owner.inputs[1].owner.op == sqr
):
return False
sqr_in = mul_neg.owner.inputs[1]
......@@ -2809,10 +2813,10 @@ def local_grad_log_erfc_neg(fgraph, node):
return False
if cst2 != -1:
if (
not erfc_x.owner
or erfc_x.owner.op != mul
or len(erfc_x.owner.inputs) != 2
if not (
erfc_x.owner
and erfc_x.owner.op == mul
and len(erfc_x.owner.inputs) == 2
):
# todo implement that case
return False
......
......@@ -360,7 +360,7 @@ class Singleton:
# don't want that, so we check the class. When we add one, we
# add one only to the current class, so all is working
# correctly.
if cls.__instance is None or not isinstance(cls.__instance, cls):
if not (cls.__instance and isinstance(cls.__instance, cls)):
cls.__instance = super().__new__(cls)
return cls.__instance
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论