提交 3d716fd6 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make GpuAdvancedIncSubtensor1_dev20 support mixed dtype for x and y. Most of the…

Make GpuAdvancedIncSubtensor1_dev20 support mixed dtype for x and y. Most of the code was already supporting this.
上级 170aff07
......@@ -1036,9 +1036,7 @@ def local_gpua_advanced_incsubtensor(op, context_name, inputs, outputs):
set_instead_of_inc = op.set_instead_of_inc
compute_capability = int(context.bin_id[-2])
if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2 or
x.type.dtype != y.type.dtype):
if compute_capability < 2 or x.ndim != 2 or y.ndim != 2:
return GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc)
else:
......
......@@ -803,7 +803,6 @@ class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, HideC,
y_ = as_gpuarray_variable(y, ctx_name)
ilist_ = as_gpuarray_variable(ilist, ctx_name)
assert x_.type.dtype == y_.type.dtype
assert x_.type.ndim >= y_.type.ndim
if ilist_.type.dtype not in tensor.integer_dtypes:
......
......@@ -13,6 +13,7 @@ from ..subtensor import (GpuIncSubtensor, GpuSubtensor,
GpuAdvancedSubtensor1,
GpuAdvancedSubtensor,
GpuAdvancedIncSubtensor1,
GpuAdvancedIncSubtensor1_dev20,
GpuDiagonal)
from ..type import gpuarray_shared_constructor
......@@ -63,6 +64,28 @@ def test_advinc_subtensor1():
assert numpy.allclose(rval, rep)
def test_advinc_subtensor1_dtype():
# Test the mixed dtype case
shp = (3, 3)
for dtype1, dtype2 in [('float32', 'int8'), ('float32', 'float64')]:
shared = gpuarray_shared_constructor
xval = numpy.arange(numpy.prod(shp), dtype=dtype1).reshape(shp) + 1
yval = numpy.empty((2,) + shp[1:], dtype=dtype2)
yval[:] = 10
x = shared(xval, name='x')
y = tensor.tensor(dtype=yval.dtype,
broadcastable=(False,) * len(shp),
name='y')
expr = tensor.advanced_inc_subtensor1(x, y, [0, 2])
f = theano.function([y], expr, mode=mode_with_gpu)
assert sum([isinstance(node.op, GpuAdvancedIncSubtensor1_dev20)
for node in f.maker.fgraph.toposort()]) == 1
rval = f(yval)
rep = xval.copy()
rep[[0, 2]] += yval
assert numpy.allclose(rval, rep)
def test_incsub_f16():
shp = (3, 3)
shared = gpuarray_shared_constructor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论