Unverified 提交 3c47f74a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #173 from brandonwillard/fix-type-checking-issues

Use isinstance to check types in Type filter methods
......@@ -19,9 +19,7 @@ from theano.scalar.basic import (
ComplexError,
Composite,
InRange,
IntDiv,
Scalar,
TrueDiv,
add,
and_,
arccos,
......@@ -44,7 +42,6 @@ from theano.scalar.basic import (
expm1,
float16,
float32,
float64,
floats,
int8,
int32,
......@@ -395,24 +392,6 @@ def test_mod_complex_fail():
x % y
def test_div_types():
a = int8()
b = int32()
c = complex64()
d = float64()
f = float32()
assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv)
def test_grad_gt():
x = float32(name="x")
y = float32(name="y")
......
import numpy as np
from theano import change_flags
from theano.scalar.basic import (
IntDiv,
Scalar,
TrueDiv,
complex64,
float32,
float64,
int8,
int32,
)
def test_div_types():
a = int8()
b = int32()
c = complex64()
d = float64()
f = float32()
assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv)
def test_filter_float_subclass():
"""Make sure `Scalar.filter` can handle `float` subclasses."""
with change_flags(floatX="float64"):
test_type = Scalar("float64")
nan = np.array([np.nan], dtype="float64")[0]
assert isinstance(nan, float)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, float)
with change_flags(floatX="float32"):
# Try again, except this time `nan` isn't a `float`
test_type = Scalar("float32")
nan = np.array([np.nan], dtype="float32")[0]
assert isinstance(nan, np.floating)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.floating)
import os.path as path
from tempfile import mkdtemp
import numpy as np
import pytest
from theano import change_flags, config
from theano.tensor.type import TensorType
def test_filter_variable():
test_type = TensorType(config.floatX, [])
with pytest.raises(TypeError):
test_type.filter(test_type())
def test_filter_strict():
test_type = TensorType(config.floatX, [])
with pytest.raises(TypeError):
test_type.filter(1, strict=True)
with pytest.raises(TypeError):
test_type.filter(np.array(1, dtype=int), strict=True)
def test_filter_ndarray_subclass():
"""Make sure `TensorType.filter` can handle NumPy `ndarray` subclasses."""
test_type = TensorType(config.floatX, [False])
class MyNdarray(np.ndarray):
pass
test_val = np.array([1.0], dtype=config.floatX).view(MyNdarray)
assert isinstance(test_val, MyNdarray)
res = test_type.filter(test_val)
assert isinstance(res, MyNdarray)
assert res is test_val
def test_filter_float_subclass():
"""Make sure `TensorType.filter` can handle `float` subclasses."""
with change_flags(floatX="float64"):
test_type = TensorType("float64", broadcastable=[])
nan = np.array([np.nan], dtype="float64")[0]
assert isinstance(nan, np.float) and not isinstance(nan, np.ndarray)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.ndarray)
with change_flags(floatX="float32"):
# Try again, except this time `nan` isn't a `float`
test_type = TensorType("float32", broadcastable=[])
nan = np.array([np.nan], dtype="float32")[0]
assert isinstance(nan, np.floating) and not isinstance(nan, np.ndarray)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.ndarray)
def test_filter_memmap():
"""Make sure `TensorType.filter` can handle NumPy `memmap`s subclasses."""
data = np.arange(12, dtype=config.floatX)
data.resize((3, 4))
filename = path.join(mkdtemp(), "newfile.dat")
fp = np.memmap(filename, dtype=config.floatX, mode="w+", shape=(3, 4))
test_type = TensorType(config.floatX, [False, False])
res = test_type.filter(fp)
assert res is fp
......@@ -445,7 +445,7 @@ def makeTester(
new_v = []
for inp in v:
if type(inp) is np.ndarray and inp.size > 0:
if isinstance(inp, np.ndarray) and inp.size > 0:
f, fname = mkstemp()
self.tmp_files.append((f, fname))
new_inp = np.memmap(
......
......@@ -13,6 +13,7 @@ import time
import traceback
import warnings
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterable
from functools import reduce
import numpy as np
......@@ -3152,7 +3153,7 @@ def copy_stack_trace(from_var, to_var):
# Store stack traces from from_var
tr = []
if type(from_var) is list:
if isinstance(from_var, Iterable) and not isinstance(from_var, graph.Variable):
# If from_var is a list, store concatenated stack traces
for v in from_var:
tr += getattr(v.tag, "trace", [])
......@@ -3167,7 +3168,7 @@ def copy_stack_trace(from_var, to_var):
tr = [tr]
# Copy over stack traces to to_var
if type(to_var) is list:
if isinstance(to_var, Iterable) and not isinstance(to_var, graph.Variable):
# Copy over stack traces from from_var to each variable in
# to_var, including the stack_trace of the to_var before
for v in to_var:
......
......@@ -345,14 +345,12 @@ def ifelse(condition, then_branch, else_branch, name=None):
"""
rval_type = None
if type(then_branch) is list:
rval_type = list
elif type(then_branch) is tuple:
rval_type = tuple
if type(then_branch) not in (list, tuple):
if isinstance(then_branch, (list, tuple)):
rval_type = type(then_branch)
else:
then_branch = [then_branch]
if type(else_branch) not in (list, tuple):
if not isinstance(else_branch, (list, tuple)):
else_branch = [else_branch]
# Some of the elements might be converted into another type,
......
......@@ -109,7 +109,7 @@ class PersistentNdarrayID:
return name
def __call__(self, obj):
if type(obj) is np.ndarray:
if isinstance(obj, np.ndarray):
if id(obj) not in self.seen:
def write_array(f):
......
......@@ -346,7 +346,7 @@ class Scalar(Type):
allow_downcast
or (
allow_downcast is None
and type(data) is float
and isinstance(data, (float, np.floating))
and self.dtype == theano.config.floatX
)
or data == converted_data
......
......@@ -1429,7 +1429,7 @@ class ScanSaveMem(gof.Optimizer):
flag_store = True
orphane_outs = [
i for i, x in enumerate(store_steps) if (type(x) is int) and (x < 0)
i for i, x in enumerate(store_steps) if isinstance(x, int) and (x < 0)
]
flag_store = flag_store or (len(orphane_outs) > 0)
# 3. is there anything to change ?
......@@ -1448,7 +1448,7 @@ class ScanSaveMem(gof.Optimizer):
offset = 1 + op.n_seqs + op.n_mit_mot
for idx, _val in enumerate(store_steps[op.n_mit_mot :]):
i = idx + op.n_mit_mot
if not (type(_val) is int and _val <= 0 and i not in required):
if not (isinstance(_val, int) and _val <= 0 and i not in required):
if idx + op.n_mit_mot in required:
val = 1
......@@ -1611,7 +1611,7 @@ class ScanSaveMem(gof.Optimizer):
for k, old in enumerate(old_outs):
# Get the correct slice
cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice:
if isinstance(cnf_slice[0], slice):
start = (
cnf_slice[0].start
- nw_steps
......
......@@ -3236,7 +3236,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
constant_folding,
]
if type(slice1) is not slice:
if not isinstance(slice1, slice):
raise ValueError(
(
"First provided slice should actually be of type"
......@@ -3247,7 +3247,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
sl1, reverse1 = get_canonical_form_slice(slice1, len1)
sl2, reverse2 = get_canonical_form_slice(slice2, len2)
if type(sl2) is not slice:
if not isinstance(sl2, slice):
if reverse1 is None:
# The first slice is not in reverse, which makes things a lot
# more clear.
......@@ -3398,7 +3398,7 @@ def local_subtensor_merge(node):
pos_1 = 0
while (pos_1 < len(slices1)) and (pos_2 < len(slices2)):
slice1 = slices1[pos_1]
if type(slice1) is slice:
if isinstance(slice1, slice):
merged_slices.append(
merge_two_slices(
slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2]
......@@ -4360,7 +4360,9 @@ def local_useless_switch(node):
"""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch):
cond = tt.extract_constant(node.inputs[0], only_process_constants=True)
if (type(cond) is np.ndarray and cond.ndim == 0) or isinstance(cond, np.number):
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number
):
if cond == 0:
correct_out = node.inputs[2]
else:
......
......@@ -92,29 +92,25 @@ class TensorType(Type):
"shared) variable instead of a numeric array?"
)
if (type(data) is np.ndarray) and (data.dtype == self.numpy_dtype):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif (type(data) is np.memmap) and (data.dtype == self.numpy_dtype):
if isinstance(data, np.memmap) and (data.dtype == self.numpy_dtype):
# numpy.memmap is a "safe" subclass of ndarray,
# so we can use it wherever we expect a base ndarray.
# however, casting it would defeat the purpose of not
# loading the whole data into memory
pass
elif isinstance(data, np.ndarray) and (data.dtype == self.numpy_dtype):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif strict:
# If any of the two conditions above was not met,
# we raise a meaningful TypeError.
if not (type(data) is np.ndarray):
raise TypeError(
"%s expected a ndarray object." % self, data, type(data)
)
if not isinstance(data, np.ndarray):
raise TypeError(f"{self} expected a ndarray object (got {type(data)}).")
if data.dtype != self.numpy_dtype:
raise TypeError(
("%s expected a ndarray object with " "dtype = %s (got %s).")
% (self, self.numpy_dtype, data.dtype)
f"{self} expected an ndarray with dtype={self.numpy_dtype} (got {data.dtype})."
)
raise AssertionError("This point should never be reached.")
else:
if allow_downcast:
# Convert to self.dtype, regardless of the type of data
......@@ -145,7 +141,7 @@ class TensorType(Type):
raise TypeError(err_msg)
elif (
allow_downcast is None
and type(data) is float
and isinstance(data, (float, np.floating))
and self.dtype == theano.config.floatX
):
# Special case where we allow downcasting of Python float
......@@ -177,7 +173,7 @@ class TensorType(Type):
'2) set "allow_input_downcast=True" when calling '
'"function".' % (self, data, converted_data, self.dtype)
)
raise TypeError(err_msg, data)
raise TypeError(err_msg)
if self.ndim != data.ndim:
raise TypeError(
......
......@@ -937,7 +937,7 @@ class TensorConstantSignature(tuple):
self._sum = self.no_nan.sum()
# The following 2 lines are needede as in Python 3.3 with NumPy
# 1.7.1, numpy.ndarray and numpy.memmap aren't hashable.
if type(self._sum) is np.memmap:
if isinstance(self._sum, np.memmap):
self._sum = np.asarray(self._sum).item()
if self.has_nan and self.no_nan.mask.all():
# In this case the sum is not properly computed by numpy.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论