提交 d62f4b19 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Replace more "if not a or not b" with "if not (a and b)"

上级 f0e9354b
...@@ -15,8 +15,8 @@ def sparse_grad(var): ...@@ -15,8 +15,8 @@ def sparse_grad(var):
""" """
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1
if var.owner is None or not isinstance( if not (
var.owner.op, AdvancedSubtensor | AdvancedSubtensor1 var.owner and isinstance(var.owner.op, AdvancedSubtensor | AdvancedSubtensor1)
): ):
raise TypeError( raise TypeError(
"Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1" "Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1"
......
...@@ -2134,8 +2134,8 @@ def dense_dot(a, b): ...@@ -2134,8 +2134,8 @@ def dense_dot(a, b):
""" """
a, b = as_tensor_variable(a), as_tensor_variable(b) a, b = as_tensor_variable(a), as_tensor_variable(b)
if not isinstance(a.type, DenseTensorType) or not isinstance( if not (
b.type, DenseTensorType isinstance(a.type, DenseTensorType) and isinstance(b.type, DenseTensorType)
): ):
raise TypeError("The dense dot product is only supported for dense types") raise TypeError("The dense dot product is only supported for dense types")
......
...@@ -658,13 +658,13 @@ def local_cast_cast(fgraph, node): ...@@ -658,13 +658,13 @@ def local_cast_cast(fgraph, node):
and the first cast cause an upcast. and the first cast cause an upcast.
""" """
if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, ps.Cast): if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Cast)):
return return
x = node.inputs[0] x = node.inputs[0]
if ( if not (
x.owner is None x.owner
or not isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op, Elemwise)
or not isinstance(x.owner.op.scalar_op, ps.Cast) and isinstance(x.owner.op.scalar_op, ps.Cast)
): ):
return return
...@@ -1053,8 +1053,9 @@ def local_merge_switch_same_cond(fgraph, node): ...@@ -1053,8 +1053,9 @@ def local_merge_switch_same_cond(fgraph, node):
Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
""" """
# node must be binary elemwise or add or mul # node must be binary elemwise or add or mul
if not isinstance(node.op, Elemwise) or not isinstance( if not (
node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.BinaryScalarOp | ps.Add | ps.Mul)
): ):
return return
# all inputs must be switch # all inputs must be switch
......
...@@ -473,10 +473,10 @@ def local_useless_dimshuffle_makevector(fgraph, node): ...@@ -473,10 +473,10 @@ def local_useless_dimshuffle_makevector(fgraph, node):
makevector_out = node.inputs[0] makevector_out = node.inputs[0]
if ( if not (
not makevector_out.owner makevector_out.owner
or not isinstance(makevector_out.owner.op, MakeVector) and isinstance(makevector_out.owner.op, MakeVector)
or not makevector_out.broadcastable == (True,) and makevector_out.broadcastable == (True,)
): ):
return return
...@@ -570,8 +570,8 @@ def local_add_mul_fusion(fgraph, node): ...@@ -570,8 +570,8 @@ def local_add_mul_fusion(fgraph, node):
This rewrite is almost useless after the AlgebraicCanonizer is used, This rewrite is almost useless after the AlgebraicCanonizer is used,
but it catches a few edge cases that are not canonicalized by it but it catches a few edge cases that are not canonicalized by it
""" """
if not isinstance(node.op, Elemwise) or not isinstance( if not (
node.op.scalar_op, ps.Add | ps.Mul isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Add | ps.Mul)
): ):
return False return False
...@@ -1094,8 +1094,8 @@ class FusionOptimizer(GraphRewriter): ...@@ -1094,8 +1094,8 @@ class FusionOptimizer(GraphRewriter):
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_useless_composite_outputs(fgraph, node): def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere.""" """Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not isinstance(node.op, Elemwise) or not isinstance( if not (
node.op.scalar_op, ps.Composite isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.Composite)
): ):
return return
comp = node.op.scalar_op comp = node.op.scalar_op
...@@ -1135,7 +1135,7 @@ def local_careduce_fusion(fgraph, node): ...@@ -1135,7 +1135,7 @@ def local_careduce_fusion(fgraph, node):
elm_node = car_input.owner elm_node = car_input.owner
if elm_node is None or not isinstance(elm_node.op, Elemwise): if not (elm_node and isinstance(elm_node.op, Elemwise)):
return False return False
elm_scalar_op = elm_node.op.scalar_op elm_scalar_op = elm_node.op.scalar_op
......
...@@ -2343,12 +2343,14 @@ def local_log_sum_exp(fgraph, node): ...@@ -2343,12 +2343,14 @@ def local_log_sum_exp(fgraph, node):
else: else:
dimshuffle_op = None dimshuffle_op = None
if not sum_node or not isinstance(sum_node.op, Sum): if not (sum_node and isinstance(sum_node.op, Sum)):
return return
exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis
if not exp_node or not ( if not (
isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, ps.Exp) exp_node
and isinstance(exp_node.op, Elemwise)
and isinstance(exp_node.op.scalar_op, ps.Exp)
): ):
return return
...@@ -2660,7 +2662,7 @@ def local_log_erfc(fgraph, node): ...@@ -2660,7 +2662,7 @@ def local_log_erfc(fgraph, node):
10.0541948,10.0541951,.0000001)] 10.0541948,10.0541951,.0000001)]
""" """
if not node.inputs[0].owner or node.inputs[0].owner.op != erfc: if not (node.inputs[0].owner and node.inputs[0].owner.op == erfc):
return False return False
if hasattr(node.tag, "local_log_erfc_applied"): if hasattr(node.tag, "local_log_erfc_applied"):
...@@ -2725,7 +2727,7 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2725,7 +2727,7 @@ def local_grad_log_erfc_neg(fgraph, node):
if node.inputs[0].owner.op != mul: if node.inputs[0].owner.op != mul:
mul_in = None mul_in = None
y = [] y = []
if not node.inputs[0].owner or node.inputs[0].owner.op != exp: if not (node.inputs[0].owner and node.inputs[0].owner.op == exp):
return False return False
exp_in = node.inputs[0] exp_in = node.inputs[0]
else: else:
...@@ -2749,7 +2751,9 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2749,7 +2751,9 @@ def local_grad_log_erfc_neg(fgraph, node):
if exp_in.owner.inputs[0].owner.op == neg: if exp_in.owner.inputs[0].owner.op == neg:
neg_in = exp_in.owner.inputs[0] neg_in = exp_in.owner.inputs[0]
if not neg_in.owner.inputs[0].owner or neg_in.owner.inputs[0].owner.op != sqr: if not (
neg_in.owner.inputs[0].owner and neg_in.owner.inputs[0].owner.op == sqr
):
return False return False
sqr_in = neg_in.owner.inputs[0] sqr_in = neg_in.owner.inputs[0]
x = sqr_in.owner.inputs[0] x = sqr_in.owner.inputs[0]
...@@ -2794,9 +2798,9 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2794,9 +2798,9 @@ def local_grad_log_erfc_neg(fgraph, node):
return False return False
if len(mul_neg.owner.inputs) == 2: if len(mul_neg.owner.inputs) == 2:
if ( if not (
not mul_neg.owner.inputs[1].owner mul_neg.owner.inputs[1].owner
or mul_neg.owner.inputs[1].owner.op != sqr and mul_neg.owner.inputs[1].owner.op == sqr
): ):
return False return False
sqr_in = mul_neg.owner.inputs[1] sqr_in = mul_neg.owner.inputs[1]
...@@ -2809,10 +2813,10 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2809,10 +2813,10 @@ def local_grad_log_erfc_neg(fgraph, node):
return False return False
if cst2 != -1: if cst2 != -1:
if ( if not (
not erfc_x.owner erfc_x.owner
or erfc_x.owner.op != mul and erfc_x.owner.op == mul
or len(erfc_x.owner.inputs) != 2 and len(erfc_x.owner.inputs) == 2
): ):
# todo implement that case # todo implement that case
return False return False
......
...@@ -360,7 +360,7 @@ class Singleton: ...@@ -360,7 +360,7 @@ class Singleton:
# don't want that, so we check the class. When we add one, we # don't want that, so we check the class. When we add one, we
# add one only to the current class, so all is working # add one only to the current class, so all is working
# correctly. # correctly.
if cls.__instance is None or not isinstance(cls.__instance, cls): if not (cls.__instance and isinstance(cls.__instance, cls)):
cls.__instance = super().__new__(cls) cls.__instance = super().__new__(cls)
return cls.__instance return cls.__instance
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论