提交 100d62b1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Fix misc. naive type checks

This commit replaces a few instances of `type(x) is y` with `isinstance(x, y)`.
上级 2507f620
......@@ -445,7 +445,7 @@ def makeTester(
new_v = []
for inp in v:
if type(inp) is np.ndarray and inp.size > 0:
if isinstance(inp, np.ndarray) and inp.size > 0:
f, fname = mkstemp()
self.tmp_files.append((f, fname))
new_inp = np.memmap(
......
......@@ -13,6 +13,7 @@ import time
import traceback
import warnings
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterable
from functools import reduce
import numpy as np
......@@ -3152,7 +3153,7 @@ def copy_stack_trace(from_var, to_var):
# Store stack traces from from_var
tr = []
if type(from_var) is list:
if isinstance(from_var, Iterable) and not isinstance(from_var, graph.Variable):
# If from_var is a list, store concatenated stack traces
for v in from_var:
tr += getattr(v.tag, "trace", [])
......@@ -3167,7 +3168,7 @@ def copy_stack_trace(from_var, to_var):
tr = [tr]
# Copy over stack traces to to_var
if type(to_var) is list:
if isinstance(to_var, Iterable) and not isinstance(to_var, graph.Variable):
# Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before
for v in to_var:
......
......@@ -345,14 +345,12 @@ def ifelse(condition, then_branch, else_branch, name=None):
"""
rval_type = None
if type(then_branch) is list:
rval_type = list
elif type(then_branch) is tuple:
rval_type = tuple
if type(then_branch) not in (list, tuple):
if isinstance(then_branch, (list, tuple)):
rval_type = type(then_branch)
else:
then_branch = [then_branch]
if type(else_branch) not in (list, tuple):
if not isinstance(else_branch, (list, tuple)):
else_branch = [else_branch]
# Some of the elements might be converted into another type,
......
......@@ -109,7 +109,7 @@ class PersistentNdarrayID:
return name
def __call__(self, obj):
if type(obj) is np.ndarray:
if isinstance(obj, np.ndarray):
if id(obj) not in self.seen:
def write_array(f):
......
......@@ -1429,7 +1429,7 @@ class ScanSaveMem(gof.Optimizer):
flag_store = True
orphane_outs = [
i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0)
i for i, x in enumerate(store_steps) if isinstance(x, int) and (x < 0)
]
flag_store = flag_store or (len(orphane_outs) > 0)
# 3. is there anything to change ?
......@@ -1448,7 +1448,7 @@ class ScanSaveMem(gof.Optimizer):
offset = 1 + op.n_seqs + op.n_mit_mot
for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
i = idx + op.n_mit_mot
if not (type(_val) is int and _val <= 0 and i not in required):
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op.n_mit_mot in required:
val = 1
......@@ -1611,7 +1611,7 @@ class ScanSaveMem(gof.Optimizer):
for k, old in enumerate(old_outs):
# Get the correct slice
cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice:
if isinstance(cnf_slice[0], slice):
start = (
cnf_slice[0].start
- nw_steps
......
......@@ -3236,7 +3236,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
constant_folding,
]
if type(slice1) is not slice:
if not isinstance(slice1, slice):
raise ValueError(
(
"First provided slice should actually be of type"
......@@ -3247,7 +3247,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2)
if type(sl2) is not slice:
if not isinstance(sl2, slice):
if reverse1 is None:
# The first slice is not in reverse, which makes things a lot
# more clear.
......@@ -3398,7 +3398,7 @@ def local_subtensor_merge(node):
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if type(slice1) is slice:
if isinstance(slice1, slice):
merged_slices.append(
merge_two_slices(
slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
......@@ -4360,7 +4360,9 @@ def local_useless_switch(node):
"""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch):
cond = tt.extract_constant(node.inputs[0], only_process_constants=True)
if (type(cond) is np.ndarray and cond.ndim == 0) or isinstance(cond, np.number):
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number
):
if cond == 0:
correct_out = node.inputs[2]
else:
......
......@@ -937,7 +937,7 @@ class TensorConstantSignature(tuple):
self._sum = self.no_nan.sum()
# The following 2 lines are needede as in Python 3.3 with NumPy
# 1.7.1, numpy.ndarray and numpy.memmap aren't hashable.
if type(self._sum) is np.memmap:
if isinstance(self._sum, np.memmap):
self._sum = np.asarray(self._sum).item()
if self.has_nan and self.no_nan.mask.all():
# In this case the sum is not properly computed by numpy.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论