提交 ff98ab8f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Coerce dtype __props__ to string due to invalid hash of `np.dtype()` objects

上级 864ebd1a
...@@ -96,7 +96,7 @@ class Binomial(Op): ...@@ -96,7 +96,7 @@ class Binomial(Op):
def __init__(self, format, dtype): def __init__(self, format, dtype):
self.format = format self.format = format
self.dtype = dtype self.dtype = np.dtype(dtype).name
def make_node(self, n, p, shape): def make_node(self, n, p, shape):
n = pt.as_tensor_variable(n) n = pt.as_tensor_variable(n)
......
...@@ -1090,6 +1090,8 @@ class Tri(Op): ...@@ -1090,6 +1090,8 @@ class Tri(Op):
def __init__(self, dtype=None): def __init__(self, dtype=None):
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype self.dtype = dtype
def make_node(self, N, M, k): def make_node(self, N, M, k):
...@@ -1368,6 +1370,8 @@ class Eye(Op): ...@@ -1368,6 +1370,8 @@ class Eye(Op):
def __init__(self, dtype=None): def __init__(self, dtype=None):
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype self.dtype = dtype
def make_node(self, n, m, k): def make_node(self, n, m, k):
...@@ -3225,7 +3229,7 @@ class ARange(COp): ...@@ -3225,7 +3229,7 @@ class ARange(COp):
__props__ = ("dtype",) __props__ = ("dtype",)
def __init__(self, dtype): def __init__(self, dtype):
self.dtype = dtype self.dtype = np.dtype(dtype).name
def make_node(self, start, stop, step): def make_node(self, start, stop, step):
from math import ceil from math import ceil
...@@ -3407,7 +3411,8 @@ def arange(start, stop=None, step=1, dtype=None): ...@@ -3407,7 +3411,8 @@ def arange(start, stop=None, step=1, dtype=None):
# We use the same dtype as numpy instead of the result of # We use the same dtype as numpy instead of the result of
# the upcast. # the upcast.
dtype = str(numpy_dtype) dtype = str(numpy_dtype)
else:
dtype = np.dtype(dtype).name
if dtype not in _arange: if dtype not in _arange:
_arange[dtype] = ARange(dtype) _arange[dtype] = ARange(dtype)
return _arange[dtype](start, stop, step) return _arange[dtype](start, stop, step)
......
...@@ -1234,8 +1234,8 @@ class CAReduce(COp): ...@@ -1234,8 +1234,8 @@ class CAReduce(COp):
else: else:
self.axis = tuple(axis) self.axis = tuple(axis)
self.dtype = dtype self.dtype = dtype if dtype is None else np.dtype(dtype).name
self.acc_dtype = acc_dtype self.acc_dtype = acc_dtype if acc_dtype is None else np.dtype(acc_dtype).name
self.upcast_discrete_output = upcast_discrete_output self.upcast_discrete_output = upcast_discrete_output
@property @property
......
...@@ -25,7 +25,7 @@ class LoadFromDisk(Op): ...@@ -25,7 +25,7 @@ class LoadFromDisk(Op):
__props__ = ("dtype", "shape", "mmap_mode") __props__ = ("dtype", "shape", "mmap_mode")
def __init__(self, dtype, shape, mmap_mode=None): def __init__(self, dtype, shape, mmap_mode=None):
self.dtype = np.dtype(dtype) # turn "float64" into np.float64 self.dtype = np.dtype(dtype).name
self.shape = shape self.shape = shape
if mmap_mode not in (None, "c"): if mmap_mode not in (None, "c"):
raise ValueError( raise ValueError(
......
...@@ -112,6 +112,8 @@ class RandomVariable(Op): ...@@ -112,6 +112,8 @@ class RandomVariable(Op):
else: else:
self.signature = safe_signature(self.ndims_params, [self.ndim_supp]) self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
if isinstance(dtype, np.dtype):
dtype = dtype.name
self.dtype = dtype or getattr(self, "dtype", None) self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = ( self.inplace = (
......
...@@ -2869,6 +2869,18 @@ class TestARange: ...@@ -2869,6 +2869,18 @@ class TestARange:
assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape
assert np.arange(-64, 64).shape == arange(-64, 64).type.shape assert np.arange(-64, 64).shape == arange(-64, 64).type.shape
def test_c_cache_bug(self):
# Regression test for bug caused by issues in hash of `np.dtype()` objects
# https://github.com/numpy/numpy/issues/17864
end = iscalar("end")
arange1 = ARange(np.dtype("float64"))(0, end, 1)
arange2 = ARange("float64")(0, end + 1, 1)
assert arange1.owner.op == arange2.owner.op
assert hash(arange1.owner.op) == hash(arange2.owner.op)
fn = function([end], [arange1, arange2])
res1, res2 = fn(10)
np.testing.assert_array_equal(res1, res2[:-1], strict=True)
class TestNdGrid: class TestNdGrid:
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论