Unverified 提交 3c47f74a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #173 from brandonwillard/fix-type-checking-issues

Use isinstance to check types in Type filter methods
...@@ -19,9 +19,7 @@ from theano.scalar.basic import ( ...@@ -19,9 +19,7 @@ from theano.scalar.basic import (
ComplexError, ComplexError,
Composite, Composite,
InRange, InRange,
IntDiv,
Scalar, Scalar,
TrueDiv,
add, add,
and_, and_,
arccos, arccos,
...@@ -44,7 +42,6 @@ from theano.scalar.basic import ( ...@@ -44,7 +42,6 @@ from theano.scalar.basic import (
expm1, expm1,
float16, float16,
float32, float32,
float64,
floats, floats,
int8, int8,
int32, int32,
...@@ -395,24 +392,6 @@ def test_mod_complex_fail(): ...@@ -395,24 +392,6 @@ def test_mod_complex_fail():
x % y x % y
def test_div_types():
a = int8()
b = int32()
c = complex64()
d = float64()
f = float32()
assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv)
def test_grad_gt(): def test_grad_gt():
x = float32(name="x") x = float32(name="x")
y = float32(name="y") y = float32(name="y")
......
import numpy as np
from theano import change_flags
from theano.scalar.basic import (
IntDiv,
Scalar,
TrueDiv,
complex64,
float32,
float64,
int8,
int32,
)
def test_div_types():
a = int8()
b = int32()
c = complex64()
d = float64()
f = float32()
assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv)
def test_filter_float_subclass():
"""Make sure `Scalar.filter` can handle `float` subclasses."""
with change_flags(floatX="float64"):
test_type = Scalar("float64")
nan = np.array([np.nan], dtype="float64")[0]
assert isinstance(nan, float)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, float)
with change_flags(floatX="float32"):
# Try again, except this time `nan` isn't a `float`
test_type = Scalar("float32")
nan = np.array([np.nan], dtype="float32")[0]
assert isinstance(nan, np.floating)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.floating)
import os.path as path
from tempfile import mkdtemp
import numpy as np
import pytest
from theano import change_flags, config
from theano.tensor.type import TensorType
def test_filter_variable():
test_type = TensorType(config.floatX, [])
with pytest.raises(TypeError):
test_type.filter(test_type())
def test_filter_strict():
test_type = TensorType(config.floatX, [])
with pytest.raises(TypeError):
test_type.filter(1, strict=True)
with pytest.raises(TypeError):
test_type.filter(np.array(1, dtype=int), strict=True)
def test_filter_ndarray_subclass():
"""Make sure `TensorType.filter` can handle NumPy `ndarray` subclasses."""
test_type = TensorType(config.floatX, [False])
class MyNdarray(np.ndarray):
pass
test_val = np.array([1.0], dtype=config.floatX).view(MyNdarray)
assert isinstance(test_val, MyNdarray)
res = test_type.filter(test_val)
assert isinstance(res, MyNdarray)
assert res is test_val
def test_filter_float_subclass():
"""Make sure `TensorType.filter` can handle `float` subclasses."""
with change_flags(floatX="float64"):
test_type = TensorType("float64", broadcastable=[])
nan = np.array([np.nan], dtype="float64")[0]
assert isinstance(nan, np.float) and not isinstance(nan, np.ndarray)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.ndarray)
with change_flags(floatX="float32"):
# Try again, except this time `nan` isn't a `float`
test_type = TensorType("float32", broadcastable=[])
nan = np.array([np.nan], dtype="float32")[0]
assert isinstance(nan, np.floating) and not isinstance(nan, np.ndarray)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.ndarray)
def test_filter_memmap():
"""Make sure `TensorType.filter` can handle NumPy `memmap`s subclasses."""
data = np.arange(12, dtype=config.floatX)
data.resize((3, 4))
filename = path.join(mkdtemp(), "newfile.dat")
fp = np.memmap(filename, dtype=config.floatX, mode="w+", shape=(3, 4))
test_type = TensorType(config.floatX, [False, False])
res = test_type.filter(fp)
assert res is fp
...@@ -445,7 +445,7 @@ def makeTester( ...@@ -445,7 +445,7 @@ def makeTester(
new_v = [] new_v = []
for inp in v: for inp in v:
if type(inp) is np.ndarray and inp.size > 0: if isinstance(inp, np.ndarray) and inp.size > 0:
f, fname = mkstemp() f, fname = mkstemp()
self.tmp_files.append((f, fname)) self.tmp_files.append((f, fname))
new_inp = np.memmap( new_inp = np.memmap(
......
...@@ -13,6 +13,7 @@ import time ...@@ -13,6 +13,7 @@ import time
import traceback import traceback
import warnings import warnings
from collections import OrderedDict, defaultdict, deque from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterable
from functools import reduce from functools import reduce
import numpy as np import numpy as np
...@@ -3152,7 +3153,7 @@ def copy_stack_trace(from_var, to_var): ...@@ -3152,7 +3153,7 @@ def copy_stack_trace(from_var, to_var):
# Store stack traces from from_var # Store stack traces from from_var
tr = [] tr = []
if type(from_var) is list: if isinstance(from_var, Iterable) and not isinstance(from_var, graph.Variable):
# If from_var is a list, store concatenated stack traces # If from_var is a list, store concatenated stack traces
for v in from_var: for v in from_var:
tr += getattr(v.tag, "trace", []) tr += getattr(v.tag, "trace", [])
...@@ -3167,7 +3168,7 @@ def copy_stack_trace(from_var, to_var): ...@@ -3167,7 +3168,7 @@ def copy_stack_trace(from_var, to_var):
tr = [tr] tr = [tr]
# Copy over stack traces to to_var # Copy over stack traces to to_var
if type(to_var) is list: if isinstance(to_var, Iterable) and not isinstance(to_var, graph.Variable):
# Copy over stack traces from from_var to each variable in # Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before # to_var, including the stack_trace of the to_var before
for v in to_var: for v in to_var:
......
...@@ -345,14 +345,12 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -345,14 +345,12 @@ def ifelse(condition, then_branch, else_branch, name=None):
""" """
rval_type = None rval_type = None
if type(then_branch) is list: if isinstance(then_branch, (list, tuple)):
rval_type = list rval_type = type(then_branch)
elif type(then_branch) is tuple: else:
rval_type = tuple
if type(then_branch) not in (list, tuple):
then_branch = [then_branch] then_branch = [then_branch]
if type(else_branch) not in (list, tuple):
if not isinstance(else_branch, (list, tuple)):
else_branch = [else_branch] else_branch = [else_branch]
# Some of the elements might be converted into another type, # Some of the elements might be converted into another type,
......
...@@ -109,7 +109,7 @@ class PersistentNdarrayID: ...@@ -109,7 +109,7 @@ class PersistentNdarrayID:
return name return name
def __call__(self, obj): def __call__(self, obj):
if type(obj) is np.ndarray: if isinstance(obj, np.ndarray):
if id(obj) not in self.seen: if id(obj) not in self.seen:
def write_array(f): def write_array(f):
......
...@@ -346,7 +346,7 @@ class Scalar(Type): ...@@ -346,7 +346,7 @@ class Scalar(Type):
allow_downcast allow_downcast
or ( or (
allow_downcast is None allow_downcast is None
and type(data) is float and isinstance(data, (float, np.floating))
and self.dtype == theano.config.floatX and self.dtype == theano.config.floatX
) )
or data == converted_data or data == converted_data
......
...@@ -1429,7 +1429,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1429,7 +1429,7 @@ class ScanSaveMem(gof.Optimizer):
flag_store = True flag_store = True
orphane_outs = [ orphane_outs = [
i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0) i for i, x in enumerate(store_steps) if isinstance(x, int) and (x < 0)
] ]
flag_store = flag_store or (len(orphane_outs) > 0) flag_store = flag_store or (len(orphane_outs) > 0)
# 3. is there anything to change ? # 3. is there anything to change ?
...@@ -1448,7 +1448,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1448,7 +1448,7 @@ class ScanSaveMem(gof.Optimizer):
offset = 1 + op.n_seqs + op.n_mit_mot offset = 1 + op.n_seqs + op.n_mit_mot
for idx, _val in enumerate(store_steps[op.n_mit_mot :]): for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
i = idx + op.n_mit_mot i = idx + op.n_mit_mot
if not (type(_val) is int and _val <= 0 and i not in required): if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op.n_mit_mot in required: if idx + op.n_mit_mot in required:
val = 1 val = 1
...@@ -1611,7 +1611,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1611,7 +1611,7 @@ class ScanSaveMem(gof.Optimizer):
for k, old in enumerate(old_outs): for k, old in enumerate(old_outs):
# Get the correct slice # Get the correct slice
cnf_slice, old_slices = slices[pos][k] cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice: if isinstance(cnf_slice[0], slice):
start = ( start = (
cnf_slice[0].start cnf_slice[0].start
- nw_steps - nw_steps
......
...@@ -3236,7 +3236,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -3236,7 +3236,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
constant_folding, constant_folding,
] ]
if type(slice1) is not slice: if not isinstance(slice1, slice):
raise ValueError( raise ValueError(
( (
"First provided slice should actually be of type" "First provided slice should actually be of type"
...@@ -3247,7 +3247,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -3247,7 +3247,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2) sl2, reverse2 = get_canonical_form_slice(slice2, len2)
if type(sl2) is not slice: if not isinstance(sl2, slice):
if reverse1 is None: if reverse1 is None:
# The first slice is not in reverse, which makes things a lot # The first slice is not in reverse, which makes things a lot
# more clear. # more clear.
...@@ -3398,7 +3398,7 @@ def local_subtensor_merge(node): ...@@ -3398,7 +3398,7 @@ def local_subtensor_merge(node):
pos_1 = 0 pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1] slice1 = slices1[pos_1]
if type(slice1) is slice: if isinstance(slice1, slice):
merged_slices.append( merged_slices.append(
merge_two_slices( merge_two_slices(
slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
...@@ -4360,7 +4360,9 @@ def local_useless_switch(node): ...@@ -4360,7 +4360,9 @@ def local_useless_switch(node):
""" """
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch): if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch):
cond = tt.extract_constant(node.inputs[0], only_process_constants=True) cond = tt.extract_constant(node.inputs[0], only_process_constants=True)
if (type(cond) is np.ndarray and cond.ndim == 0) or isinstance(cond, np.number): if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number
):
if cond == 0: if cond == 0:
correct_out = node.inputs[2] correct_out = node.inputs[2]
else: else:
......
...@@ -92,29 +92,25 @@ class TensorType(Type): ...@@ -92,29 +92,25 @@ class TensorType(Type):
"shared) variable instead of a numeric array?" "shared) variable instead of a numeric array?"
) )
if (type(data) is np.ndarray) and (data.dtype == self.numpy_dtype): if isinstance(data, np.memmap) and (data.dtype == self.numpy_dtype):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif (type(data) is np.memmap) and (data.dtype == self.numpy_dtype):
# numpy.memmap is a "safe" subclass of ndarray, # numpy.memmap is a "safe" subclass of ndarray,
# so we can use it wherever we expect a base ndarray. # so we can use it wherever we expect a base ndarray.
# however, casting it would defeat the purpose of not # however, casting it would defeat the purpose of not
# loading the whole data into memory # loading the whole data into memory
pass pass
elif isinstance(data, np.ndarray) and (data.dtype == self.numpy_dtype):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif strict: elif strict:
# If any of the two conditions above was not met, # If any of the two conditions above was not met,
# we raise a meaningful TypeError. # we raise a meaningful TypeError.
if not (type(data) is np.ndarray): if not isinstance(data, np.ndarray):
raise TypeError( raise TypeError(f"{self} expected a ndarray object (got {type(data)}).")
"%s expected a ndarray object." % self, data, type(data)
)
if data.dtype != self.numpy_dtype: if data.dtype != self.numpy_dtype:
raise TypeError( raise TypeError(
("%s expected a ndarray object with " "dtype = %s (got %s).") f"{self} expected an ndarray with dtype={self.numpy_dtype} (got {data.dtype})."
% (self, self.numpy_dtype, data.dtype)
) )
raise AssertionError("This point should never be reached.")
else: else:
if allow_downcast: if allow_downcast:
# Convert to self.dtype, regardless of the type of data # Convert to self.dtype, regardless of the type of data
...@@ -145,7 +141,7 @@ class TensorType(Type): ...@@ -145,7 +141,7 @@ class TensorType(Type):
raise TypeError(err_msg) raise TypeError(err_msg)
elif ( elif (
allow_downcast is None allow_downcast is None
and type(data) is float and isinstance(data, (float, np.floating))
and self.dtype == theano.config.floatX and self.dtype == theano.config.floatX
): ):
# Special case where we allow downcasting of Python float # Special case where we allow downcasting of Python float
...@@ -177,7 +173,7 @@ class TensorType(Type): ...@@ -177,7 +173,7 @@ class TensorType(Type):
'2) set "allow_input_downcast=True" when calling ' '2) set "allow_input_downcast=True" when calling '
'"function".' % (self, data, converted_data, self.dtype) '"function".' % (self, data, converted_data, self.dtype)
) )
raise TypeError(err_msg, data) raise TypeError(err_msg)
if self.ndim != data.ndim: if self.ndim != data.ndim:
raise TypeError( raise TypeError(
......
...@@ -937,7 +937,7 @@ class TensorConstantSignature(tuple): ...@@ -937,7 +937,7 @@ class TensorConstantSignature(tuple):
self._sum = self.no_nan.sum() self._sum = self.no_nan.sum()
# The following 2 lines are needede as in Python 3.3 with NumPy # The following 2 lines are needede as in Python 3.3 with NumPy
# 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable.
if type(self._sum) is np.memmap: if isinstance(self._sum, np.memmap):
self._sum = np.asarray(self._sum).item() self._sum = np.asarray(self._sum).item()
if self.has_nan and self.no_nan.mask.all(): if self.has_nan and self.no_nan.mask.all():
# In this case the sum is not properly computed by numpy. # In this case the sum is not properly computed by numpy.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论