提交 fbb066c4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: Mathieu Germain

Fix crash of set_value when the destination isn't contiguous (#5913)

上级 ddafc3e2
...@@ -132,3 +132,14 @@ class test_shared_options(object): ...@@ -132,3 +132,14 @@ class test_shared_options(object):
class test_shared_options2(object): class test_shared_options2(object):
pass pass
""" """
def test_set_value_non_contiguous():
s = gpuarray_shared_constructor(
np.asarray([[1., 2.], [1., 2.], [5, 6]]))
s.set_value(s.get_value(borrow=True, return_internal_type=True)[::2],
borrow=True)
assert not s.get_value(borrow=True,
return_internal_type=True).flags["C_CONTIGUOUS"]
# In the past, this failed
s.set_value([[0, 0], [1, 1]])
...@@ -292,7 +292,10 @@ class GpuArrayType(Type): ...@@ -292,7 +292,10 @@ class GpuArrayType(Type):
raise TypeError("Non-unit value on shape on a broadcastable" raise TypeError("Non-unit value on shape on a broadcastable"
" dimension.", shp, self.broadcastable) " dimension.", shp, self.broadcastable)
if not isinstance(data, gpuarray.GpuArray): if not isinstance(data, gpuarray.GpuArray):
if old_data is not None and old_data.shape == data.shape: if old_data is not None and old_data.shape == data.shape and (
# write() only work if the destitation is contiguous.
old_data.flags['C_CONTIGUOUS'] or
old_data.flags['F_CONTIGUOUS']):
old_data.write(data) old_data.write(data)
data = old_data data = old_data
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论