提交 fbc28965 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix overly strict type checks across the entire codebase

上级 44066869
......@@ -680,15 +680,16 @@ def _lessbroken_deepcopy(a):
# This logic is also in link.py
from aesara.graph.type import _cdata_type
if type(a) in (np.ndarray, np.memmap):
if isinstance(a, (np.ndarray, np.memmap)):
rval = a.copy(order="K")
elif type(a) is _cdata_type:
elif isinstance(a, _cdata_type):
# This is not copyable (and should be used for constant data).
rval = a
else:
rval = copy.deepcopy(a)
assert type(rval) == type(a), (type(rval), type(a))
if isinstance(rval, np.ndarray):
assert rval.dtype == a.dtype
return rval
......@@ -2006,7 +2007,7 @@ class _Linker(LocalLinker):
# HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION
# TOOK PLACE
if (
(type(dr_vals[r][0]) in (np.ndarray, np.memmap))
isinstance(dr_vals[r][0], (np.ndarray, np.memmap))
and (dr_vals[r][0].dtype == storage_map[r][0].dtype)
and (dr_vals[r][0].shape == storage_map[r][0].shape)
):
......
......@@ -399,7 +399,7 @@ def pfunc(
if profile is True:
profile = ProfileStats(message=name)
# profile -> object
elif type(profile) == str:
elif isinstance(profile, str):
profile = ProfileStats(message=profile)
# profile is typically either False or an object at this point.
# No need to block other objects being passed through though. It might be
......
......@@ -731,7 +731,7 @@ class Function:
message = str(profile.message) + " copy"
profile = aesara.compile.profiling.ProfileStats(message=message)
# profile -> object
elif type(profile) == str:
elif isinstance(profile, str):
profile = aesara.compile.profiling.ProfileStats(message=profile)
f_cpy = maker.__class__(
......
......@@ -308,7 +308,7 @@ class Mode:
):
if linker is None:
linker = config.linker
if type(optimizer) == str and optimizer == "default":
if isinstance(optimizer, str) and optimizer == "default":
optimizer = config.optimizer
self.__setstate__((linker, optimizer))
......
......@@ -2263,7 +2263,7 @@ class ConvMetaOptimizer(LocalMetaOptimizer):
):
return result
if type(node.op) in [AbstractConv2d, AbstractConv3d]:
if isinstance(node.op, (AbstractConv2d, AbstractConv3d)):
img, kern = node.inputs
for (var, shape) in zip((img, kern), shapes):
result[var] = aesara.shared(
......@@ -2273,7 +2273,9 @@ class ConvMetaOptimizer(LocalMetaOptimizer):
borrow=True,
)
if type(node.op) in [AbstractConv2d_gradWeights, AbstractConv3d_gradWeights]:
if isinstance(
node.op, (AbstractConv2d_gradWeights, AbstractConv3d_gradWeights)
):
img, top, kshape = node.inputs
tshp = get_conv_output_shape(
......@@ -2295,7 +2297,7 @@ class ConvMetaOptimizer(LocalMetaOptimizer):
borrow=True,
)
if type(node.op) in [AbstractConv2d_gradInputs, AbstractConv3d_gradInputs]:
if isinstance(node.op, (AbstractConv2d_gradInputs, AbstractConv3d_gradInputs)):
kern, top, ishape = node.inputs
tshp = get_conv_output_shape(
......
......@@ -358,7 +358,7 @@ def inplace_allocempty(op, idx):
@local_optimizer([op], inplace=True)
@wraps(maker)
def opt(fgraph, node):
if type(node.op) != op or node.op.inplace:
if not isinstance(node.op, op) or node.op.inplace:
return
inputs = list(node.inputs)
alloc = inputs[idx]
......@@ -460,7 +460,7 @@ def op_lifter(OP, cuda_only=False):
def f(maker):
def local_opt(fgraph, node):
if type(node.op) in OP:
if isinstance(node.op, OP):
# Either one of our inputs is on the gpu or
# all of our clients are on the gpu
replace = False
......
......@@ -331,7 +331,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
@register_opt2([mrg_uniform], "fast_compile")
def local_gpua_mrg_graph(fgraph, op, context_name, inputs, outputs):
if (
type(op) == mrg_uniform
isinstance(op, mrg_uniform)
and isinstance(inputs[0].type, GpuArrayType)
and (inputs[0].owner is None or not isinstance(inputs[0].owner.op, GpuFromHost))
):
......
......@@ -279,7 +279,7 @@ class GpuArrayType(CType):
# fallthrough to ndim check
elif allow_downcast or (
allow_downcast is None
and type(data) == float
and isinstance(data, float)
and self.dtype == config.floatX
):
if not isinstance(data, gpuarray.GpuArray):
......@@ -427,7 +427,7 @@ class GpuArrayType(CType):
def convert_variable(self, var):
vt = var.type
if (
type(self) == type(vt)
isinstance(vt, type(self))
and self.typecode == vt.typecode
and self.ndim == vt.ndim
and self.context_name == vt.context_name
......
......@@ -386,7 +386,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, disconnected_inputs="raise"
coordinates of the tensor element in the last
If `f` is a list/tuple, then return a list/tuple with the results.
"""
if type(eval_points) not in (list, tuple):
if not isinstance(eval_points, (list, tuple)):
eval_points = [eval_points]
using_list = isinstance(wrt, list)
......
......@@ -268,13 +268,13 @@ def ff_2p72(rstate):
def mrg_next_value(rstate, new_rstate, NORM, mask, offset):
# TODO : need description for method, parameter and return
x11, x12, x13, x21, x22, x23 = rstate
assert type(x11) == np.int32
assert isinstance(x11, np.int32)
i0, i7, i9, i15, i16, i22, i24 = np_int32_vals
# first component
y1 = ((x12 & MASK12) << i22) + (x12 >> i9) + ((x13 & MASK13) << i7) + (x13 >> i24)
assert type(y1) == np.int32
assert isinstance(y1, np.int32)
if y1 < 0 or y1 >= M1: # must also check overflow
y1 -= M1
y1 += x13
......@@ -287,11 +287,11 @@ def mrg_next_value(rstate, new_rstate, NORM, mask, offset):
# second component
y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16))
assert type(y1) == np.int32
assert isinstance(y1, np.int32)
if y1 < 0 or y1 >= M2:
y1 -= M2
y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16))
assert type(y2) == np.int32
assert isinstance(y2, np.int32)
if y2 < 0 or y2 >= M2:
y2 -= M2
y2 += x23
......
......@@ -961,7 +961,7 @@ class transfer_type(MetaObject):
__props__ = ("transfer",)
def __init__(self, *transfer):
assert all(type(x) in [int, str] or x is None for x in transfer)
assert all(isinstance(x, (int, str)) or x is None for x in transfer)
self.transfer = transfer
def __str__(self):
......@@ -4363,7 +4363,7 @@ class Compositef32:
else:
ni = i
mapping[i] = ni
if type(node.op) in self.special:
if isinstance(node.op, tuple(self.special.keys())):
self.special[type(node.op)](node, mapping)
continue
new_node = node.clone_with_new_inputs(
......
......@@ -1073,7 +1073,7 @@ def scan(
pass
scan_inputs += [arg]
scan_outs = local_op(*scan_inputs)
if type(scan_outs) not in (list, tuple):
if not isinstance(scan_outs, (list, tuple)):
scan_outs = [scan_outs]
##
# Step 9. Figure out which outs are update rules for shared variables
......
......@@ -2778,7 +2778,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
allow_gc=self.allow_gc,
)
outputs = local_op(*outer_inputs)
if type(outputs) not in (list, tuple):
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
# Re-order the gradients correctly
gradients = [DisconnectedType()()]
......@@ -2922,7 +2922,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if self.info.n_shared_outs > 0:
rop_self_outputs = rop_self_outputs[: -self.info.n_shared_outs]
rop_outs = Rop(rop_self_outputs, rop_of_inputs, inner_eval_points)
if type(rop_outs) not in (list, tuple):
if not isinstance(rop_outs, (list, tuple)):
rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan
......@@ -3112,7 +3112,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
allow_gc=self.allow_gc,
)
outputs = local_op(*scan_inputs)
if type(outputs) not in (list, tuple):
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
# Select only the result of the R_op results
final_outs = []
......
......@@ -1200,7 +1200,7 @@ def save_mem_new_scan(fgraph, node):
# 2.1 outputs of the function
# => output needs all its intermediate values
if type(cl) == str:
if isinstance(cl, str):
# if the node is actually an output, then
# we need to store the entire thing
global_nsteps = None
......@@ -1263,13 +1263,13 @@ def save_mem_new_scan(fgraph, node):
if isinstance(stop, Variable):
global_nsteps["sym"] += [stop]
# not if it is maxsize
elif type(stop) == int and stop == maxsize:
elif isinstance(stop, int) and stop == maxsize:
global_nsteps = None
# yes if it is a int k, 0 < k < maxsize
elif type(stop) == int and global_nsteps["real"] < stop:
elif isinstance(stop, int) and global_nsteps["real"] < stop:
global_nsteps["real"] = stop
# yes if it is a int k, 0 < k < maxsize
elif type(stop) == int and stop > 0:
elif isinstance(stop, int) and stop > 0:
pass
# not otherwise
else:
......@@ -1311,7 +1311,7 @@ def save_mem_new_scan(fgraph, node):
for i, out in enumerate(node.outputs[:c_outs]):
# look at all its clients
for cl, _ in fgraph.clients[out]:
if type(cl) == str:
if isinstance(cl, str):
store_steps[i] = 0
break
elif not isinstance(cl.op, Subtensor):
......@@ -2275,7 +2275,7 @@ def push_out_dot1_scan(fgraph, node):
)
new_outs = new_op(*_scan_inputs)
if type(new_outs) not in (list, tuple):
if not isinstance(new_outs, (list, tuple)):
new_outs = [new_outs]
# We need now to pair correctly the new outputs
......
......@@ -262,7 +262,7 @@ class InplaceElemwiseOptimizer(GlobalOptimizer):
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
# gpuarray GpuElemwise inherit from Elemwise
if not type(op) == self.op:
if not isinstance(op, self.op):
continue
# If big graph and the outputs are scalar, do not make it
# inplace.
......
......@@ -223,7 +223,7 @@ class DimShuffle(ExternalCOp):
(res,) = inp
(storage,) = out
if type(res) != np.ndarray and type(res) != np.memmap:
if not isinstance(res, (np.ndarray, np.memmap)):
raise TypeError(res)
res = res.transpose(self.transposition)
......
......@@ -1695,7 +1695,12 @@ def local_reduce_broadcastable(fgraph, node):
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
if type(node.op) == CAReduce:
# This happen for at_max(), at_min()
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
# TODO FIXME: This highlights a major design flaw in
# `CAReduce` (or at least our use of it), and it needs
# to be fixed
new_op = node.op.__class__(node.op.scalar_op, axis=new_axis)
else:
new_op = node.op.__class__(axis=new_axis)
......
......@@ -736,7 +736,9 @@ class AbstractBatchNormTrainGrad(Op):
aesara.gradient.DisconnectedType()(),
]
return [
aesara.gradient.DisconnectedType()() if (type(r) == int and r == 0) else r
aesara.gradient.DisconnectedType()()
if (isinstance(r, int) and r == 0)
else r
for r in results
]
......
......@@ -195,7 +195,7 @@ class Images2Neibs(COp):
ten4, neib_shape, neib_step = inp
(z,) = out_
# GpuImages2Neibs should not run this perform in DebugMode
if type(self) != Images2Neibs:
if not isinstance(self, Images2Neibs):
raise aesara.graph.utils.MethodNotDefined()
def CEIL_INTDIV(a, b):
......
......@@ -656,7 +656,7 @@ def local_subtensor_of_alloc(fgraph, node):
if nw_val.ndim > len(nw_dims):
return False
rval = alloc(nw_val, *nw_dims)
if type(rval) not in (list, tuple):
if not isinstance(rval, (list, tuple)):
rval = [rval]
if rval[0].type != node.outputs[0].type:
# It happen that the make_node() isn't able to infer the same pattern.
......
......@@ -345,7 +345,7 @@ There are several ways to make sure that equality testing works properly:
.. testcode::
def __eq__(self, other):
return type(self) is Double and type(other) is Double
return type(self) == type(other)
#. Override :meth:`Double.__new__` to always return the same instance.
#. Hide the Double class and only advertise a single instance of it.
......
......@@ -209,7 +209,7 @@ class multiple_outputs_numeric_grad:
def scan_project_sum(*args, **kwargs):
rng = RandomStream(123)
scan_outputs, updates = scan(*args, **kwargs)
if type(scan_outputs) not in [list, tuple]:
if not isinstance(scan_outputs, (list, tuple)):
scan_outputs = [scan_outputs]
# we should ignore the random-state updates so that
# the uniform numbers are the same every evaluation and on every call
......
......@@ -339,15 +339,17 @@ def test_local_subtensor_remove_broadcastable_index():
z8 = y3[0, :, 0, :, 0]
f = function([x], [z1, z2, z3, z4, z5, z6, z7, z8], mode=mode)
for elem in f.maker.fgraph.toposort():
assert type(elem.op) not in [
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
]
assert not isinstance(
elem.op,
(
Subtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
),
)
rng = np.random.default_rng(seed=utt.fetch_seed())
xn = rng.random((5, 5))
f(xn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论