Unverified 提交 d0420e3d authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6628 from abergeron/no_dev20_smallint

Don't use the dev20 version of GpuAdvancedIncSubtensor1 for small ints
...@@ -1118,7 +1118,8 @@ def local_gpua_advanced_incsubtensor1(op, context_name, inputs, outputs): ...@@ -1118,7 +1118,8 @@ def local_gpua_advanced_incsubtensor1(op, context_name, inputs, outputs):
set_instead_of_inc = op.set_instead_of_inc set_instead_of_inc = op.set_instead_of_inc
if (x.ndim == 1 and y.ndim == 0 and if (x.ndim == 1 and y.ndim == 0 and
config.deterministic == 'default'): config.deterministic == 'default' and
x.dtype not in ('int8', 'int16')):
x = x.dimshuffle(0, 'x') x = x.dimshuffle(0, 'x')
y = y.dimshuffle('x', 'x') y = y.dimshuffle('x', 'x')
ret = GpuAdvancedIncSubtensor1_dev20( ret = GpuAdvancedIncSubtensor1_dev20(
...@@ -1126,7 +1127,8 @@ def local_gpua_advanced_incsubtensor1(op, context_name, inputs, outputs): ...@@ -1126,7 +1127,8 @@ def local_gpua_advanced_incsubtensor1(op, context_name, inputs, outputs):
ret = GpuDimShuffle(ret.type.broadcastable, [0])(ret) ret = GpuDimShuffle(ret.type.broadcastable, [0])(ret)
return ret return ret
elif (x.ndim != 2 or y.ndim != 2 or elif (x.ndim != 2 or y.ndim != 2 or
config.deterministic == 'more'): config.deterministic == 'more' or
x.dtype in ('int8', 'int16')):
return GpuAdvancedIncSubtensor1( return GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc) set_instead_of_inc=set_instead_of_inc)
else: else:
......
...@@ -155,7 +155,8 @@ def test_advinc_subtensor1_vector_scalar(): ...@@ -155,7 +155,8 @@ def test_advinc_subtensor1_vector_scalar():
shp = (3,) shp = (3,)
for dtype1, dtype2 in [('float32', 'int8'), ('float32', 'float64'), for dtype1, dtype2 in [('float32', 'int8'), ('float32', 'float64'),
('float16', 'int8'), ('float16', 'float64'), ('float16', 'int8'), ('float16', 'float64'),
('float16', 'float16')]: ('float16', 'float16'), ('int8', 'int8'),
('int16', 'int16')]:
shared = gpuarray_shared_constructor shared = gpuarray_shared_constructor
xval = np.arange(np.prod(shp), dtype=dtype1).reshape(shp) + 1 xval = np.arange(np.prod(shp), dtype=dtype1).reshape(shp) + 1
yval = np.asarray(10, dtype=dtype2) yval = np.asarray(10, dtype=dtype2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论