提交 100d62b1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Fix misc. naive type checks

This commit replaces a few instances of `type(x) is y` with `isinstance(x, y)`.
上级 2507f620
...@@ -445,7 +445,7 @@ def makeTester( ...@@ -445,7 +445,7 @@ def makeTester(
new_v = [] new_v = []
for inp in v: for inp in v:
if type(inp) is np.ndarray and inp.size > 0: if isinstance(inp, np.ndarray) and inp.size > 0:
f, fname = mkstemp() f, fname = mkstemp()
self.tmp_files.append((f, fname)) self.tmp_files.append((f, fname))
new_inp = np.memmap( new_inp = np.memmap(
......
...@@ -13,6 +13,7 @@ import time ...@@ -13,6 +13,7 @@ import time
import traceback import traceback
import warnings import warnings
from collections import OrderedDict, defaultdict, deque from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterable
from functools import reduce from functools import reduce
import numpy as np import numpy as np
...@@ -3152,7 +3153,7 @@ def copy_stack_trace(from_var, to_var): ...@@ -3152,7 +3153,7 @@ def copy_stack_trace(from_var, to_var):
# Store stack traces from from_var # Store stack traces from from_var
tr = [] tr = []
if type(from_var) is list: if isinstance(from_var, Iterable) and not isinstance(from_var, graph.Variable):
# If from_var is a list, store concatenated stack traces # If from_var is a list, store concatenated stack traces
for v in from_var: for v in from_var:
tr += getattr(v.tag, "trace", []) tr += getattr(v.tag, "trace", [])
...@@ -3167,7 +3168,7 @@ def copy_stack_trace(from_var, to_var): ...@@ -3167,7 +3168,7 @@ def copy_stack_trace(from_var, to_var):
tr = [tr] tr = [tr]
# Copy over stack traces to to_var # Copy over stack traces to to_var
if type(to_var) is list: if isinstance(to_var, Iterable) and not isinstance(to_var, graph.Variable):
# Copy over stack traces from from_var to each variable in # Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before # to_var, including the stack_trace of the to_var before
for v in to_var: for v in to_var:
......
...@@ -345,14 +345,12 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -345,14 +345,12 @@ def ifelse(condition, then_branch, else_branch, name=None):
""" """
rval_type = None rval_type = None
if type(then_branch) is list: if isinstance(then_branch, (list, tuple)):
rval_type = list rval_type = type(then_branch)
elif type(then_branch) is tuple: else:
rval_type = tuple
if type(then_branch) not in (list, tuple):
then_branch = [then_branch] then_branch = [then_branch]
if type(else_branch) not in (list, tuple):
if not isinstance(else_branch, (list, tuple)):
else_branch = [else_branch] else_branch = [else_branch]
# Some of the elements might be converted into another type, # Some of the elements might be converted into another type,
......
...@@ -109,7 +109,7 @@ class PersistentNdarrayID: ...@@ -109,7 +109,7 @@ class PersistentNdarrayID:
return name return name
def __call__(self, obj): def __call__(self, obj):
if type(obj) is np.ndarray: if isinstance(obj, np.ndarray):
if id(obj) not in self.seen: if id(obj) not in self.seen:
def write_array(f): def write_array(f):
......
...@@ -1429,7 +1429,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1429,7 +1429,7 @@ class ScanSaveMem(gof.Optimizer):
flag_store = True flag_store = True
orphane_outs = [ orphane_outs = [
i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0) i for i, x in enumerate(store_steps) if isinstance(x, int) and (x < 0)
] ]
flag_store = flag_store or (len(orphane_outs) > 0) flag_store = flag_store or (len(orphane_outs) > 0)
# 3. is there anything to change ? # 3. is there anything to change ?
...@@ -1448,7 +1448,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1448,7 +1448,7 @@ class ScanSaveMem(gof.Optimizer):
offset = 1 + op.n_seqs + op.n_mit_mot offset = 1 + op.n_seqs + op.n_mit_mot
for idx, _val in enumerate(store_steps[op.n_mit_mot :]): for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
i = idx + op.n_mit_mot i = idx + op.n_mit_mot
if not (type(_val) is int and _val <= 0 and i not in required): if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op.n_mit_mot in required: if idx + op.n_mit_mot in required:
val = 1 val = 1
...@@ -1611,7 +1611,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1611,7 +1611,7 @@ class ScanSaveMem(gof.Optimizer):
for k, old in enumerate(old_outs): for k, old in enumerate(old_outs):
# Get the correct slice # Get the correct slice
cnf_slice, old_slices = slices[pos][k] cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice: if isinstance(cnf_slice[0], slice):
start = ( start = (
cnf_slice[0].start cnf_slice[0].start
- nw_steps - nw_steps
......
...@@ -3236,7 +3236,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -3236,7 +3236,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
constant_folding, constant_folding,
] ]
if type(slice1) is not slice: if not isinstance(slice1, slice):
raise ValueError( raise ValueError(
( (
"First provided slice should actually be of type" "First provided slice should actually be of type"
...@@ -3247,7 +3247,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -3247,7 +3247,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2) sl2, reverse2 = get_canonical_form_slice(slice2, len2)
if type(sl2) is not slice: if not isinstance(sl2, slice):
if reverse1 is None: if reverse1 is None:
# The first slice is not in reverse, which makes things a lot # The first slice is not in reverse, which makes things a lot
# more clear. # more clear.
...@@ -3398,7 +3398,7 @@ def local_subtensor_merge(node): ...@@ -3398,7 +3398,7 @@ def local_subtensor_merge(node):
pos_1 = 0 pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1] slice1 = slices1[pos_1]
if type(slice1) is slice: if isinstance(slice1, slice):
merged_slices.append( merged_slices.append(
merge_two_slices( merge_two_slices(
slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
...@@ -4360,7 +4360,9 @@ def local_useless_switch(node): ...@@ -4360,7 +4360,9 @@ def local_useless_switch(node):
""" """
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch): if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch):
cond = tt.extract_constant(node.inputs[0], only_process_constants=True) cond = tt.extract_constant(node.inputs[0], only_process_constants=True)
if (type(cond) is np.ndarray and cond.ndim == 0) or isinstance(cond, np.number): if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number
):
if cond == 0: if cond == 0:
correct_out = node.inputs[2] correct_out = node.inputs[2]
else: else:
......
...@@ -937,7 +937,7 @@ class TensorConstantSignature(tuple): ...@@ -937,7 +937,7 @@ class TensorConstantSignature(tuple):
self._sum = self.no_nan.sum() self._sum = self.no_nan.sum()
# The following 2 lines are needede as in Python 3.3 with NumPy # The following 2 lines are needede as in Python 3.3 with NumPy
# 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable.
if type(self._sum) is np.memmap: if isinstance(self._sum, np.memmap):
self._sum = np.asarray(self._sum).item() self._sum = np.asarray(self._sum).item()
if self.has_nan and self.no_nan.mask.all(): if self.has_nan and self.no_nan.mask.all():
# In this case the sum is not properly computed by numpy. # In this case the sum is not properly computed by numpy.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论