提交 19358099 authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Fix E721: do not compare types, for exact checks use is / is not

上级 4b6a4440
...@@ -687,7 +687,7 @@ def _lessbroken_deepcopy(a): ...@@ -687,7 +687,7 @@ def _lessbroken_deepcopy(a):
else: else:
rval = copy.deepcopy(a) rval = copy.deepcopy(a)
assert type(rval) == type(a), (type(rval), type(a)) assert type(rval) is type(a), (type(rval), type(a))
if isinstance(rval, np.ndarray): if isinstance(rval, np.ndarray):
assert rval.dtype == a.dtype assert rval.dtype == a.dtype
...@@ -1154,7 +1154,7 @@ class _FunctionGraphEvent: ...@@ -1154,7 +1154,7 @@ class _FunctionGraphEvent:
return str(self.__dict__) return str(self.__dict__)
def __eq__(self, other): def __eq__(self, other):
rval = type(self) == type(other) rval = type(self) is type(other)
if rval: if rval:
# nodes are not compared because this comparison is # nodes are not compared because this comparison is
# supposed to be true for corresponding events that happen # supposed to be true for corresponding events that happen
......
...@@ -246,7 +246,7 @@ class FromFunctionOp(Op): ...@@ -246,7 +246,7 @@ class FromFunctionOp(Op):
self.infer_shape = self._infer_shape self.infer_shape = self._infer_shape
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.__fn == other.__fn return type(self) is type(other) and self.__fn == other.__fn
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.__fn) return hash(type(self)) ^ hash(self.__fn)
......
...@@ -748,7 +748,7 @@ class NominalVariable(AtomicVariable[_TypeType]): ...@@ -748,7 +748,7 @@ class NominalVariable(AtomicVariable[_TypeType]):
return True return True
return ( return (
type(self) == type(other) type(self) is type(other)
and self.id == other.id and self.id == other.id
and self.type == other.type and self.type == other.type
) )
......
...@@ -33,7 +33,7 @@ class NullType(Type): ...@@ -33,7 +33,7 @@ class NullType(Type):
raise ValueError("NullType has no values to compare") raise ValueError("NullType has no values to compare")
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
...@@ -57,8 +57,8 @@ class ConstrainedVar(Var): ...@@ -57,8 +57,8 @@ class ConstrainedVar(Var):
return obj return obj
def __eq__(self, other): def __eq__(self, other):
if type(self) == type(other): if type(self) is type(other):
return self.token == other.token and self.constraint == other.constraint return self.token is other.token and self.constraint == other.constraint
return NotImplemented return NotImplemented
def __hash__(self): def __hash__(self):
......
...@@ -229,7 +229,7 @@ class MetaType(ABCMeta): ...@@ -229,7 +229,7 @@ class MetaType(ABCMeta):
if "__eq__" not in dct: if "__eq__" not in dct:
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and tuple( return type(self) is type(other) and tuple(
getattr(self, a) for a in props getattr(self, a) for a in props
) == tuple(getattr(other, a) for a in props) ) == tuple(getattr(other, a) for a in props)
......
...@@ -78,7 +78,7 @@ class IfElse(_NoPythonOp): ...@@ -78,7 +78,7 @@ class IfElse(_NoPythonOp):
self.name = name self.name = name
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) is not type(other):
return False return False
if self.as_view != other.as_view: if self.as_view != other.as_view:
return False return False
......
...@@ -301,7 +301,7 @@ class Params(dict): ...@@ -301,7 +301,7 @@ class Params(dict):
def __eq__(self, other): def __eq__(self, other):
return ( return (
type(self) == type(other) type(self) is type(other)
and self.__params_type__ == other.__params_type__ and self.__params_type__ == other.__params_type__
and all( and all(
# NB: Params object should have been already filtered. # NB: Params object should have been already filtered.
...@@ -435,7 +435,7 @@ class ParamsType(CType): ...@@ -435,7 +435,7 @@ class ParamsType(CType):
def __eq__(self, other): def __eq__(self, other):
return ( return (
type(self) == type(other) type(self) is type(other)
and self.fields == other.fields and self.fields == other.fields
and self.types == other.types and self.types == other.types
) )
......
...@@ -515,7 +515,7 @@ class EnumType(CType, dict): ...@@ -515,7 +515,7 @@ class EnumType(CType, dict):
def __eq__(self, other): def __eq__(self, other):
return ( return (
type(self) == type(other) type(self) is type(other)
and self.ctype == other.ctype and self.ctype == other.ctype
and len(self) == len(other) and len(self) == len(other)
and len(self.aliases) == len(other.aliases) and len(self.aliases) == len(other.aliases)
......
...@@ -16,7 +16,7 @@ from pytensor.tensor.type import DenseTensorType ...@@ -16,7 +16,7 @@ from pytensor.tensor.type import DenseTensorType
class ExceptionType(Generic): class ExceptionType(Generic):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -51,7 +51,7 @@ class CheckAndRaise(COp): ...@@ -51,7 +51,7 @@ class CheckAndRaise(COp):
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}" return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) is not type(other):
return False return False
if self.msg == other.msg and self.exc_type == other.exc_type: if self.msg == other.msg and self.exc_type == other.exc_type:
......
...@@ -1074,7 +1074,7 @@ class unary_out_lookup(MetaObject): ...@@ -1074,7 +1074,7 @@ class unary_out_lookup(MetaObject):
return [rval] return [rval]
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.tbl == other.tbl return type(self) is type(other) and self.tbl == other.tbl
def __hash__(self): def __hash__(self):
return hash(type(self)) # ignore hash of table return hash(type(self)) # ignore hash of table
...@@ -1160,7 +1160,7 @@ class ScalarOp(COp): ...@@ -1160,7 +1160,7 @@ class ScalarOp(COp):
return self.grad(inputs, output_gradients) return self.grad(inputs, output_gradients)
def __eq__(self, other): def __eq__(self, other):
test = type(self) == type(other) and getattr( test = type(self) is type(other) and getattr(
self, "output_types_preference", None self, "output_types_preference", None
) == getattr(other, "output_types_preference", None) ) == getattr(other, "output_types_preference", None)
return test return test
...@@ -4133,7 +4133,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph): ...@@ -4133,7 +4133,7 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
if self is other: if self is other:
return True return True
if ( if (
type(self) != type(other) type(self) is not type(other)
or self.nin != other.nin or self.nin != other.nin
or self.nout != other.nout or self.nout != other.nout
): ):
......
...@@ -626,7 +626,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -626,7 +626,7 @@ class Chi2SF(BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -679,7 +679,7 @@ class GammaInc(BinaryScalarOp): ...@@ -679,7 +679,7 @@ class GammaInc(BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -732,7 +732,7 @@ class GammaIncC(BinaryScalarOp): ...@@ -732,7 +732,7 @@ class GammaIncC(BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -1045,7 +1045,7 @@ class GammaU(BinaryScalarOp): ...@@ -1045,7 +1045,7 @@ class GammaU(BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -1083,7 +1083,7 @@ class GammaL(BinaryScalarOp): ...@@ -1083,7 +1083,7 @@ class GammaL(BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
...@@ -1246,7 +1246,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1246,7 +1246,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
return apply_node return apply_node
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) is not type(other):
return False return False
if self.info != other.info: if self.info != other.info:
......
...@@ -462,7 +462,7 @@ class SparseConstantSignature(tuple): ...@@ -462,7 +462,7 @@ class SparseConstantSignature(tuple):
return ( return (
a == x a == x
and (b.dtype == y.dtype) and (b.dtype == y.dtype)
and (type(b) == type(y)) and (type(b) is type(y))
and (b.shape == y.shape) and (b.shape == y.shape)
and (abs(b - y).sum() < 1e-6 * b.nnz) and (abs(b - y).sum() < 1e-6 * b.nnz)
) )
......
...@@ -107,7 +107,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]): ...@@ -107,7 +107,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
return _eq(sa, sb) return _eq(sa, sb)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
...@@ -1742,7 +1742,7 @@ def local_reduce_broadcastable(fgraph, node): ...@@ -1742,7 +1742,7 @@ def local_reduce_broadcastable(fgraph, node):
ii += 1 ii += 1
new_reduced = reduced.dimshuffle(*pattern) new_reduced = reduced.dimshuffle(*pattern)
if new_axis: if new_axis:
if type(node.op) == CAReduce: if type(node.op) is CAReduce:
# This case handles `CAReduce` instances # This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the # (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses # scalar `Op`-specific subclasses
......
...@@ -370,7 +370,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -370,7 +370,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol) return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol)
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) is not type(other):
return NotImplemented return NotImplemented
return other.dtype == self.dtype and other.shape == self.shape return other.dtype == self.dtype and other.shape == self.shape
...@@ -624,7 +624,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -624,7 +624,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
class DenseTypeMeta(MetaType): class DenseTypeMeta(MetaType):
def __instancecheck__(self, o): def __instancecheck__(self, o):
if type(o) == TensorType or isinstance(o, DenseTypeMeta): if type(o) is TensorType or isinstance(o, DenseTypeMeta):
return True return True
return False return False
......
...@@ -64,7 +64,7 @@ class SliceType(Type[slice]): ...@@ -64,7 +64,7 @@ class SliceType(Type[slice]):
return "slice" return "slice"
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
...@@ -945,7 +945,7 @@ class TensorConstantSignature(tuple): ...@@ -945,7 +945,7 @@ class TensorConstantSignature(tuple):
""" """
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) is not type(other):
return False return False
try: try:
(t0, d0), (t1, d1) = self, other (t0, d0), (t1, d1) = self, other
...@@ -1105,7 +1105,7 @@ TensorType.constant_type = TensorConstant ...@@ -1105,7 +1105,7 @@ TensorType.constant_type = TensorConstant
class DenseVariableMeta(MetaType): class DenseVariableMeta(MetaType):
def __instancecheck__(self, o): def __instancecheck__(self, o):
if type(o) == TensorVariable or isinstance(o, DenseVariableMeta): if type(o) is TensorVariable or isinstance(o, DenseVariableMeta):
return True return True
return False return False
...@@ -1120,7 +1120,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): ...@@ -1120,7 +1120,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
class DenseConstantMeta(MetaType): class DenseConstantMeta(MetaType):
def __instancecheck__(self, o): def __instancecheck__(self, o):
if type(o) == TensorConstant or isinstance(o, DenseConstantMeta): if type(o) is TensorConstant or isinstance(o, DenseConstantMeta):
return True return True
return False return False
......
...@@ -55,7 +55,7 @@ class TypedListType(CType): ...@@ -55,7 +55,7 @@ class TypedListType(CType):
Two lists are equal if they contain the same type. Two lists are equal if they contain the same type.
""" """
return type(self) == type(other) and self.ttype == other.ttype return type(self) is type(other) and self.ttype == other.ttype
def __hash__(self): def __hash__(self):
return hash((type(self), self.ttype)) return hash((type(self), self.ttype))
......
...@@ -42,7 +42,7 @@ class CustomOpNoPropsNoEq(Op): ...@@ -42,7 +42,7 @@ class CustomOpNoPropsNoEq(Op):
class CustomOpNoProps(CustomOpNoPropsNoEq): class CustomOpNoProps(CustomOpNoPropsNoEq):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.a == other.a return type(self) is type(other) and self.a == other.a
def __hash__(self): def __hash__(self):
return hash((type(self), self.a)) return hash((type(self), self.a))
......
...@@ -31,8 +31,8 @@ class TestFunctionGraph: ...@@ -31,8 +31,8 @@ class TestFunctionGraph:
s = pickle.dumps(func) s = pickle.dumps(func)
new_func = pickle.loads(s) new_func = pickle.loads(s)
assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs)) assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs))
assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs))
assert all( assert all(
type(a.op) is type(b.op) # noqa: E721 type(a.op) is type(b.op) # noqa: E721
for a, b in zip(func.apply_nodes, new_func.apply_nodes) for a, b in zip(func.apply_nodes, new_func.apply_nodes)
......
...@@ -25,7 +25,7 @@ class MyType(Type): ...@@ -25,7 +25,7 @@ class MyType(Type):
self.thingy = thingy self.thingy = thingy
def __eq__(self, other): def __eq__(self, other):
return type(other) == type(self) and other.thingy == self.thingy return type(other) is type(self) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return str(self.thingy)
......
...@@ -71,7 +71,7 @@ class TDouble(CType): ...@@ -71,7 +71,7 @@ class TDouble(CType):
return (1,) return (1,)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) is type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
......
...@@ -348,7 +348,7 @@ class TestVerifyGradSparse: ...@@ -348,7 +348,7 @@ class TestVerifyGradSparse:
self.structured = structured self.structured = structured
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) and self.structured == other.structured return (type(self) is type(other)) and self.structured == other.structured
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.structured) return hash(type(self)) ^ hash(self.structured)
......
...@@ -3163,7 +3163,7 @@ def test_stack(): ...@@ -3163,7 +3163,7 @@ def test_stack():
sx, sy = dscalar(), dscalar() sx, sy = dscalar(), dscalar()
rval = inplace_func([sx, sy], stack([sx, sy]))(-4.0, -2.0) rval = inplace_func([sx, sy], stack([sx, sy]))(-4.0, -2.0)
assert type(rval) == np.ndarray assert type(rval) is np.ndarray
assert [-4, -2] == list(rval) assert [-4, -2] == list(rval)
......
...@@ -819,7 +819,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -819,7 +819,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
assert np.allclose(val, good), (val, good) assert np.allclose(val, good), (val, good)
# Test reuse of output memory # Test reuse of output memory
if type(AdvancedSubtensor1) == AdvancedSubtensor1: if type(AdvancedSubtensor1) is AdvancedSubtensor1:
op = AdvancedSubtensor1() op = AdvancedSubtensor1()
# When idx is a TensorConstant. # When idx is a TensorConstant.
if hasattr(idx, "data"): if hasattr(idx, "data"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论