提交 2507f620 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Use isinstance to check types in filter

This commit also introduces new `test_type` modules with direct tests for the filter methods.
上级 55b36403
......@@ -19,9 +19,7 @@ from theano.scalar.basic import (
ComplexError,
Composite,
InRange,
IntDiv,
Scalar,
TrueDiv,
add,
and_,
arccos,
......@@ -44,7 +42,6 @@ from theano.scalar.basic import (
expm1,
float16,
float32,
float64,
floats,
int8,
int32,
......@@ -395,24 +392,6 @@ def test_mod_complex_fail():
x % y
def test_div_types():
a = int8()
b = int32()
c = complex64()
d = float64()
f = float32()
assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv)
def test_grad_gt():
x = float32(name="x")
y = float32(name="y")
......
import numpy as np
from theano import change_flags
from theano.scalar.basic import (
IntDiv,
Scalar,
TrueDiv,
complex64,
float32,
float64,
int8,
int32,
)
def test_div_types():
a = int8()
b = int32()
c = complex64()
d = float64()
f = float32()
assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((a / c).owner.op, TrueDiv)
def test_filter_float_subclass():
"""Make sure `Scalar.filter` can handle `float` subclasses."""
with change_flags(floatX="float64"):
test_type = Scalar("float64")
nan = np.array([np.nan], dtype="float64")[0]
assert isinstance(nan, float)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, float)
with change_flags(floatX="float32"):
# Try again, except this time `nan` isn't a `float`
test_type = Scalar("float32")
nan = np.array([np.nan], dtype="float32")[0]
assert isinstance(nan, np.floating)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.floating)
import os.path as path
from tempfile import mkdtemp
import numpy as np
import pytest
from theano import change_flags, config
from theano.tensor.type import TensorType
def test_filter_variable():
test_type = TensorType(config.floatX, [])
with pytest.raises(TypeError):
test_type.filter(test_type())
def test_filter_strict():
test_type = TensorType(config.floatX, [])
with pytest.raises(TypeError):
test_type.filter(1, strict=True)
with pytest.raises(TypeError):
test_type.filter(np.array(1, dtype=int), strict=True)
def test_filter_ndarray_subclass():
"""Make sure `TensorType.filter` can handle NumPy `ndarray` subclasses."""
test_type = TensorType(config.floatX, [False])
class MyNdarray(np.ndarray):
pass
test_val = np.array([1.0], dtype=config.floatX).view(MyNdarray)
assert isinstance(test_val, MyNdarray)
res = test_type.filter(test_val)
assert isinstance(res, MyNdarray)
assert res is test_val
def test_filter_float_subclass():
"""Make sure `TensorType.filter` can handle `float` subclasses."""
with change_flags(floatX="float64"):
test_type = TensorType("float64", broadcastable=[])
nan = np.array([np.nan], dtype="float64")[0]
assert isinstance(nan, np.float) and not isinstance(nan, np.ndarray)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.ndarray)
with change_flags(floatX="float32"):
# Try again, except this time `nan` isn't a `float`
test_type = TensorType("float32", broadcastable=[])
nan = np.array([np.nan], dtype="float32")[0]
assert isinstance(nan, np.floating) and not isinstance(nan, np.ndarray)
filtered_nan = test_type.filter(nan)
assert isinstance(filtered_nan, np.ndarray)
def test_filter_memmap():
"""Make sure `TensorType.filter` can handle NumPy `memmap`s subclasses."""
data = np.arange(12, dtype=config.floatX)
data.resize((3, 4))
filename = path.join(mkdtemp(), "newfile.dat")
fp = np.memmap(filename, dtype=config.floatX, mode="w+", shape=(3, 4))
test_type = TensorType(config.floatX, [False, False])
res = test_type.filter(fp)
assert res is fp
......@@ -346,7 +346,7 @@ class Scalar(Type):
allow_downcast
or (
allow_downcast is None
and type(data) is float
and isinstance(data, (float, np.floating))
and self.dtype == theano.config.floatX
)
or data == converted_data
......
......@@ -92,29 +92,25 @@ class TensorType(Type):
"shared) variable instead of a numeric array?"
)
if (type(data) is np.ndarray) and (data.dtype == self.numpy_dtype):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif (type(data) is np.memmap) and (data.dtype == self.numpy_dtype):
if isinstance(data, np.memmap) and (data.dtype == self.numpy_dtype):
# numpy.memmap is a "safe" subclass of ndarray,
# so we can use it wherever we expect a base ndarray.
# however, casting it would defeat the purpose of not
# loading the whole data into memory
pass
elif isinstance(data, np.ndarray) and (data.dtype == self.numpy_dtype):
if data.dtype.num != self.numpy_dtype.num:
data = theano._asarray(data, dtype=self.dtype)
# -- now fall through to ndim check
elif strict:
# If any of the two conditions above was not met,
# we raise a meaningful TypeError.
if not (type(data) is np.ndarray):
raise TypeError(
"%s expected a ndarray object." % self, data, type(data)
)
if not isinstance(data, np.ndarray):
raise TypeError(f"{self} expected a ndarray object (got {type(data)}).")
if data.dtype != self.numpy_dtype:
raise TypeError(
("%s expected a ndarray object with " "dtype = %s (got %s).")
% (self, self.numpy_dtype, data.dtype)
f"{self} expected an ndarray with dtype={self.numpy_dtype} (got {data.dtype})."
)
raise AssertionError("This point should never be reached.")
else:
if allow_downcast:
# Convert to self.dtype, regardless of the type of data
......@@ -145,7 +141,7 @@ class TensorType(Type):
raise TypeError(err_msg)
elif (
allow_downcast is None
and type(data) is float
and isinstance(data, (float, np.floating))
and self.dtype == theano.config.floatX
):
# Special case where we allow downcasting of Python float
......@@ -177,7 +173,7 @@ class TensorType(Type):
'2) set "allow_input_downcast=True" when calling '
'"function".' % (self, data, converted_data, self.dtype)
)
raise TypeError(err_msg, data)
raise TypeError(err_msg)
if self.ndim != data.ndim:
raise TypeError(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论