提交 ff98ab8f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Coerce dtype __props__ to string due to invalid hash of `np.dtype()` objects

上级 864ebd1a
......@@ -96,7 +96,7 @@ class Binomial(Op):
def __init__(self, format, dtype):
self.format = format
self.dtype = dtype
self.dtype = np.dtype(dtype).name
def make_node(self, n, p, shape):
n = pt.as_tensor_variable(n)
......
......@@ -1090,6 +1090,8 @@ class Tri(Op):
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype
def make_node(self, N, M, k):
......@@ -1368,6 +1370,8 @@ class Eye(Op):
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype
def make_node(self, n, m, k):
......@@ -3225,7 +3229,7 @@ class ARange(COp):
__props__ = ("dtype",)
def __init__(self, dtype):
self.dtype = dtype
self.dtype = np.dtype(dtype).name
def make_node(self, start, stop, step):
from math import ceil
......@@ -3407,7 +3411,8 @@ def arange(start, stop=None, step=1, dtype=None):
# We use the same dtype as numpy instead of the result of
# the upcast.
dtype = str(numpy_dtype)
else:
dtype = np.dtype(dtype).name
if dtype not in _arange:
_arange[dtype] = ARange(dtype)
return _arange[dtype](start, stop, step)
......
......@@ -1234,8 +1234,8 @@ class CAReduce(COp):
else:
self.axis = tuple(axis)
self.dtype = dtype
self.acc_dtype = acc_dtype
self.dtype = dtype if dtype is None else np.dtype(dtype).name
self.acc_dtype = acc_dtype if acc_dtype is None else np.dtype(acc_dtype).name
self.upcast_discrete_output = upcast_discrete_output
@property
......
......@@ -25,7 +25,7 @@ class LoadFromDisk(Op):
__props__ = ("dtype", "shape", "mmap_mode")
def __init__(self, dtype, shape, mmap_mode=None):
self.dtype = np.dtype(dtype) # turn "float64" into np.float64
self.dtype = np.dtype(dtype).name
self.shape = shape
if mmap_mode not in (None, "c"):
raise ValueError(
......
......@@ -112,6 +112,8 @@ class RandomVariable(Op):
else:
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
if isinstance(dtype, np.dtype):
dtype = dtype.name
self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = (
......
......@@ -2869,6 +2869,18 @@ class TestARange:
assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape
assert np.arange(-64, 64).shape == arange(-64, 64).type.shape
def test_c_cache_bug(self):
# Regression test for bug caused by issues in hash of `np.dtype()` objects
# https://github.com/numpy/numpy/issues/17864
end = iscalar("end")
arange1 = ARange(np.dtype("float64"))(0, end, 1)
arange2 = ARange("float64")(0, end + 1, 1)
assert arange1.owner.op == arange2.owner.op
assert hash(arange1.owner.op) == hash(arange2.owner.op)
fn = function([end], [arange1, arange2])
res1, res2 = fn(10)
np.testing.assert_array_equal(res1, res2[:-1], strict=True)
class TestNdGrid:
def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论